In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

import tools._torch_tools as tt
import models.loss as lo
from models.extrapolation import Model

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device is", device)

data_folder = "dataset_extrapolation/"

In [None]:
X_train, y_train = tt.loadData(data_folder,train=True,typeF='float16')

In [None]:
X_val, y_val = tt.loadData(data_folder,val=True,typeF='float16')

In [None]:
X_test, y_test = tt.loadData(data_folder,test=True,typeF='float16')

In [None]:
class Training(tt.Training):
    def __init__(self,*args,**kwargs):
        super(Training,self).__init__(*args,**kwargs)
        
    def getBatch(self, offset, batch_size, val=False):
        input,target = super(Training,self).getBatch(offset, batch_size, val=val)
        target = target[:,:1,:,:]
        
        return input, target

### Training with the L1 loss

In [None]:
model = Model()
training = Training(model,device,X_train, y_train, X_val, y_val, loss_function=nn.L1Loss())

In [None]:
try:
    training.fit(32,40,val=True)
except KeyboardInterrupt:
    print("\n\nFinished training.")

In [None]:
training.save("l1.pth")

In [None]:
tt.plotHistory(training.history,save="history.png",size=(5,4))
tt.toCSV("history.csv",training.history)

### Training with the perceptual loss

In [None]:
model = Model()
model.load_state_dict(torch.load("l1.pth",map_location=device))
loss = lo.CombinedLoss(-18,0.00001)
training = Training(model,device,X_train, y_train, X_val, y_val, loss_function=loss)

In [None]:
try:
    training.fit(32, 35,val=True)
except KeyboardInterrupt:
    print("\n\nFinished training.")

In [None]:
training.save("final.pth")

In [None]:
tt.plotHistory(training.history,size=(5,4),save="history_2.png")
tt.toCSV("history_2.csv",training.history)

### SSIM test set validation

In [None]:
def toTensor(x):
    return torch.tensor(x,dtype=torch.float).view(1,-1,96,96).to(device)

In [None]:
ssim = lo.SSIMLoss()
model = Model().to(device)
model.load_state_dict(torch.load("final.pth",map_location=device))
model.eval()

In [None]:
res = [0.,0.,0.]
for i in range(len(X_test)):
    out = model.predict(toTensor(X_test[i:i+1]),3)
    for j in range(3):
        res[j] += ssim(out[:,j:j+1],toTensor(y_test[i:i+1,j:j+1]))
        
    if i % 8 == 0:
        print('\r'+str(i).zfill(5)+" "+str(res[0]/(i+1)),end='',flush=True)

print("")
for i in range(3):
    res[i] = res[i] / len(X_test)
    print(res[i])
    print("================")