In [None]:
import torch
import torch.nn as nn
import torchvision

%matplotlib nbagg
import numpy as np
import matplotlib.pyplot as plt

import time
import os
os.environ['KMP_DUPLICATE_?LIB_OK']='True'

from dataset import *
from transforms import *
from criteria import *
from torch.utils.data import DataLoader, random_split

# torch.manual_seed(42)

## Dataset

In [None]:
x_transform = torchvision.transforms.Compose([PadOrCenterCrop((256,256)),
                                              ToTensor(make_CHW=True, input_format="HWC"),
                                              ToCuda()])
y_transform = torchvision.transforms.Compose([lambda y: np.array(eval(y)),
                                              ToTensor(make_CHW=False, out_type=torch.long),
                                              ToCuda()])

data_dir = "../dr_experiments/data/"
ds_train = SimpleDataset(data_dir + "x_train.txt", data_dir + "t_train.txt",
                         x_transform=x_transform, y_transform=y_transform,
                         x_path_prefix="../dr_experiments")
ds_val = SimpleDataset(data_dir + "x_val.txt", data_dir + "t_val.txt",
                       x_transform=x_transform, y_transform=y_transform,
                       x_path_prefix="../dr_experiments")
dl_train = DataLoader(ds_train, batch_size=12, shuffle=True, num_workers=0)
dl_val = DataLoader(ds_val, batch_size=18, shuffle=True, num_workers=0)

## Model

In [None]:
model = torchvision.models.resnet152(pretrained=True, progress=True)

# replace the fc layer
model = nn.Sequential(*[m for m in model.children()][:-1], 
                      nn.Flatten(),  # or shapes won't work out
                      nn.Linear(2048,5))

if torch.cuda.is_available():
    model = model.cuda()

In [None]:
sum(p.numel() for p in model.parameters())

In [None]:
load_model = False
model_file = "models_gn_8_1-9/model_e200.pkl"
if load_model:
    if torch.cuda.is_available():
        model.load_state_dict(torch.load(model_file))
    else:
        model.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))

## Training

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda() if torch.cuda.is_available() else criterion
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
fig, ax = plt.subplots(1, 1)
ax.set_ylim(0,1)
hl_train, = ax.plot([], [])
hl_val, = ax.plot([], [])
fig.canvas.draw()

In [None]:
from datetime import datetime, timedelta
start_time = datetime.now()

save_model = True
save_period = 10
save_root = "../dr_experiments/exp1/"
model_file_template = save_root + "model_e{}.pkl"

save_loss = True
loss_file_path = save_root + "loss.pkl"
loss_png_path = save_root + "loss.png"
train_loss = []
val_loss = []

epochs = 200
for epoch in range(epochs):  # loop over the dataset multiple times

    model.train()
    train_loss_epoch = []
    for i, (inputs, labels) in enumerate(dl_train):
        optimizer.zero_grad()
        outputs = model.forward(inputs).squeeze()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss_epoch += [loss.item()]

        # print statistics
        progress = (i+1 + epoch*len(dl_train)) / (epochs*len(dl_train))
        time_elapsed = datetime.now() - start_time
        time_to_completion = time_elapsed / progress - time_elapsed
        print("Epoch: {}, Train, Batch {}/{}, ETA: ".format(epoch+1, i+1, len(dl_train)) + 
              str(time_to_completion), end='\r')
                
    model.eval()
    val_loss_epoch = []
    for i, (inputs, labels) in enumerate(dl_val):
        with torch.no_grad():
            outputs = model.forward(inputs).squeeze()
            loss = criterion(outputs, labels)
        val_loss_epoch += [loss.item()]
        
        print("Epoch: {}, Val, Batch {}/{}".format(epoch+1, i+1, len(dl_val))+' '*40, end='\r')

    train_loss += [train_loss_epoch]
    val_loss += [val_loss_epoch]
    
    # update loss graph
    hl_train.set_xdata(np.append(hl_train.get_xdata(), epoch+1))
    hl_train.set_ydata(np.append(hl_train.get_ydata(), np.mean(train_loss_epoch)))
    hl_val.set_xdata(np.append(hl_val.get_xdata(), epoch+1))
    hl_val.set_ydata(np.append(hl_val.get_ydata(), np.mean(val_loss_epoch)))
    ax.legend(['Train','Val']); ax.relim(); ax.autoscale(axis='x'); fig.canvas.draw()
        
    if save_model and (epoch+1) % save_period == 0:
        if not os.path.exists(save_root):
            os.makedirs(save_root)
        if isinstance(model, nn.DataParallel):
            torch.save(model.module.state_dict(), model_file_template.format(epoch+1))
        else:
            torch.save(model.state_dict(), model_file_template.format(epoch+1))

if save_loss:
    pickle.dump((train_loss, val_loss), open(loss_file_path, 'wb'))
    plt.savefig(loss_png_path)

print('\nFinished Training')

In [None]:
# if save_loss:
#     pickle.dump((train_loss, val_loss), open(loss_file_path, 'wb'))
#     plt.savefig(loss_png_path)

# Tests