In [1]:
# personal imports
from dataloader import DataLoader
import utils
from utils import calculate_auc, auc
from callbacks import *

# python stuffs
import os
import numpy as np
import torchvision.models as models
from torchvision import transforms as trn
import skimage.io
import skimage
import torch.utils.model_zoo as model_zoo
import torch

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision

from transfer_classifier import *


Using TensorFlow backend.


In [3]:
USE_GPU = True
dtype = torch.float32

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
    print("gpu available!")
else:
    device = torch.device('cpu')
    print("gpu NOT available!")

gpu available!


## Data loading
Numpy tensors exist in /scratch/users/gmachi/codex/data/train

In [2]:
ppb = 1 # "patches-per-batch"; batch size to see all 25 slices in a patch

train_loader = DataLoader(utils.train_dir, batch_size=ppb, transfer=True)
val_loader = DataLoader(utils.val_dir, batch_size=ppb, transfer=True)
test_loader = DataLoader(utils.test_dir, batch_size=ppb, transfer=True)

print("begin sanity checks for shapes...\n")
for f, d, l in train_loader: # filename, batched data, label
    print("train <filenames, data batch, labels>:\n", len(f), d.shape, l.shape)
    break
    
for f, d, l in val_loader:
    print("val <filenames, data batch, labels>:\n", len(f), d.shape, l.shape)
    break
    
for f, d, l in test_loader:
    print("test <filenames, data batch, labels>:\n", len(f), d.shape, l.shape)
    break


begin sanity checks for shapes...

train <filenames, data batch, labels>:
 25 (25, 3, 96, 96) (25,)
val <filenames, data batch, labels>:
 25 (25, 3, 96, 96) (25,)
test <filenames, data batch, labels>:
 25 (25, 3, 96, 96) (25,)


In [7]:
# Get image summary stats

from utils import labels_dict

def count_files(dir):
    return len([1 for x in list(os.scandir(dir)) if x.is_file()])

def unique_files(dir):
    return set([x.split("_")[0].split("reg")[1] for x in os.listdir(dir)])

def set_splits(dir):
    all_files = [x.split("_")[0].split("reg")[1] for x in os.listdir(dir)]
    labels = [labels_dict[u][1] for u in all_files]
    pos = np.sum(labels)
    neg = len(labels) - pos
    return pos, neg
    

print("After augmentation/up-sampling, we have...\n------------------------------------------")
print("train set size:", count_files(utils.train_dir))
print("val set size:", count_files(utils.val_dir))
print("test set size:", count_files(utils.test_dir))

print("\nSee composition of patients in sets...\n--------------------------------------")
print("train set unique files:", unique_files(utils.train_dir))
print("val set unique files:", unique_files(utils.val_dir))
print("test set unique files:", unique_files(utils.test_dir))

print("\n(+/-) splits in sets...\n-----------------------")
print("train set split:", set_splits(utils.train_dir))
print("val set split:", set_splits(utils.val_dir))
print("test set split:", set_splits(utils.test_dir))


After augmentation/up-sampling, we have...
------------------------------------------
train set size: 13069
val set size: 2625
test set size: 2646

See composition of patients in sets...
--------------------------------------
train set unique files: {'034', '024', '027', '012', '008', '020', '007', '014', '015', '004'}
val set unique files: {'011', '016', '030', '023'}
test set unique files: {'005', '017', '006', '019'}

(+/-) splits in sets...
-----------------------
train set split: (6897, 6172)
val set split: (1750, 875)
test set split: (1764, 882)


## Model definition - VGG19

In [11]:
model = torchvision.models.vgg19(pretrained=True).features
print(model)
for param in model.parameters():
    param.requires_grad = True

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace)
  (18): MaxPool2d(kernel_size=2, stride=2, padding=0, 

## Transfer learning: apply pre-trained weights to 96x96x3 patch-slices 
This process gives us slice-level tensors for aggregation for the whole patch (96x96x75)

In [6]:
def flatten(x):
    N = x.shape[0] # read in N, C, H, W
    return x.view(N, -1)

class Flatten(nn.Module):
    def forward(self, x):
        return flatten(x)
    
class PrintLayer(nn.Module):
    def __init__(self):
        super(PrintLayer, self).__init__()
    def forward(self, x):
        print(x.shape)
        return x

In [12]:
model_t = nn.Sequential(
    Flatten(),
    
    nn.Linear(in_features= 4608, out_features=4096, bias=True),
    nn.ReLU(inplace = True),
    nn.Dropout(p=0.5),
    
    nn.Linear(in_features=4096, out_features=4096, bias=True),
    nn.ReLU(inplace = True),
    nn.Dropout(p=0.5),
    
    nn.Linear(in_features=4096, out_features=2, bias=True),
)

In [8]:
def validation(model_t, val_loader):
    it = 0
    pooled_val_data = []
    pooled_val_labels =[]
    val_losses = []
    for fdl_val in val_loader:
        
        (f_val, d_val, l_val) = fdl_val

        d_val = torch.from_numpy(d_val)
        l_val = torch.from_numpy(l_val)

        d_val = d_val.to(device=device, dtype=dtype)  # move to device, e.g. GPU
        l_val = l_val.to(device=device, dtype=torch.long)

        d_slice_v = model(d_val)
        d_pooled_v = pool_batch(d_slice_v, mode='mean')
        l_pooled_v = pool_labels(l_val)
        pooled_val_data.append(d_pooled_v)
        pooled_val_labels.append(l_pooled_v)
        
        if it != 0 and it % 4:
            d_pooled_v = torch.cat(pooled_val_data, dim=0)
            l_pooled_v = torch.cat(pooled_val_labels, dim=0)
            pooled_val_data = []
            pooled_val_labels =[]
            
            val_scores = model_t(d_pooled_v)
            val_loss = F.cross_entropy(val_scores, l_pooled_v)
            val_losses.append(val_loss.item())
            
        it += 0
    return np.mean(val_losses)

In [13]:
def train_transfer(model, model_t, vgg_optimizer, t_optimizer):
    ppb = 5
    print_every = 10
    epochs = 10

    train_loader = DataLoader(utils.train_dir, batch_size=ppb, transfer=True)

    i = 0 # batch number
    train_losses, val_losses = [], []
    cur_val = 999
    consec_increases = 0
    model = model.to(device=device)
    model_t = model_t.to(device=device)
    
    model.eval()
#     model.train()
    model_t.train()

    for e in range(1, epochs + 1):
        pooled_train_data = []
        pooled_train_labels =[]
        ct = 0
        for fdl_train in train_loader:

            # train
            (f_train, d_train, l_train) = fdl_train
            
            d_train = torch.from_numpy(d_train)
            l_train = torch.from_numpy(l_train)
            
            d_train = d_train.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            l_train = l_train.to(device=device, dtype=torch.long)
            
            d_slice_t = model(d_train) # (5*25, 3, 3, 512)
            d_pooled_t = pool_batch(d_slice_t, batch_size=ppb, mode='max') # (5, 3, 3, 512)
            l_pooled_t = pool_labels(l_train)
            pooled_train_data.append(d_pooled_t)
            pooled_train_labels.append(l_pooled_t)
            
            if i != 0 and i % 2 == 0:
                
                d_pooled_t = torch.cat(pooled_train_data, dim=0)
                l_pooled_t = torch.cat(pooled_train_labels, dim=0)

                pooled_train_data = []
                pooled_train_labels =[]
            
                # get train metrics
                train_scores = model_t(d_pooled_t)
                train_loss = F.cross_entropy(train_scores, l_pooled_t)

                vgg_optimizer.zero_grad()
                t_optimizer.zero_grad()
                train_loss.backward()
                vgg_optimizer.step()
                t_optimizer.step()
                
                print(d_train.grad)

                train_losses.append(train_loss)

                if ct % print_every == 0:
                    print("iter:", ct+print_every, "train:", train_loss.item())


                # update cur_val
                ct += 1

            i += 1

        torch.save(model, utils.model_dir + "transfer_epoch%s_maxpool_vgg.pt" % e)
        torch.save(model_t, utils.model_dir + "transfer_epoch%s_maxpool_clf.pt" % e)
#         with torch.no_grad():
#             model_t.eval()
#             print('validating...')
#             val_loader = DataLoader(utils.val_dir, batch_size=ppb, transfer=True)
#             val_loss = validation(model_t, val_loader)
#             val_losses.append(val_loss)
        
        print('\nepoch:', e, 'train:', train_loss.item()) #, "val:", val_loss)
        
        
    return train_losses, val_losses


In [14]:
learning_rate = 1e-5
#optimizer = optim.SGD(model.parameters(),
 #                     lr=learning_rate,
  #                    momentum=0.9,
   #                   nesterov=True)
    
vgg_optimizer = optim.Adam(model.parameters(),
                      lr=learning_rate)
t_optimizer = optim.Adam(model_t.parameters(),
                      lr=learning_rate)

loss_history, val_history = train_transfer(model, model_t, vgg_optimizer, t_optimizer)
torch.save(model, utils.model_dir + "transfer_full_maxpool_vgg.pt" % e)
torch.save(model_t, utils.model_dir + "transfer_full_maxpool_clf.pt")

None
iter: 10 train: 0.6485349535942078


KeyboardInterrupt: 

In [None]:

#     # cute printout for sanity (~27,000 train)
#     if (i > 0) and ((i+1) % print_every == 0):
#         print("%i patches complete" % ((i+1)*ppb))


# train_names = ['train_loss', 'train_acc', "train_auc"]
# val_names = ['val_loss', 'val_acc', "val_auc"]