## IMPORTS

In [None]:
import os
import pathlib
import time
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

import mhnn
import dataset as dst

plt.rcParams.update({'figure.figsize': (12, 12)})
plt.rcParams.update({'font.size': 20})
print("Using CUDA?", torch.cuda.is_available())

%load_ext autoreload
%autoreload 2

## SETTINGS, PRELIMINARIES

In [None]:
path_to_project = pathlib.Path(os.getcwd()).parent

data_path = os.path.join(path_to_project, 'data')
result_path = os.path.join(path_to_project, 'results', 'mhnn', datetime.now().strftime('%Y_%m_%d_%H_%M'))
figure_path = os.path.join(result_path, 'figures')
model_path = os.path.join(result_path, 'models')

try:
    os.mkdir(result_path)
    os.mkdir(figure_path)
    os.mkdir(model_path)
except FileExistsError:
    print("The result directory already exists. Its contents may be overwritten!")

print("Results are being saved to", result_path)

## HYPERPARAMETERS

In [None]:
PARAMS = {
    'PATCH_SIZE': 64,
    'BATCH_SIZE': 128,
    'LEARNING_RATE': 5e-5,
    'ITS_WU': 5000 + 1,
    'ENSEMBLE_SIZE': 10,
    'ITS_PER_CYCLE': 1000 + 1
}

with open(os.path.join(result_path, 'params.txt'), "w") as log:
    print(PARAMS, file=log)

## DATASETS

In [None]:
training_ds = dst.AgbDataset(
    os.path.join(data_path, 'training'), 
    patch_size = PARAMS['PATCH_SIZE']
)

validation_ds = dst.AgbDataset(
    os.path.join(data_path, 'validation'), 
    patch_size = PARAMS['PATCH_SIZE']
)

testing_ds = dst.AgbDataset(
    os.path.join(data_path, 'testing'), 
    patch_size = PARAMS['PATCH_SIZE']
)

idx = 0

testing_ds.show(idx)
plt.savefig(os.path.join(figure_path, "obs_bm.png"))

## INITIALIZE NETWORK AND OPTIMIZER

In [None]:
net = mhnn.MHNN()

opt = optim.RMSprop(net.parameters(), lr=PARAMS['LEARNING_RATE'])

## WARMUP

In [None]:
start_warmup = time.time()

print("--- WARMUP ---")

min_rmse = np.inf
min_uce = np.inf

rmse_log = []
uce_log = []

for it in tqdm(range(PARAMS['ITS_WU'])):

    l, x = training_ds.get_batch(PARAMS['BATCH_SIZE'])

    # Training
    loss = net.training_iteration(l, x, optimizer=opt)

    # Validation
    if it % 100 == 0:
        rmse, uce = mhnn.evaluate_net_on_ds(net, validation_ds)
        rmse_log.append(rmse)
        uce_log.append(uce)

        if uce < min_uce:
            torch.save(net.state_dict(), os.path.join(model_path, 'warmup_net.pt'))
            min_rmse, min_uce = rmse, uce

net.load_state_dict(torch.load(os.path.join(model_path, 'warmup_net.pt')))

stop_warmup = time.time()

## EVALUATE AND SHOW RESULT ON WARMUP MODEL

In [None]:
rmse_wu_val, uce_wu_val = mhnn.evaluate_net_on_ds(net, validation_ds)
rmse_wu_tst, uce_wu_tst = mhnn.evaluate_net_on_ds(net, testing_ds)

l, x = testing_ds.get_full(idx)
x = dst.unnormalize_x(x)

net.apply(l, fig=True, x=x)
plt.savefig(os.path.join(figure_path, "est_wu.png"))

## INITIALIZE SCHEDULER

In [None]:
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, PARAMS['ITS_PER_CYCLE'])

## ENSEMBLES

In [None]:
start_ens = time.time()

print("--- ENSEMBLES ---")

plt.figure()
for i in range(PARAMS['ENSEMBLE_SIZE']):

    print(f"Ensemble {i}")

    min_rmse = np.inf
    min_uce = np.inf

    for it in tqdm(range(PARAMS['ITS_PER_CYCLE'])):

        # Training
        l, x = training_ds.get_batch(PARAMS['BATCH_SIZE'])
        loss = net.training_iteration(l, x, optimizer=opt, scheduler=scheduler)

        # Validation
        if it % 100 == 0:
            rmse, uce = mhnn.evaluate_net_on_ds(net, validation_ds)
            rmse_log.append(rmse)
            uce_log.append(uce)

            if uce < min_uce:
                torch.save(net.state_dict(), os.path.join(model_path, f"ens_{i}_net.pt"))
                min_rmse, min_uce = rmse, uce

    l, x = testing_ds.get_full(idx)
    x = dst.unnormalize_x(x)
    net.apply(l, fig=True, x=x)
    plt.savefig(os.path.join(figure_path, f"est_{i}.png"))

stop_ens = time.time()

## LOSS CURVE

In [None]:
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.plot((np.arange(len(rmse_log)) + 1) * 100, np.array(rmse_log), 'g-')
ax2.plot((np.arange(len(uce_log)) + 1) * 100, np.array(uce_log), 'b-')
ax1.set_xlabel('Training Iteration')
ax1.set_ylabel('RMSE', color='g')
ax2.set_ylabel('UCE', color='b')

plt.savefig(os.path.join(figure_path, f"losscurve.png"))

## EVALUATE ENSEMBLE

In [None]:
ensemble_rmses_val, ensemble_uces_val, rmses_val, uces_val = mhnn.evaluate_ensemble_on_ds(model_path, validation_ds, module=mhnn.MHNN, fig=False)
ensemble_rmses_tst, ensemble_uces_tst, rmses_tst, uces_tst = mhnn.evaluate_ensemble_on_ds(model_path, testing_ds, module=mhnn.MHNN, fig=True)
plt.savefig(os.path.join(figure_path, "calib_curve.png"))

## LOG RESULTS

In [None]:
with open(os.path.join(result_path, 'res.txt'), "w") as file:
    print("Training took", (stop_warmup-start_warmup) + (stop_ens-start_ens), "seconds.", file=file)
    print(f'Validation after warmup - RMSE: {rmse_wu_val}, UCE: {uce_wu_val}', file=file)
    print(f'Testing after warmup - RMSE: {rmse_wu_tst}, UCE: {uce_wu_tst}', file=file)
    print(f"Final Ensemble Validation - Ensemble RMSE: {np.mean(ensemble_rmses_val)}, Ensemble UCE: {np.mean(ensemble_uces_val)}, Individual RMSE: {np.mean(rmses_val)}, Individual UCE: {np.mean(uces_val)}", file=file)
    print(f"Final Ensemble Test - Ensemble RMSE: {np.mean(ensemble_rmses_tst)}, Ensemble UCE: {np.mean(ensemble_uces_tst)}, Individual RMSE: {np.mean(rmses_tst)}, Individual UCE: {np.mean(uces_tst)}", file=file)