# UNet model

Running the UNet from here 

In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import skorch
import torchvision.datasets as dset
import torchvision.models as models
import collections
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torchvision.transforms as transforms
from tqdm import tqdm


from optparse import OptionParser
from unet import UNet
from utils import *
from myloss import dice_coeff


import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from UNet_Loader import ILDDataset

In [26]:
NUM_TOTAL = 885
NUM_TRAIN = 700

lung_dataset_train = ILDDataset(cystic_path='/Users/magdy/Desktop/Stanford Spring/BMI260/Project/Data/Cystic_masks_new/Train',
                          root_dir='/Users/magdy/Desktop/Stanford Spring/BMI260/Project/Data/Cystic Dataset/Train',
                          mask=True, HU=True, resize=256)

lung_dataset_test = ILDDataset(cystic_path='/Users/magdy/Desktop/Stanford Spring/BMI260/Project/Data/Cystic_masks_new/Test',
                          root_dir='/Users/magdy/Desktop/Stanford Spring/BMI260/Project/Data/Cystic Dataset/Test',
                          mask=True, HU=True, resize=256)

loader_train = DataLoader(lung_dataset_train, batch_size=4, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

loader_val = DataLoader(lung_dataset_train, batch_size=4, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, NUM_TOTAL)))

loader_test = DataLoader(lung_dataset_test, batch_size=4)

In [27]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cpu


# Check Accuracy

In [28]:
def eval_net(net, loader, device, gpu=False):
    if loader.dataset.train:
        print('Checking accuracy on training set')
    else:
        print('Checking accuracy on test set')   

    tot = 0
    cntr = 0
    with torch.no_grad():
        for X, y in loader:
            X = X.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.float32)

            if gpu:
                X = Variable(X, requires_grad=True).cuda()
                y = Variable(y, requires_grad=True).cuda()
            else:
                X = Variable(X, requires_grad=True)
                y = Variable(y, requires_grad=True)

            X.unsqueeze_(1)
            y.unsqueeze_(1)
            y_pred = net(X)

            y_pred = (F.sigmoid(y_pred) > 0.6).float()
            dice = dice_coeff(y_pred, y.float()).data[0]
            tot += dice
            cntr += 1
            if 1:
                X = X.data.squeeze(0).cpu().numpy()
                X = np.transpose(X, axes=[1, 2, 0])
                y = y.data.squeeze(0).cpu().numpy()
                y_pred = y_pred.data.squeeze(0).squeeze(0).cpu().numpy()
                print(y_pred.shape)

                fig = plt.figure()
                ax1 = fig.add_subplot(1, 4, 1)
                ax1.imshow(X)
                ax2 = fig.add_subplot(1, 4, 2)
                ax2.imshow(y)
                ax3 = fig.add_subplot(1, 4, 3)
                ax3.imshow((y_pred > 0.5))

                Q = dense_crf(((X * 255).round()).astype(np.uint8), y_pred)
                ax4 = fig.add_subplot(1, 4, 4)
                print(Q)
                ax4.imshow(Q > 0.5)
                plt.show()
    return tot / cntr

# Train

In [32]:
#Train function

def train_net(net, epochs=5, batch_size=5, lr=0.1, val_percent=0.05, cp=True, gpu=False):
    optimizer = optim.SGD(net.parameters(),
                          lr=lr, momentum=0.9, weight_decay=0.0005)
    criterion = nn.BCELoss()
    dir_checkpoint = 'checkpoints/'
    
    for epoch in range(epochs):
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))

        epoch_loss = 0
            
        for t, (X, y) in enumerate(loader_train):
            X = X.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=dtype)
            X.unsqueeze_(1)
            y.unsqueeze_(1)

            if gpu:
                X = Variable(X).cuda()
                y = Variable(y).cuda()
            else:
                X = Variable(X)
                y = Variable(y)
            
            
            y_pred = net(X)
            print(y_pred.shape)
            
            fig = plt.figure()
            ax1 = fig.add_subplot(1, 2, 1)
            ax1.imshow(X.data.numpy()[0,0,:])
            ax2 = fig.add_subplot(1, 2, 2)
            ax2.imshow(y_pred.data.numpy()[0,0,:])
            plt.show()
            
            probs = F.sigmoid(y_pred)
            probs_flat = probs.view(-1)

            y_flat = y.view(-1)

            loss = criterion(probs_flat, y_flat.float())
            epoch_loss += loss.data[0]

            if(t%10 == 0):
                print('{0:.4f} --- loss: {1:.6f}'.format(t,
                                                     loss.data[0]))
            optimizer.zero_grad()

            loss.backward()

            optimizer.step()
        print('Epoch finished ! Loss: {}'.format(epoch_loss / t))
        
        if 1:
            val_dice = eval_net(net, loader_val, device, gpu)
            print('Validation Dice Coeff: {}'.format(val_dice))

        if cp:
            torch.save(net.state_dict(),
                       dir_checkpoint + 'CP{}.pth'.format(epoch + 1))

            print('Checkpoint {} saved !'.format(epoch + 1))


In [33]:
net = UNet(1, 1)

try:
    train_net(net) # , options.epochs, options.batchsize, options.lr,gpu=options.gpu)
except KeyboardInterrupt:
    torch.save(net.state_dict(), 'INTERRUPTED.pth')
    print('Saved interrupt')

Starting epoch 1/5.
torch.Size([4, 1, 256, 256])
0.0000 --- loss: 0.688255
torch.Size([4, 1, 256, 256])
torch.Size([4, 1, 256, 256])
torch.Size([4, 1, 256, 256])
Saved interrupt
