In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

import tools._torch_tools as tt
import models.loss as lo
from models.interpolation import Model

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device is", device)

data_folder = "dataset_interpolation/"

In [None]:
X_train, y_train = tt.loadData(data_folder,train=True,typeF='float16',channels_last=False)

In [None]:
X_val, y_val =  tt.loadData(data_folder,val=True,typeF='float16',channels_last=False)

In [None]:
X_test, y_test = tt.loadData(data_folder,test=True,typeF='float16',channels_last=False)

### Training with the L1 loss

In [None]:
model = Model()
training = tt.Training(model, device, X_train, y_train, X_val, y_val, loss_function=nn.L1Loss())

In [None]:
try:
    training.fit(32,50,val=True)
except KeyboardInterrupt:
    print("\n\nFinished training.")

In [None]:
training.save("l1.pth")

In [None]:
tt.plotHistory(training.history,save="history.png",size=(5,4))
tt.toCSV("history.csv",training.history)

### Training with the perceptual loss

In [None]:
model = Model()
model.load_state_dict(torch.load("l1.pth",map_location=device))
loss = lo.CombinedLoss(-18,0.00001)
training = tt.Training(model, device, X_train, y_train, X_val, y_val, loss_function=loss)

In [None]:
try:
    training.fit(32,35,val=True)
except KeyboardInterrupt:
    print("\n\nFinished training.")

In [None]:
training.save("final.pth")

In [None]:
tt.plotHistory(training.history,size=(5,4),save="history_2.png")
tt.toCSV("history_2.csv",training.history)

### SSIM index on test dataset

In [None]:
def validateSSIM(in_weights, X_set, y_set):
    model = Model()
    model.load_state_dict(torch.load(in_weights,map_location=device))
    model.eval()
    validating = tt.Training(model,device,X_set, y_set, X_set, y_set, loss_function=lo.SSIMLoss())
   
    return validating.validate(1)

In [None]:
final_test_SSIM = validateSSIM("final.pth",X_test, y_test)