In [38]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
from modules import torch_classes
from torchvision import transforms, utils
from modules.grad_cam import *

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
#Train set
labels_path = '../DATA/train_labels.pkl'
data_path = 'train.dat'
data_dims = (8269,10404)
genes_path = 'train.csv'

#Data and labels
train = torch_classes.TumorDataset(labels_path,data_path,data_dims,genes_path,transform = transforms.Compose([torch_classes.ToImage(),torch_classes.ToTensor()]))

#Test set
labels_path = '../DATA/test_labels.pkl'
data_path = 'test.dat'
data_dims = (2085,10404)
genes_path = 'test.csv'

test = torch_classes.TumorDataset(labels_path,data_path,data_dims,genes_path,transform = transforms.Compose([torch_classes.ToImage(),torch_classes.ToTensor()]))

In [28]:
num_epochs = 30
batch_size = 75
learning_rate = 0.0001

### Using sklearn stratified kfold splitter, iterate over the folds to run tests

In [29]:
skf = StratifiedKFold(n_splits=3, random_state=None, shuffle=True)
skf.split(np.zeros(len(train.int_labels)), train.int_labels)

<generator object _BaseKFold.split at 0x7f1622e5eb88>

### CV to find optimal number of epochs

In [30]:
training_losses = []
validation_losses = []
fold = 0

for train_index, val_index in skf.split(np.zeros(len(train.int_labels)), train.int_labels):
    train_subset = torch.utils.data.Subset(train, train_index)
    val_subset = torch.utils.data.Subset(train, val_index)
    
    #initialize data loaders using subsets
    dataloader_train = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4)
    dataloader_val = DataLoader(val_subset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    #Reinitialize network for every instance
    net = torch_classes.Net(num_of_classes=33)
    net.to(device)
    net.double()
    
    #Training loss function
    criterion = nn.CrossEntropyLoss()

    #Training optimizer
    optimizer = torch.optim.Adam(params=net.parameters(), lr=learning_rate)
    
    fold_trn_loss = []
    fold_val_loss = []
    
    for epoch in tqdm(range(num_epochs)):
        running_loss = 0.0
        running_val_loss = 0.0
        
        for i, samples in enumerate(dataloader_train):
            #Data
            images = samples['data'].to(device).requires_grad_()
            images = torch.unsqueeze(images,1)
            labels = samples['label'].to(device)
            _, labels = torch.max(labels,1) #converts from one hot to integer
            
            #Forward Pass
            outputs = net(images)

            #Backward Pass
            optimizer.zero_grad()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.data
            print("Epoch {} Batch {} loss is: {}".format(epoch, i, loss.data))
            
        epoch_training_loss = running_loss.item()/(i+1) #i + 1 since index starts from 0
        
        for j, val_samples in enumerate(dataloader_val):
            #validation loss
            images = val_samples['data'].to(device)
            images = torch.unsqueeze(images,1)
            labels = val_samples['label'].to(device)
            _, labels = torch.max(labels,1)

            #Forward Pass
            outputs = net(images)
            loss = criterion(outputs, labels)
            running_val_loss += loss.data
            
        epoch_validation_loss = running_val_loss.item()/(j+1)
        
        fold_trn_loss.append(epoch_training_loss)
        fold_val_loss.append(epoch_validation_loss)
        print('fold', fold, 'epoch', epoch, 'trn loss is:', epoch_training_loss, 'val loss is:', epoch_validation_loss)
    
    plt.plot(fold_trn_loss, label="Train")
    plt.plot(fold_val_loss, label="Validation")
    plt.title("Training vs Validation Loss")
    plt.xlabel('Epochs')
    plt.ylabel('Loss (Cross Entropy)')
    plt.legend()
    plt.show()
    
    training_losses.append(fold_trn_loss)
    validation_losses.append(fold_val_loss)
    fold = fold + 1



  0%|          | 0/30 [00:00<?, ?it/s][A[A

Epoch 0 Batch 0 loss is: 3.526394989167268
Epoch 0 Batch 1 loss is: 3.734350669301763
Epoch 0 Batch 2 loss is: 3.058990553101805
Epoch 0 Batch 3 loss is: 2.8045508597411852
Epoch 0 Batch 4 loss is: 2.6530753433712113
Epoch 0 Batch 5 loss is: 2.591426679875444
Epoch 0 Batch 6 loss is: 2.9702986891943866
Epoch 0 Batch 7 loss is: 2.0337783129479874
Epoch 0 Batch 8 loss is: 2.2438971160183843
Epoch 0 Batch 9 loss is: 2.282969901815195
Epoch 0 Batch 10 loss is: 2.1202171558270915
Epoch 0 Batch 11 loss is: 2.2085308789001052
Epoch 0 Batch 12 loss is: 2.037771703134282
Epoch 0 Batch 13 loss is: 1.9373709842477929
Epoch 0 Batch 14 loss is: 1.7592555547695437
Epoch 0 Batch 15 loss is: 1.5762485790674152
Epoch 0 Batch 16 loss is: 1.5301700636922906
Epoch 0 Batch 17 loss is: 1.6216274740300602
Epoch 0 Batch 18 loss is: 1.5379127275408697
Epoch 0 Batch 19 loss is: 1.4458990303518962
Epoch 0 Batch 20 loss is: 1.4225203783424403
Epoch 0 Batch 21 loss is: 1.4084797070876958
Epoch 0 Batch 22 loss is: 



  3%|▎         | 1/30 [04:39<2:14:53, 279.07s/it][A[A

fold 0 epoch 0 trn loss is: 1.1094013271466967 val loss is: 0.3838830445016632
Epoch 1 Batch 0 loss is: 0.2340885143330947
Epoch 1 Batch 1 loss is: 0.2772651324737858
Epoch 1 Batch 2 loss is: 0.34755187386783004
Epoch 1 Batch 3 loss is: 0.20636125881134404
Epoch 1 Batch 4 loss is: 0.31229896083467046
Epoch 1 Batch 5 loss is: 0.21393649523737898
Epoch 1 Batch 6 loss is: 0.22157254930137532
Epoch 1 Batch 7 loss is: 0.3229338597527511
Epoch 1 Batch 8 loss is: 0.17659077036290202
Epoch 1 Batch 9 loss is: 0.1902776763903071
Epoch 1 Batch 10 loss is: 0.14315736164162374
Epoch 1 Batch 11 loss is: 0.3185886003954748
Epoch 1 Batch 12 loss is: 0.20752832166916194
Epoch 1 Batch 13 loss is: 0.16161590679751886
Epoch 1 Batch 14 loss is: 0.22207666287020839
Epoch 1 Batch 15 loss is: 0.30388714737887457
Epoch 1 Batch 16 loss is: 0.21956523114472565
Epoch 1 Batch 17 loss is: 0.2825451141188534
Epoch 1 Batch 18 loss is: 0.12416731852666267
Epoch 1 Batch 19 loss is: 0.19800135202662666
Epoch 1 Batch 20 



  7%|▋         | 2/30 [09:16<2:09:58, 278.52s/it][A[A

fold 0 epoch 1 trn loss is: 0.20019811557855294 val loss is: 0.27030171395854297
Epoch 2 Batch 0 loss is: 0.110746220090146
Epoch 2 Batch 1 loss is: 0.08003607990583192
Epoch 2 Batch 2 loss is: 0.06735338395587191
Epoch 2 Batch 3 loss is: 0.06783151995877601
Epoch 2 Batch 4 loss is: 0.08524570066722308
Epoch 2 Batch 5 loss is: 0.07025220408574086
Epoch 2 Batch 6 loss is: 0.14948928489323787
Epoch 2 Batch 7 loss is: 0.09156333040719847
Epoch 2 Batch 8 loss is: 0.057033819463421766
Epoch 2 Batch 9 loss is: 0.09450885290295466
Epoch 2 Batch 10 loss is: 0.08604039753154348
Epoch 2 Batch 11 loss is: 0.0659681597964413
Epoch 2 Batch 12 loss is: 0.05539656152792535
Epoch 2 Batch 13 loss is: 0.06000104432629214
Epoch 2 Batch 14 loss is: 0.09998444106680783
Epoch 2 Batch 15 loss is: 0.06098862606239597
Epoch 2 Batch 16 loss is: 0.0826831390214546
Epoch 2 Batch 17 loss is: 0.0500740885036428
Epoch 2 Batch 18 loss is: 0.04258995072321738
Epoch 2 Batch 19 loss is: 0.04051752872347284
Epoch 2 Batch



 10%|█         | 3/30 [13:52<2:05:05, 277.96s/it][A[A

fold 0 epoch 2 trn loss is: 0.06162011537163516 val loss is: 0.22012910351070514
Epoch 3 Batch 0 loss is: 0.032005120226269286
Epoch 3 Batch 1 loss is: 0.030444749836491147
Epoch 3 Batch 2 loss is: 0.028469887976702496
Epoch 3 Batch 3 loss is: 0.03611809889731373
Epoch 3 Batch 4 loss is: 0.03293422494593008
Epoch 3 Batch 5 loss is: 0.025212381076764375
Epoch 3 Batch 6 loss is: 0.0360155128351954
Epoch 3 Batch 7 loss is: 0.02113452084556947
Epoch 3 Batch 8 loss is: 0.0332236319341042
Epoch 3 Batch 9 loss is: 0.028608796841826004
Epoch 3 Batch 10 loss is: 0.03580494513520689
Epoch 3 Batch 11 loss is: 0.023940384467573933
Epoch 3 Batch 12 loss is: 0.019260425178120228
Epoch 3 Batch 13 loss is: 0.035133009835435845
Epoch 3 Batch 14 loss is: 0.027501660890966104
Epoch 3 Batch 15 loss is: 0.01671551323840164
Epoch 3 Batch 16 loss is: 0.029221111547352324
Epoch 3 Batch 17 loss is: 0.02002017693789463
Epoch 3 Batch 18 loss is: 0.016920592017687176
Epoch 3 Batch 19 loss is: 0.023202232247072418



 13%|█▎        | 4/30 [18:29<2:00:15, 277.51s/it][A[A

fold 0 epoch 3 trn loss is: 0.02069520498785716 val loss is: 0.2028622677513221
Epoch 4 Batch 0 loss is: 0.013289743667583948
Epoch 4 Batch 1 loss is: 0.015536424252992908
Epoch 4 Batch 2 loss is: 0.019318240799764463
Epoch 4 Batch 3 loss is: 0.008197325348680819
Epoch 4 Batch 4 loss is: 0.01750242997240805
Epoch 4 Batch 5 loss is: 0.016325094231767815
Epoch 4 Batch 6 loss is: 0.010941045957473333
Epoch 4 Batch 7 loss is: 0.0112498148544224
Epoch 4 Batch 8 loss is: 0.007087786337849049
Epoch 4 Batch 9 loss is: 0.012791673318595839
Epoch 4 Batch 10 loss is: 0.0052627216865980845
Epoch 4 Batch 11 loss is: 0.009775189525819043
Epoch 4 Batch 12 loss is: 0.010037754749576718
Epoch 4 Batch 13 loss is: 0.006541558147421789
Epoch 4 Batch 14 loss is: 0.010168191094528692
Epoch 4 Batch 15 loss is: 0.01050611408114431
Epoch 4 Batch 16 loss is: 0.006320644658530969
Epoch 4 Batch 17 loss is: 0.007166548606502517
Epoch 4 Batch 18 loss is: 0.00916196639712659
Epoch 4 Batch 19 loss is: 0.0101573408664



 17%|█▋        | 5/30 [23:07<1:55:39, 277.57s/it][A[A

fold 0 epoch 4 trn loss is: 0.00921655093176823 val loss is: 0.19745542689776543
Epoch 5 Batch 0 loss is: 0.008924677928841464
Epoch 5 Batch 1 loss is: 0.005279333881958869
Epoch 5 Batch 2 loss is: 0.006184432363501446
Epoch 5 Batch 3 loss is: 0.005139512895896559
Epoch 5 Batch 4 loss is: 0.0076896708984055566
Epoch 5 Batch 5 loss is: 0.005509764778694108
Epoch 5 Batch 6 loss is: 0.007706219763222097
Epoch 5 Batch 7 loss is: 0.0047785013438931285
Epoch 5 Batch 8 loss is: 0.004181602391851958
Epoch 5 Batch 9 loss is: 0.007955858205115901
Epoch 5 Batch 10 loss is: 0.004547620313301858
Epoch 5 Batch 11 loss is: 0.006682651547153533
Epoch 5 Batch 12 loss is: 0.002640567958967163
Epoch 5 Batch 13 loss is: 0.002532350706514753
Epoch 5 Batch 14 loss is: 0.0056922264528238035
Epoch 5 Batch 15 loss is: 0.006536873924101556
Epoch 5 Batch 16 loss is: 0.0066917561486583566
Epoch 5 Batch 17 loss is: 0.009236069626007091
Epoch 5 Batch 18 loss is: 0.005326944591085491
Epoch 5 Batch 19 loss is: 0.0050



 20%|██        | 6/30 [27:45<1:51:05, 277.74s/it][A[A

fold 0 epoch 5 trn loss is: 0.005337307343272916 val loss is: 0.21046600726875564
Epoch 6 Batch 0 loss is: 0.0033113728234739643
Epoch 6 Batch 1 loss is: 0.0064197320462226655
Epoch 6 Batch 2 loss is: 0.005488292742012793
Epoch 6 Batch 3 loss is: 0.0030768441161118667
Epoch 6 Batch 4 loss is: 0.00451893263368575
Epoch 6 Batch 5 loss is: 0.004120963738749642
Epoch 6 Batch 6 loss is: 0.003501558699072594
Epoch 6 Batch 7 loss is: 0.002483376358638812
Epoch 6 Batch 8 loss is: 0.0023040547038710338
Epoch 6 Batch 9 loss is: 0.004018492078641991
Epoch 6 Batch 10 loss is: 0.006017590625546324
Epoch 6 Batch 11 loss is: 0.003457276068015934
Epoch 6 Batch 12 loss is: 0.004029580264920073
Epoch 6 Batch 13 loss is: 0.0069543170558406565
Epoch 6 Batch 14 loss is: 0.0038895353592610367
Epoch 6 Batch 15 loss is: 0.00346337719000528
Epoch 6 Batch 16 loss is: 0.005641966962670632
Epoch 6 Batch 17 loss is: 0.0033238279179612106
Epoch 6 Batch 18 loss is: 0.0036358019735268717
Epoch 6 Batch 19 loss is: 0.0



 23%|██▎       | 7/30 [32:19<1:46:03, 276.68s/it][A[A

fold 0 epoch 6 trn loss is: 0.003437454204452766 val loss is: 0.20497724388560673
Epoch 7 Batch 0 loss is: 0.004101312465327851
Epoch 7 Batch 1 loss is: 0.002083532807143958
Epoch 7 Batch 2 loss is: 0.002497942762134002
Epoch 7 Batch 3 loss is: 0.0028983197030788923
Epoch 7 Batch 4 loss is: 0.0032545194589927817
Epoch 7 Batch 5 loss is: 0.0029340055554646226
Epoch 7 Batch 6 loss is: 0.0031379968418912794
Epoch 7 Batch 7 loss is: 0.0028349508045461863
Epoch 7 Batch 8 loss is: 0.002748673760238501
Epoch 7 Batch 9 loss is: 0.0017400216291223577
Epoch 7 Batch 10 loss is: 0.002204371476475515
Epoch 7 Batch 11 loss is: 0.004716393531226414
Epoch 7 Batch 12 loss is: 0.003149824084575575
Epoch 7 Batch 13 loss is: 0.001951721540182542
Epoch 7 Batch 14 loss is: 0.0021281269218057257
Epoch 7 Batch 15 loss is: 0.0022094789171805
Epoch 7 Batch 16 loss is: 0.001873957457381484
Epoch 7 Batch 17 loss is: 0.0024132751317446593
Epoch 7 Batch 18 loss is: 0.004327378481269856
Epoch 7 Batch 19 loss is: 0.0



 27%|██▋       | 8/30 [36:54<1:41:14, 276.11s/it][A[A

fold 0 epoch 7 trn loss is: 0.0025364895657705665 val loss is: 0.20622525542908193
Epoch 8 Batch 0 loss is: 0.0031434965004646406
Epoch 8 Batch 1 loss is: 0.0023727940734443345
Epoch 8 Batch 2 loss is: 0.0022794093842719576
Epoch 8 Batch 3 loss is: 0.002535660737593825
Epoch 8 Batch 4 loss is: 0.002396574123759597
Epoch 8 Batch 5 loss is: 0.0026939720765753114
Epoch 8 Batch 6 loss is: 0.003742510964371384
Epoch 8 Batch 7 loss is: 0.0017789294608166984
Epoch 8 Batch 8 loss is: 0.002212332639295932
Epoch 8 Batch 9 loss is: 0.0031743205466538645
Epoch 8 Batch 10 loss is: 0.0016994850930055823
Epoch 8 Batch 11 loss is: 0.0010843667876313152
Epoch 8 Batch 12 loss is: 0.002195749976625147
Epoch 8 Batch 13 loss is: 0.0013924780892084362
Epoch 8 Batch 14 loss is: 0.0018347815465699568
Epoch 8 Batch 15 loss is: 0.002095909664469057
Epoch 8 Batch 16 loss is: 0.0015153322230972558
Epoch 8 Batch 17 loss is: 0.0021125354895381323
Epoch 8 Batch 18 loss is: 0.0033233754525434853
Epoch 8 Batch 19 loss



 30%|███       | 9/30 [41:30<1:36:38, 276.14s/it][A[A

fold 0 epoch 8 trn loss is: 0.0019449368427977077 val loss is: 0.2123480150169797
Epoch 9 Batch 0 loss is: 0.001986331841901586
Epoch 9 Batch 1 loss is: 0.0023501830623935877
Epoch 9 Batch 2 loss is: 0.0011441118381290967
Epoch 9 Batch 3 loss is: 0.0019005513291196982
Epoch 9 Batch 4 loss is: 0.0018422606306368057
Epoch 9 Batch 5 loss is: 0.0011521928174090827
Epoch 9 Batch 6 loss is: 0.003171727422657552
Epoch 9 Batch 7 loss is: 0.0018843737556821813
Epoch 9 Batch 8 loss is: 0.0013607915254985887
Epoch 9 Batch 9 loss is: 0.0018699196242565345
Epoch 9 Batch 10 loss is: 0.0015207967031192501
Epoch 9 Batch 11 loss is: 0.0010716513636722642
Epoch 9 Batch 12 loss is: 0.001966155666805823
Epoch 9 Batch 13 loss is: 0.00152709582926074
Epoch 9 Batch 14 loss is: 0.0020330973282950047
Epoch 9 Batch 15 loss is: 0.0014904641993162926
Epoch 9 Batch 16 loss is: 0.0016885151024982766
Epoch 9 Batch 17 loss is: 0.0015727090305887497
Epoch 9 Batch 18 loss is: 0.001676501108569989
Epoch 9 Batch 19 loss 



 33%|███▎      | 10/30 [46:04<1:31:47, 275.36s/it][A[A

fold 0 epoch 9 trn loss is: 0.0015491716547504717 val loss is: 0.21318182467005364
Epoch 10 Batch 0 loss is: 0.0014120930086278113
Epoch 10 Batch 1 loss is: 0.0008785628241533535
Epoch 10 Batch 2 loss is: 0.0012272200514882078
Epoch 10 Batch 3 loss is: 0.0018977565882937596
Epoch 10 Batch 4 loss is: 0.0012400651739195468
Epoch 10 Batch 5 loss is: 0.0017705731868221145
Epoch 10 Batch 6 loss is: 0.0012690060671549474
Epoch 10 Batch 7 loss is: 0.0008668498660822622
Epoch 10 Batch 8 loss is: 0.0008914964763643951
Epoch 10 Batch 9 loss is: 0.0014010753146805636
Epoch 10 Batch 10 loss is: 0.0014048225156313284
Epoch 10 Batch 11 loss is: 0.000717700283430555
Epoch 10 Batch 12 loss is: 0.0013951875664357516
Epoch 10 Batch 13 loss is: 0.0017780766079000425
Epoch 10 Batch 14 loss is: 0.0014738156166020624
Epoch 10 Batch 15 loss is: 0.0010321173591292402
Epoch 10 Batch 16 loss is: 0.002049751250878889
Epoch 10 Batch 17 loss is: 0.002008101868946284
Epoch 10 Batch 18 loss is: 0.0015016730771040443



 37%|███▋      | 11/30 [50:39<1:27:12, 275.38s/it][A[A

fold 0 epoch 10 trn loss is: 0.001303215594162098 val loss is: 0.21152194913135589


OSError: [Errno 12] Cannot allocate memory

In [31]:
#This is list of lists of lists
losses = [training_losses,validation_losses]
with open('CV_file/losses/losses.pkl', 'wb') as f:
    pickle.dump(losses, f)

In [None]:
fig = plt.figure()
for i in range(len(training_losses)):
    plt.subplot(len(training_losses)+1,1,i+1)
    plt.plot(training_losses[i], label="Train")
    plt.plot(validation_losses[i], label="Validation")
    plt.title("Training vs Validation Loss")
    plt.xlabel('Epochs')
    plt.ylabel('Loss (Cross Entropy)')
    plt.legend()
plt.show()
fig.set_size_inches(8, 20)    
fig.savefig("CV_file/Loss.png")