### Setup and load torch

In [None]:
%load_ext autoreload
%autoreload 2

from IPython.core.display import display, HTML
import sys,cv2,timeit,gc
sys.path.append('../')
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

from math import isinf

from Utils.load_batch import *
from Utils.utils import *
from Utils.center_images import *
from Utils.f1_score import *
from ipywidgets import interact
%matplotlib inline
display(HTML("<style>.container { width:100% !important; }</style>"))

from jupyterthemes import jtplot
jtplot.style(theme='grade3',context='paper', fscale=2.5, spines=True, gridlines='-',ticks=True, grid=True, figsize=(6, 4.5))
plotcolor = (0, 0.6, 1.0)

import torch
import torchvision
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn

import deepdish as dd

data_folder = 'D:/data/HPA/all/'
model_folder = 'D:/data/HPA/models/'
LOSS = 'BCE'
USE_SMALL_IMAGES = False
USE_ALL_CHANNELS = True
np.random.seed(100)

print("Using GPU:",torch.cuda.is_available())
print("Using device ",torch.cuda.get_device_name(0))
print("Done.")

### Load training data

In [None]:
%%time
filename = "_augmented"
if USE_SMALL_IMAGES:
    filename = filename + "_small.h5"
else:
    filename = filename + ".h5"
if USE_ALL_CHANNELS:
    filename = "all_channel" + filename
else:
    filename = "poi" + filename

d = dd.io.load(data_folder+filename)
    
X = d['filenames'] # filenames

y = d['labels']

idx = np.arange(y.shape[0])
np.random.shuffle(idx)
X = X[idx]
y = y[idx]

print(X.shape)
print(y.shape)
print("Done")

### Initialize (and load) model

In [None]:
LOAD_OLD_MODEL = False

#Setup resnet model
net = models.resnet34(pretrained=True)

old_weights = net.conv1.weight #remember weights even though we reinit first layer

if USE_ALL_CHANNELS:     
    net.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3,bias=False) # adapt to single channel input
    net.conv1.weights = torch.stack([old_weights[:,0,:,:], old_weights[:,0,:,:], old_weights[:,0,:,:], old_weights[:,0,:,:]]).permute([1,0,2,3])  
    net.requires_grad = False #freeze first layer
else:
    net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,bias=False) # adapt to single channel input
    net.conv1.weights = torch.stack([old_weights[:,0,:,:]]).permute([1,0,2,3])
    net.requires_grad = False #freeze first layer

if USE_SMALL_IMAGES:
    net.fc = nn.Linear(512, 28) #adapt last layer to allow larger input images
else:
    net.fc = nn.Linear(51200, 28) #adapt last layer to allow larger input images

# freeze feature layers (first 3 blocks, last we leave)
ct = 0
for child in net.children():
    ct += 1
    if ct < 8:
        for param in child.parameters():
            param.requires_grad = False

if LOAD_OLD_MODEL:
    modelname = "resnet34"
    if USE_ALL_CHANNELS:
        modelname = modelname + "_all"
    
    if USE_SMALL_IMAGES:
        modelname = modelname + "_small"
    net.load_state_dict(torch.load(model_folder+modelname+".model"))

net = net.cuda()    

print("Done.")

### Initialize loss and optimizer

In [None]:
from Utils.f1_loss import *
from Utils.f1_loss import *

if LOSS == 'F1':
    #Setup custom F1 Loss
    criterion = F1_Loss().cuda()
else:
    #Setup BCE Loss and optimizer
    weights = np.asarray(y.shape[0]/np.sum(y,axis=0))
    for idx,weight in enumerate(weights):
        if isinf(weight):
            weights[idx] = y.shape[0]
    print("Weights = ",weights) # we weight classes given their skewed distribution
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(weights,dtype=torch.float).cuda()).cuda()

optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()),lr=4e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.5, patience = 100, min_lr = 1e-7, verbose=True)

print("Done.")

### Run training

In [None]:
gc.collect() 

epochs = 5
batch_size = 100
time_per_epoch = 0
net.train()

iterations_per_epoch = np.ceil(X.shape[0] / batch_size).astype(int)
runtime = 0

for epoch in range(epochs):
    print_horizontal_divider()
    print("Starting Epoch ", epoch)
    print_horizontal_divider()
    
    running_loss,running_f1,average_targets_predicted = 0,0,0
      
    #reshuffle
    idx = np.arange(y.shape[0])
    np.random.shuffle(idx)
    X = X[idx]
    y = y[idx]
        
    for i in range(iterations_per_epoch):
        real_batch_size = np.minimum(batch_size,X.shape[0] - i*batch_size)
        if real_batch_size == 0: #in case X.shape[0] is divisible by batch size this is required
            real_batch_size = batch_size

        start = timeit.default_timer() #measure time
        
        start_idx = i * batch_size 
        filenames = X[start_idx:start_idx+real_batch_size]
        X_batch = load_batch(filenames,USE_ALL_CHANNELS).cuda()     
        y_batch = torch.tensor(y[start_idx:start_idx+real_batch_size].astype(np.float32),dtype=torch.float).cuda()
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        outputs = net(X_batch)
        
        if LOSS == 'F1':
            outputs = torch.sigmoid(outputs)
        
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        #compute F1 scores
        label = y_batch.cpu().detach().numpy().astype(np.bool)
        logits = outputs.cpu().detach().numpy() > 0.5
        average_targets_predicted = np.sum(logits) / batch_size
        average_targets_present = np.sum(label) / batch_size
        running_f1 += f1_score(label,logits)
        
        #measure runtime
        stop = timeit.default_timer()
        time_per_epoch = 0.5 * time_per_epoch + 0.5 * (stop-start) * iterations_per_epoch
        runtime += (stop-start)
        
        #update LR if stagnating
        scheduler.step(running_loss)
        
        #print performance metrics
        if i % 5 == 0:
            print('[epoch = (%d/%d), iteration = (%3d/%d), time = %3ds, est. time per epoch = %5ds] \t loss = %.5f ## F1 = %.5f pred/img = %.3f / %.3f'\
                  %(epoch + 1, epochs,i + 1, iterations_per_epoch, runtime, time_per_epoch, running_loss / (i+1), running_f1 / (i+1), average_targets_predicted, average_targets_present))

print("Overall done.")

### Load validation data

In [None]:
%%time
filename = "_6_"
if USE_SMALL_IMAGES:
    filename = filename + "small.h5"
else:
    filename = filename + ".h5"
if USE_ALL_CHANNELS:
    filename = "all_channel" + filename
else:
    filename = "poi" + filename

d = dd.io.load(data_folder+filename)
    
Xval = center_images(d['X'].astype(np.float32) / 255.0) # torch likes float images
yval = d['labels']

idx = np.arange(yval.shape[0])

np.random.shuffle(idx)
Xval = Xval[idx]
yval = yval[idx]
print("Done")

### Run validation

In [None]:
# Run validation
gc.collect() 
net.eval()

batch_size = 67
iterations_per_epoch = np.ceil(Xval.shape[0] / batch_size).astype(int)
time_per_epoch,runtime,running_loss,running_f1 = 0,0,0,0

for i in range(iterations_per_epoch):
    start = timeit.default_timer() #measure time

    start_idx = i * batch_size 
    X_batch = torch.tensor(Xval[start_idx:start_idx+batch_size].transpose(0,3,1,2))
    y_batch = torch.tensor(yval[start_idx:start_idx+batch_size].astype(np.float32),dtype=torch.float)

    # forward + backward + optimize
    outputs = net(X_batch)

    loss = criterion(outputs, y_batch)
    
    running_loss += loss.item()

    #compute F1 scores
    act = torch.sigmoid(outputs)

    label = y_batch.detach().numpy().astype(np.bool)
    logits = act.detach().numpy() > 0.5
    print("Targets in batch = ",np.sum(label),"Predicted targets = ",np.sum(logits))
    running_f1 += f1_score(label,logits)

    #measure runtime
    stop = timeit.default_timer()
    time_per_epoch = 0.5 * time_per_epoch + 0.5 * (stop-start) * iterations_per_epoch
    runtime += (stop-start)
    #print performance metrics
    N = ((i+1)*batch_size)
    print('[iteration = (%3d/%d), time = %3ds, est. time per epoch = %5ds] \t loss = %.5f ## F1 = %.5f'\
          %(i + 1, iterations_per_epoch, runtime, time_per_epoch, running_loss / N, running_f1 / (i+1)))

### Save the trained model

In [None]:
# Save model
modelname = "resnet34"
if USE_ALL_CHANNELS:
    modelname = modelname + "_all"

if USE_SMALL_IMAGES:
    modelname = modelname + "_small"
    
if USE_SMALL_IMAGES:
    torch.save(net.state_dict(), model_folder+modelname+".model")
else:
    torch.save(net.state_dict(),  model_folder+modelname+".model")
print("Done.")

### Shutdown system (can be run after training and saving)

In [None]:
import subprocess
cmdCommand = "shutdown -s"
process = subprocess.Popen(cmdCommand.split(), stdout=subprocess.PIPE)