In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
from modules import torch_classes
from torchvision import transforms, utils
from modules.grad_cam import *

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [5]:
#Train set
labels_path = '../DATA/train_labels.pkl'
data_path = 'train.dat'
data_dims = (8269,10404)
genes_path = 'train.csv'

#Data and labels
train = torch_classes.TumorDataset(labels_path,data_path,data_dims,genes_path,transform = transforms.Compose([torch_classes.ToImage(),torch_classes.ToTensor()]))

#Test set
labels_path = '../DATA/test_labels.pkl'
data_path = 'test.dat'
data_dims = (2085,10404)
genes_path = 'test.csv'

test = torch_classes.TumorDataset(labels_path,data_path,data_dims,genes_path,transform = transforms.Compose([torch_classes.ToImage(),torch_classes.ToTensor()]))

In [8]:
num_epochs = 30
batch_size = 75
learning_rate = 0.0001

net = torch_classes.Net(num_of_classes=33)
net.to(device)
net.double()

#Training loss function
criterion = nn.CrossEntropyLoss()

#Training optimizer
optimizer = torch.optim.Adam(params=net.parameters(), lr=learning_rate)

### Train using early stopping

In [5]:
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.15)

#Single loop to get indicies
for train_index, val_index in sss.split(np.zeros(len(train.int_labels)), train.int_labels):
    train_subset = torch.utils.data.Subset(train, train_index)
    val_subset = torch.utils.data.Subset(train, val_index)
    
#initialize data loaders using subsets
dataloader_train = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader_val = DataLoader(val_subset, batch_size=batch_size, shuffle=True, num_workers=4)

In [6]:
training_losses = []
validation_losses = []
weights = []

for epoch in tqdm(range(num_epochs)):
        running_loss = 0.0
        running_val_loss = 0.0
        for i, samples in enumerate(dataloader_train):
            #Data
            images = samples['data'].to(device).requires_grad_()
            images = torch.unsqueeze(images,1)
            labels = samples['label'].to(device)
            _, labels = torch.max(labels,1) #converts from one hot to integer
            
            #Forward Pass
            outputs = net(images)

            #Backward Pass
            optimizer.zero_grad()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.data
            print("Epoch {} Batch {} loss is: {}".format(epoch, i, loss.data))
            
        epoch_training_loss = running_loss.item()/(i+1) #i + 1 since index starts from 0
        
        for j, val_samples in enumerate(dataloader_val):
            #validation loss
            images = val_samples['data'].to(device)
            images = torch.unsqueeze(images,1)
            labels = val_samples['label'].to(device)
            _, labels = torch.max(labels,1)

            #Forward Pass
            outputs = net(images)
            loss = criterion(outputs, labels)
            running_val_loss += loss.data
            
        epoch_validation_loss = running_val_loss.item()/(j+1) #i + 1 since index starts from 0
        
        training_losses.append(epoch_training_loss)
        validation_losses.append(epoch_validation_loss)
        print('epoch', epoch, 'trn loss is:', epoch_training_loss, 'val loss is:', epoch_validation_loss)
        
        weights.append(net.state_dict())
        if epoch >= 3:
            #keep only the 3 most current network weights
            del weights[0]
            
        
        #Criteria for early stopping - if validation loss goes up after 3 iterations or stops changing over 3 epochs
        if (epoch >= 3) and (abs(validation_losses[epoch] - validation_losses[epoch-1]) <= 0.0001) and (abs(validation_losses[epoch - 1] - validation_losses[epoch - 2]) <= 0.0001):
            print("Early Stopping - losses stopped changing")
            for i in range(len(weights)):
                path = "CV_file/es_weights/weights"+str(i)+'.pt'
                torch.save(weights[i], path)
            break   
       
        elif (epoch>=3) and (validation_losses[epoch]-validation_losses[epoch-1]>0) and (validation_losses[epoch-1]-validation_losses[epoch-2]>0):
            print("Early Stopping - losses going up")
            for i in range(len(weights)):
                path = "CV_file/es_weights/weights"+str(i)+'.pt'
                torch.save(weights[i], path)
            break
            
            
        

  0%|          | 0/30 [00:00<?, ?it/s]

Epoch 0 Batch 0 loss is: 3.518521642254989
Epoch 0 Batch 1 loss is: 3.8712316085365654
Epoch 0 Batch 2 loss is: 3.758162479418121
Epoch 0 Batch 3 loss is: 3.010863333427325
Epoch 0 Batch 4 loss is: 2.8834931909197303
Epoch 0 Batch 5 loss is: 2.7282659703670475
Epoch 0 Batch 6 loss is: 2.5472590976950875
Epoch 0 Batch 7 loss is: 2.421582608740413
Epoch 0 Batch 8 loss is: 2.340670277796199
Epoch 0 Batch 9 loss is: 1.9347495236049523
Epoch 0 Batch 10 loss is: 1.9101258586908905
Epoch 0 Batch 11 loss is: 1.989105378318807
Epoch 0 Batch 12 loss is: 1.9776118401771139
Epoch 0 Batch 13 loss is: 1.777035872639607
Epoch 0 Batch 14 loss is: 1.6195798734507103
Epoch 0 Batch 15 loss is: 1.5927843424120072
Epoch 0 Batch 16 loss is: 1.5080177131561887
Epoch 0 Batch 17 loss is: 1.6767222727746214
Epoch 0 Batch 18 loss is: 1.5160949423773311
Epoch 0 Batch 19 loss is: 1.5239111896529707
Epoch 0 Batch 20 loss is: 1.5049549563826186
Epoch 0 Batch 21 loss is: 1.1292033984866316
Epoch 0 Batch 22 loss is: 1

  3%|▎         | 1/30 [05:25<2:37:24, 325.68s/it]

epoch 0 trn loss is: 0.9858828791265282 val loss is: 0.324025127160218
Epoch 1 Batch 0 loss is: 0.3168642864457478
Epoch 1 Batch 1 loss is: 0.15662032296944073
Epoch 1 Batch 2 loss is: 0.17473186596717677
Epoch 1 Batch 3 loss is: 0.18570017552929396
Epoch 1 Batch 4 loss is: 0.25425772489217413
Epoch 1 Batch 5 loss is: 0.1728998121589691
Epoch 1 Batch 6 loss is: 0.16132577270652376
Epoch 1 Batch 7 loss is: 0.23149082238690985
Epoch 1 Batch 8 loss is: 0.1973013082365515
Epoch 1 Batch 9 loss is: 0.2071432819473308
Epoch 1 Batch 10 loss is: 0.1827564199520861
Epoch 1 Batch 11 loss is: 0.15438469742783578
Epoch 1 Batch 12 loss is: 0.16595231833719337
Epoch 1 Batch 13 loss is: 0.2757961729871755
Epoch 1 Batch 14 loss is: 0.18486318400048707
Epoch 1 Batch 15 loss is: 0.16223833195986612
Epoch 1 Batch 16 loss is: 0.23300447334119043
Epoch 1 Batch 17 loss is: 0.11112117974885233
Epoch 1 Batch 18 loss is: 0.16215076065295192
Epoch 1 Batch 19 loss is: 0.2143294810227289
Epoch 1 Batch 20 loss is: 

  7%|▋         | 2/30 [10:47<2:31:25, 324.47s/it]

epoch 1 trn loss is: 0.16158415550112853 val loss is: 0.23817242570842953
Epoch 2 Batch 0 loss is: 0.0506625638822142
Epoch 2 Batch 1 loss is: 0.04729554993876312
Epoch 2 Batch 2 loss is: 0.07088883265576168
Epoch 2 Batch 3 loss is: 0.048792079525662814
Epoch 2 Batch 4 loss is: 0.07305927993488105
Epoch 2 Batch 5 loss is: 0.03343409158850915
Epoch 2 Batch 6 loss is: 0.061317470342308104
Epoch 2 Batch 7 loss is: 0.049329274595795995
Epoch 2 Batch 8 loss is: 0.07559569332557486
Epoch 2 Batch 9 loss is: 0.03330100525382346
Epoch 2 Batch 10 loss is: 0.08857022865556168
Epoch 2 Batch 11 loss is: 0.05417088136668831
Epoch 2 Batch 12 loss is: 0.08043171256482279
Epoch 2 Batch 13 loss is: 0.039293178339807194
Epoch 2 Batch 14 loss is: 0.045295200532208015
Epoch 2 Batch 15 loss is: 0.05929021817565861
Epoch 2 Batch 16 loss is: 0.06986502359677241
Epoch 2 Batch 17 loss is: 0.10945911839879763
Epoch 2 Batch 18 loss is: 0.05977618208380431
Epoch 2 Batch 19 loss is: 0.07749219448001626
Epoch 2 Batc

 10%|█         | 3/30 [16:08<2:25:30, 323.34s/it]

epoch 2 trn loss is: 0.051366608737230826 val loss is: 0.1892537194475813
Epoch 3 Batch 0 loss is: 0.013629628188310896
Epoch 3 Batch 1 loss is: 0.010914982311761818
Epoch 3 Batch 2 loss is: 0.01097779108637346
Epoch 3 Batch 3 loss is: 0.035916224192536664
Epoch 3 Batch 4 loss is: 0.009554402631258832
Epoch 3 Batch 5 loss is: 0.06234674656347818
Epoch 3 Batch 6 loss is: 0.03711202746383142
Epoch 3 Batch 7 loss is: 0.02314872509862949
Epoch 3 Batch 8 loss is: 0.012632863775477444
Epoch 3 Batch 9 loss is: 0.016246075048026422
Epoch 3 Batch 10 loss is: 0.03167317015383263
Epoch 3 Batch 11 loss is: 0.031158778990014884
Epoch 3 Batch 12 loss is: 0.020456973466155427
Epoch 3 Batch 13 loss is: 0.01625877559764452
Epoch 3 Batch 14 loss is: 0.0189253825972639
Epoch 3 Batch 15 loss is: 0.014378094066169212
Epoch 3 Batch 16 loss is: 0.008739748792454047
Epoch 3 Batch 17 loss is: 0.017721036603493504
Epoch 3 Batch 18 loss is: 0.02649847988238405
Epoch 3 Batch 19 loss is: 0.009309155523682193
Epoch

 13%|█▎        | 4/30 [21:28<2:19:41, 322.37s/it]

epoch 3 trn loss is: 0.016923289961583265 val loss is: 0.17467267249095905
Epoch 4 Batch 0 loss is: 0.007270413293278999
Epoch 4 Batch 1 loss is: 0.003898265546206545
Epoch 4 Batch 2 loss is: 0.00520021434161567
Epoch 4 Batch 3 loss is: 0.006655417440726137
Epoch 4 Batch 4 loss is: 0.007380588106580532
Epoch 4 Batch 5 loss is: 0.0075192358855298855
Epoch 4 Batch 6 loss is: 0.012403985741994425
Epoch 4 Batch 7 loss is: 0.011379830857219773
Epoch 4 Batch 8 loss is: 0.005870022484572814
Epoch 4 Batch 9 loss is: 0.006683259099626279
Epoch 4 Batch 10 loss is: 0.009347826402701228
Epoch 4 Batch 11 loss is: 0.011093795824841936
Epoch 4 Batch 12 loss is: 0.013088189186694283
Epoch 4 Batch 13 loss is: 0.008413297762447888
Epoch 4 Batch 14 loss is: 0.004207447944164263
Epoch 4 Batch 15 loss is: 0.008346996428660868
Epoch 4 Batch 16 loss is: 0.006965366402902309
Epoch 4 Batch 17 loss is: 0.0029496351516103184
Epoch 4 Batch 18 loss is: 0.0036045908414290297
Epoch 4 Batch 19 loss is: 0.007982388224

 17%|█▋        | 5/30 [26:52<2:14:36, 323.06s/it]

epoch 4 trn loss is: 0.006331152021418543 val loss is: 0.1899711493104799
Epoch 5 Batch 0 loss is: 0.00635381414230937
Epoch 5 Batch 1 loss is: 0.002834773270445273
Epoch 5 Batch 2 loss is: 0.005239060737044182
Epoch 5 Batch 3 loss is: 0.003242006125278086
Epoch 5 Batch 4 loss is: 0.005119372727966815
Epoch 5 Batch 5 loss is: 0.004447210587629442
Epoch 5 Batch 6 loss is: 0.0039348512104122576
Epoch 5 Batch 7 loss is: 0.0024460410332303203
Epoch 5 Batch 8 loss is: 0.004562013520453772
Epoch 5 Batch 9 loss is: 0.004062582332145747
Epoch 5 Batch 10 loss is: 0.004034481326151536
Epoch 5 Batch 11 loss is: 0.005001651826457885
Epoch 5 Batch 12 loss is: 0.004070209614732315
Epoch 5 Batch 13 loss is: 0.005593114318720858
Epoch 5 Batch 14 loss is: 0.005243853153786743
Epoch 5 Batch 15 loss is: 0.004160651079607156
Epoch 5 Batch 16 loss is: 0.0032511826021499484
Epoch 5 Batch 17 loss is: 0.0033378072259752636
Epoch 5 Batch 18 loss is: 0.003239877860187856
Epoch 5 Batch 19 loss is: 0.003275566288

 20%|██        | 6/30 [32:19<2:09:39, 324.14s/it]

epoch 5 trn loss is: 0.003700566821940186 val loss is: 0.18695301702435443
Epoch 6 Batch 0 loss is: 0.002378847697910871
Epoch 6 Batch 1 loss is: 0.0016280277403948418
Epoch 6 Batch 2 loss is: 0.0012248494810085475
Epoch 6 Batch 3 loss is: 0.0025620186900371066
Epoch 6 Batch 4 loss is: 0.0028393442876458153
Epoch 6 Batch 5 loss is: 0.0010323838557442617
Epoch 6 Batch 6 loss is: 0.003096869538903183
Epoch 6 Batch 7 loss is: 0.0023976359398107405
Epoch 6 Batch 8 loss is: 0.0011345827989219734
Epoch 6 Batch 9 loss is: 0.001962797044596248
Epoch 6 Batch 10 loss is: 0.0027706179672060097
Epoch 6 Batch 11 loss is: 0.002286142821119353
Epoch 6 Batch 12 loss is: 0.0017167877877762312
Epoch 6 Batch 13 loss is: 0.003046461144250685
Epoch 6 Batch 14 loss is: 0.002957660515465387
Epoch 6 Batch 15 loss is: 0.0027636322748466097
Epoch 6 Batch 16 loss is: 0.0024795895527892496
Epoch 6 Batch 17 loss is: 0.00246559974017084
Epoch 6 Batch 18 loss is: 0.0034293549484012923
Epoch 6 Batch 19 loss is: 0.002

 23%|██▎       | 7/30 [37:39<2:03:49, 323.01s/it]

epoch 6 trn loss is: 0.002505615651311519 val loss is: 0.17494964313679628
Epoch 7 Batch 0 loss is: 0.0016710678263526428
Epoch 7 Batch 1 loss is: 0.0016931128034728478
Epoch 7 Batch 2 loss is: 0.001880093763787277
Epoch 7 Batch 3 loss is: 0.0013338757203382556
Epoch 7 Batch 4 loss is: 0.001788059738387607
Epoch 7 Batch 5 loss is: 0.0018304121531627971
Epoch 7 Batch 6 loss is: 0.0022303210504240074
Epoch 7 Batch 7 loss is: 0.0017149370816303626
Epoch 7 Batch 8 loss is: 0.0015026987081997116
Epoch 7 Batch 9 loss is: 0.002262583247986309
Epoch 7 Batch 10 loss is: 0.002272649436851519
Epoch 7 Batch 11 loss is: 0.0018087179108458192
Epoch 7 Batch 12 loss is: 0.00229387898486929
Epoch 7 Batch 13 loss is: 0.001192747613273113
Epoch 7 Batch 14 loss is: 0.0014062016522378448
Epoch 7 Batch 15 loss is: 0.0009763709348558554
Epoch 7 Batch 16 loss is: 0.0021795494651941044
Epoch 7 Batch 17 loss is: 0.0019270039762862012
Epoch 7 Batch 18 loss is: 0.0022152327139485094
Epoch 7 Batch 19 loss is: 0.00

 27%|██▋       | 8/30 [43:05<1:58:42, 323.74s/it]

epoch 7 trn loss is: 0.0018283885469876097 val loss is: 0.19116536693314778
Epoch 8 Batch 0 loss is: 0.0024529295023719417
Epoch 8 Batch 1 loss is: 0.000998507855704034
Epoch 8 Batch 2 loss is: 0.001994420653015965
Epoch 8 Batch 3 loss is: 0.0013572493091011494
Epoch 8 Batch 4 loss is: 0.0009859615288616889
Epoch 8 Batch 5 loss is: 0.0017456708916885806
Epoch 8 Batch 6 loss is: 0.0014104079463032803
Epoch 8 Batch 7 loss is: 0.0019910868094938296
Epoch 8 Batch 8 loss is: 0.0011474811880551764
Epoch 8 Batch 9 loss is: 0.0016603986408588157
Epoch 8 Batch 10 loss is: 0.001383618895809633
Epoch 8 Batch 11 loss is: 0.001157685591228533
Epoch 8 Batch 12 loss is: 0.0017452896580850564
Epoch 8 Batch 13 loss is: 0.0006777849095801732
Epoch 8 Batch 14 loss is: 0.0012077451474695996
Epoch 8 Batch 15 loss is: 0.0029785159771874704
Epoch 8 Batch 16 loss is: 0.0015395220936051145
Epoch 8 Batch 17 loss is: 0.0008971635313396576
Epoch 8 Batch 18 loss is: 0.0016173423604866363
Epoch 8 Batch 19 loss is: 

In [9]:
#These results are using the weights that it exited on (validation loss went up twice)
dataloader_test = DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=4) #0.9510791366906475 -> using two steps before 0.951558752997602
dataloader_fulltrain = DataLoader(train, batch_size=batch_size, shuffle=False, num_workers=4)#0.9922602491232313

In [10]:
test_net = torch_classes.Net(num_of_classes=33)
test_net.load_state_dict(torch.load('CV_file/es_weights/weights0.pt'))
test_net.to(device)
test_net.double()
test_net.eval()

Net(
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (drop2D): Dropout2d(p=0.25)
  (vp): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=36864, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=33, bias=True)
)

In [11]:
true_labels = np.array([])
labels_predicted = np.array([])
for ii, test_sample in enumerate(dataloader_test):
    #Data
    images = test_sample['data'].to(device)
    images = torch.unsqueeze(images,1)
    labels = test_sample['label'].to(device)
    _, labels = torch.max(labels,1)
    
    #Forward pass
    outputs = test_net(images)
    
    #Label classified
    _, predicted = torch.max(outputs.data, 1)
    
    true_labels = np.append(true_labels, labels.cpu().numpy().astype('int8'))
    labels_predicted = np.append(labels_predicted, predicted.cpu().numpy().astype('int8'))

In [12]:
accuracy = np.sum(true_labels==labels_predicted)
print(accuracy/len(true_labels))

0.951558752997602
