## IMPORTS

In [None]:
import os
import pathlib
import time
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

import deterministic
import dataset as dst

plt.rcParams.update({'figure.figsize': (12, 12)})
plt.rcParams.update({'font.size': 20})
print("Using CUDA?", torch.cuda.is_available())

%load_ext autoreload
%autoreload 2

## SETTINGS, PRELIMINARIES

In [None]:
path_to_project = pathlib.Path(os.getcwd()).parent

data_path = os.path.join(path_to_project, 'data')
result_path = os.path.join(path_to_project, 'results', 'deterministic', datetime.now().strftime('%Y_%m_%d_%H_%M'))
figure_path = os.path.join(result_path, 'figures')
model_path = os.path.join(result_path, 'models')

try:
    os.mkdir(result_path)
    os.mkdir(figure_path)
    os.mkdir(model_path)
except FileExistsError:
    print("The result directory already exists. Its contents may be overwritten!")

print("Results are being saved to", result_path)

## HYPERPARAMETERS

In [None]:
PARAMS = {
    'PATCH_SIZE': 64,
    'BATCH_SIZE': 128,
    'LEARNING_RATE': 5e-5,
    'ITS': 2000 + 1,
    'LOSS_FUNC' : nn.MSELoss()
}

with open(os.path.join(result_path, 'params.txt'), "w") as log:
    print(PARAMS, file=log)

## DATASETS

In [None]:
training_ds = dst.AgbDataset(
    os.path.join(data_path, 'training'), 
    patch_size = PARAMS['PATCH_SIZE']
)

validation_ds = dst.AgbDataset(
    os.path.join(data_path, 'validation'), 
    patch_size = PARAMS['PATCH_SIZE']
)

testing_ds = dst.AgbDataset(
    os.path.join(data_path, 'testing'), 
    patch_size = PARAMS['PATCH_SIZE']
)

idx = 0

testing_ds.show(idx)
plt.savefig(os.path.join(figure_path, "obs_bm.png"))

## INITIALIZE NETWORK AND OPTIMIZER

In [None]:
net = deterministic.Deterministic()

opt = optim.RMSprop(net.parameters(), lr=PARAMS['LEARNING_RATE'])

## TRAINING

In [None]:
start = time.time()

min_rmse = np.inf

rmse_log = []

for it in tqdm(range(PARAMS['ITS'])):

    # Training
    l, x = training_ds.get_batch(PARAMS['BATCH_SIZE'])
    loss = net.training_iteration(l, x, loss_func=PARAMS['LOSS_FUNC'], optimizer=opt)

    # Validation
    if it % 100 == 0:
        rmse = deterministic.evaluate_net_on_ds(net, validation_ds)
        rmse_log.append(rmse)

        if rmse < min_rmse:
            torch.save(net.state_dict(), os.path.join(model_path, 'net.pt'))
            min_rmse = rmse

net.load_state_dict(torch.load(os.path.join(model_path, 'net.pt')))

stop = time.time()

## LOSS CURVE

In [None]:
plt.figure()
plt.plot(np.arange(len(rmse_log)) * 100, np.array(rmse_log), 'g-')
plt.xlabel('Epochs')
plt.ylabel('RMSE', color='g')

plt.savefig(os.path.join(figure_path, f"losscurve.png"))

## EVALUATE NETWORK AND SHOW RESULT

In [None]:
rmse_val = deterministic.evaluate_net_on_ds(net, validation_ds)
rmse_tst = deterministic.evaluate_net_on_ds(net, testing_ds)

l, x = testing_ds.get_full(idx)
x = dst.unnormalize_x(x)

net.apply(l, fig=True, x=x)
plt.savefig(os.path.join(figure_path, "est.png"))

## LOG RESULTS

In [None]:
with open(os.path.join(result_path, 'res.txt'), "w") as log:
    print("Training took", stop-start, "seconds.", file=log)
    print(f"Final Validation - RMSE: {rmse_val}", file=log)
    print(f"Final Test - RMSE: {rmse_tst}", file=log)