In [30]:
import torch
from torch import autograd
import torch.nn.functional as F
from torch.autograd import Variable
import torch.nn as nn
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import time
import pickle
import copy
import torch.optim as optim

In [23]:
in_height = 120
in_width = 120
in_channels = 3

In [27]:
class AlexNet(nn.Module):
    def __init__(self, num_classes = 5):
        super(AlexNet, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 48, kernel_size = 11, stride = 4),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(48, eps = 0.001),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(48, 128, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128, eps=0.001),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(128, 192, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128, eps=0.001),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.classifier = nn.Sequential(
            
            nn.Dropout(),
            nn.Linear(128 * int(in_height/32) * int(in_width/32), 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, num_classes)
        )
    def forward(self, x):
        x = self.net(x)
        x = x.view(-1, 128 * in_height/32 * in_width/32)
        x = self.classifier(x)
        return x

In [14]:
def train(model, criterion, optimizer, lr_scheduler, dsets_loader, dset_sizes, lr = 0.001, num_epochs = 20):
    since = time.time()
    use_model = model
    best_acc = 0.0
    history = {
        x :[]
        for x in ['train','val']
    }
    
    for i in range(num_epoches):
        print('Epoch {}/{}'.format(i, num_epochs - 1))
        print('-'*10)
        for mode in ['train','val']:
            start = time.time()
            
            if mode == 'train':
                optimizer = lr_scheduler(optimizer, epoch, init_lr = lr)
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0
            
            for data in dsets_loader[mode]:
                inputs, labels = data
                inputs_, labels_ = Variable(inputs), Variable(labels)
                
                optimizer.zero_grad()
                
                outputs = model(inputs_)
                _, preds = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels_)
                
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                running_loss += loss.data[0]
                running_corrects += torch.sum(preds == labels_.data)
                
            epoch_loss = running_loss / dsets_sizes[mode]
            epoch_acc = running_corrects / dsets_sizes[mode]
            epoch_time = time.time() - start
            
            history_dict =  {
                'mode':mode,
                'epoch':i,
                'epoch_loss': epoch_loss,
                'epoch_accuracy': epoch_acc,
                'learing_rate':optimizer.state_dict()['param_groups'][0]['lr']
                
            }
            
            history[mode].append(history_dict)
            
            print('{} Loss: {:.4f} Acc:{:.4f} Time: {:.0f}m {:.0f}s'.format(mode, epoch_loss, epoch_acc, epoch_time//60, epoch_time%60))
            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model = copy.deepcopy(model)
        print()
    time_elapsed = time.time() - since
    
    print('Training completion in {:.0f}m {:.0f}s'.format(time_elapsed//60, time_elapsed%60))
    print('Best validation accuracy is {:.4f}'.format(best_acc))
    
    print('save mode')
    ##here I want to save the history of the process in order to debug
    with open('history.pickle','wb') as f:
        pickle.dump(history, f)
    
    return best_model, best_acc


In [15]:
#add scheduler to better control learning rate 
def exp_lr_scheduler(optimizer, epoch, init_lr = 0.001, lr_decay_epoch = 5):
    lr = init_lr * (0.9 ** (epoch // lr_decay_epoch))
    
    if epoch % lr_devay_epoch == 0:
        print('Learning rate is changed to {}'.format(lr))
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return optimizer

In [17]:
#import mri_loader

In [19]:
train_file = ''
val_file = ''
##load data
#dset_loaders, dset_sizes, dset_classes = mri_loader.load_data(train_path=train_path, val_path=val_path)

In [28]:
net = AlexNet()

In [32]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), weight_decay = 0.0005)
lr_scheduler = exp_lr_scheduler

In [34]:
lr = 0.001
#best_mode, best_acc = train(net, criterion, optimizer, lr_scheduler, dset_loaders, dset_sizes, lr, 40)

In [None]:
print('Saving best model')
filename = 'trained_model_{:.2f}.pt'.format(best_acc)
torch.save(best_model_state_dict(), filename)