# UNet model

Running the UNet from here 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import skorch
import torchvision.datasets as dset
import torchvision.models as models
import collections
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torchvision.transforms as transforms

from optparse import OptionParser
from eval import eval_net
from unet import UNet
from utils import *



import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from UNet_Loader import ILDDataset

In [2]:
NUM_TOTAL = 885
NUM_TRAIN = 700

lung_dataset_train = ILDDataset(cystic_path='/Users/magdy/Desktop/Stanford Spring/BMI260/Project/Data/cystic_dataset_masks/Train',
                          root_dir='/Users/magdy/Desktop/Stanford Spring/BMI260/Project/Data/Cystic Dataset/Train',
                          mask=True, HU=True, resize=256)

lung_dataset_test = ILDDataset(cystic_path='/Users/magdy/Desktop/Stanford Spring/BMI260/Project/Data/cystic_dataset_masks/Test',
                          root_dir='/Users/magdy/Desktop/Stanford Spring/BMI260/Project/Data/Cystic Dataset/Test',
                          mask=True, HU=True, resize=256)

loader_train = DataLoader(lung_dataset_train, batch_size=4, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

loader_val = DataLoader(lung_dataset_train, batch_size=4, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, NUM_TOTAL)))

loader_test = DataLoader(lung_dataset_test, batch_size=4)

In [3]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cpu


# Check Accuracy

In [4]:
def check_accuracy(loader, model, train=False):
    if loader.dataset.train and train == True:
        print('Checking accuracy on training set')
    elif loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            
            [N,H,W] =  [*x.size()]
            
            
            a = collections.Counter(y).most_common()[0][0]
            y = torch.LongTensor([a])
            y = y.to(device=device, dtype=torch.long)

                                
            scores = model(x.view(1, 1, N, H , W))
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples      
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
        
    return acc

# Train

In [5]:
#Train function

def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05, cp=True, gpu=False):
    optimizer = optim.SGD(net.parameters(),
                          lr=lr, momentum=0.9, weight_decay=0.0005)
    criterion = nn.BCELoss()

    for epoch in range(epochs):
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))

        # reset the generators
#         train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
#         val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)

        epoch_loss = 0

        if 1:
            val_dice = eval_net(net, loader_val, device, gpu)
            print('Validation Dice Coeff: {}'.format(val_dice))
            
        for t, (X, y) in enumerate(loader_train):
            X = X.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=dtype)
            X.unsqueeze_(1)
            y.unsqueeze_(1)

            if gpu:
                X = Variable(X).cuda()
                y = Variable(y).cuda()
            else:
                X = Variable(X)
                y = Variable(y)

            y_pred = net(X)
            probs = F.sigmoid(y_pred)
            probs_flat = probs.view(-1)

            y_flat = y.view(-1)

            loss = criterion(probs_flat, y_flat.float())
            epoch_loss += loss.data[0]

            if(t%10 == 0):
                print('{0:.4f} --- loss: {1:.6f}'.format(t,
                                                     loss.data[0]))
            optimizer.zero_grad()

            loss.backward()

            optimizer.step()
            if(t>5):
                break
        print('Epoch finished ! Loss: {}'.format(epoch_loss / i))

        if cp:
            torch.save(net.state_dict(),
                       dir_checkpoint + 'CP{}.pth'.format(epoch + 1))

            print('Checkpoint {} saved !'.format(epoch + 1))


In [None]:
# # if __name__ == '__main__':
# parser = OptionParser()
# parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int',
#                   help='number of epochs')
# parser.add_option('-b', '--batch-size', dest='batchsize', default=10,
#                   type='int', help='batch size')
# parser.add_option('-l', '--learning-rate', dest='lr', default=0.1,
#                   type='float', help='learning rate')
# parser.add_option('-g', '--gpu', action='store_true', dest='gpu',
#                   default=False, help='use cuda')
# parser.add_option('-c', '--load', dest='load',
#                   default=False, help='load file model')

# (options, args) = parser.parse_args()

net = UNet(1, 1)

# if options.load:
#     net.load_state_dict(torch.load(options.load))
#     print('Model loaded from {}'.format(options.load))

# if options.gpu:
#     net.cuda()
#     cudnn.benchmark = True

try:
    train_net(net) # , options.epochs, options.batchsize, options.lr,gpu=options.gpu)
except KeyboardInterrupt:
    torch.save(net.state_dict(), 'INTERRUPTED.pth')
    print('Saved interrupt')
    try:
        sys.exit(0)
    except SystemExit:
        os._exit(0)

Starting epoch 1/5.
Checking accuracy on test set
