## IMPORTS

In [None]:
import os
import pathlib
import time
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

import cwgan
import dataset as dst

plt.rcParams.update({'figure.figsize': (12, 12)})
plt.rcParams.update({'font.size': 20})
print("Using CUDA?", torch.cuda.is_available())

%load_ext autoreload
%autoreload 2

## SETTINGS, PRELIMINARIES

In [None]:
path_to_project = pathlib.Path(os.getcwd()).parent

data_path = os.path.join(path_to_project, 'data')
result_path = os.path.join(path_to_project, 'results', 'cgan', datetime.now().strftime('%Y_%m_%d_%H_%M'))
figure_path = os.path.join(result_path, 'figures')
model_path = os.path.join(result_path, 'models')

try:
    os.mkdir(result_path)
    os.mkdir(figure_path)
    os.mkdir(model_path)
except FileExistsError:
    print("The result directory already exists. Its contents may be overwritten!")

print("Results are being saved to", result_path)

## HYPERPARAMETERS

In [None]:
PARAMS = {
    'PATCH_SIZE': 64,
    'BATCH_SIZE' : 128,

    'LEARNING_RATE_PT' : 5e-5,
    'LOSS_FUNC_PT' : nn.MSELoss(),
    'ITS_PT' : 2000 + 1,

    'LEARNING_RATE' : 1e-5,
    'CRITIC_ITERATIONS' : 5,
    'WEIGHT_CLIP' : 0.01,
    'ITS_WU' : 8000 + 1,
    'ENSEMBLE_SIZE' : 10,
    'ITS_PER_CYCLE' : 1000 + 1,
    'FINE_TUNING' : False
}

with open(os.path.join(result_path, 'params.txt'), "w") as file:
    print(PARAMS, file=file)

## DATASETS

In [None]:
training_ds = dst.AgbDataset(
    os.path.join(data_path, 'training'), 
    patch_size = PARAMS['PATCH_SIZE']
)

validation_ds = dst.AgbDataset(
    os.path.join(data_path, 'validation'), 
    patch_size = PARAMS['PATCH_SIZE']
)

testing_ds = dst.AgbDataset(
    os.path.join(data_path, 'testing'), 
    patch_size = PARAMS['PATCH_SIZE']
)

idx = 0

testing_ds.show(idx)
plt.savefig(os.path.join(figure_path, "obs_bm.png"))

## INITIALIZE GENERATOR

In [None]:
G = cwgan.Generator()

opt_PT = optim.RMSprop(G.parameters(), lr=PARAMS['LEARNING_RATE_PT'])

## DETERMINISTIC PRE-TRAINING

In [None]:
start_PT = time.time()

print("--- PRE-TRAINING ---")

min_rmse = np.inf

rmse_log = []
uce_log = []

for it in tqdm(range(PARAMS['ITS_PT'])):

    # Training
    l, x = training_ds.get_batch(PARAMS['BATCH_SIZE'])
    loss = G.pretraining_iteration(l, x, loss_func=PARAMS['LOSS_FUNC_PT'], optimizer=opt_PT)

    #Validation
    if it % 100 == 0:
        rmse_pt, uce_pt = cwgan.evaluate_net_on_ds(G, validation_ds)
        rmse_log.append(rmse_pt)
        uce_log.append(uce_pt)

        if rmse_pt < min_rmse:
            torch.save(G.state_dict(), os.path.join(model_path, 'pretrain_G.pt'))
            
G.load_state_dict(torch.load(os.path.join(model_path, 'pretrain_G.pt')))

stop_PT = time.time()

## EVALUATE AND SHOW SAMPLE RESULT FROM PRE-TRAINED MODEL

In [None]:
rmse_pt_val, uce_pt_val = cwgan.evaluate_net_on_ds(G, validation_ds)
rmse_pt_tst, uce_pt_tst = cwgan.evaluate_net_on_ds(G, testing_ds)

l, x = testing_ds.get_full(idx)
x = dst.unnormalize_x(x)

G.apply(l, n_samples=4, fig=True, x=x)
plt.savefig(os.path.join(figure_path, "sample_ests_pt.png"))

## INITIALIZE CWGAN DISCRIMINATOR AND OPTIMIZERS

In [None]:
D = cwgan.Discriminator()

opt_G = optim.RMSprop(G.parameters(), lr=PARAMS['LEARNING_RATE'])
opt_D = optim.RMSprop(D.parameters(), lr=PARAMS['LEARNING_RATE'])

## CWGAN WARMUP

In [None]:
start_warmup = time.time()

print("--- WARMUP ---")

min_rmse = np.inf
min_uce = np.inf

for it in tqdm(range(PARAMS['ITS_WU'])):

    # Discriminator training
    for i in range(PARAMS['CRITIC_ITERATIONS']):
        l, x = training_ds.get_batch(PARAMS['BATCH_SIZE'])
        loss_D = D.training_iteration(l, x, G, optimizer=opt_D, weight_clip=PARAMS['WEIGHT_CLIP'])

    # Generator training
    l, x = training_ds.get_batch(PARAMS['BATCH_SIZE'])
    loss_G = G.training_iteration(l, D, optimizer=opt_G)

    # Validation
    if it % 100 == 0:
        rmse, uce = cwgan.evaluate_net_on_ds(G, validation_ds)
        rmse_log.append(rmse)
        uce_log.append(uce)

        if uce < min_uce:
            torch.save(G.state_dict(), os.path.join(model_path, 'warmup_G.pt'))
            torch.save(opt_G.state_dict(), os.path.join(model_path, 'warmup_opt_G.pt'))
            torch.save(D.state_dict(), os.path.join(model_path, 'warmup_D.pt'))
            torch.save(opt_D.state_dict(), os.path.join(model_path, 'warmup_opt_D.pt'))
            min_rmse, min_uce = rmse, uce

G.load_state_dict(torch.load(os.path.join(model_path, 'warmup_G.pt')))
opt_G.load_state_dict(torch.load(os.path.join(model_path, 'warmup_opt_G.pt')))
D.load_state_dict(torch.load(os.path.join(model_path, 'warmup_D.pt')))
opt_D.load_state_dict(torch.load(os.path.join(model_path, 'warmup_opt_D.pt')))

stop_warmup = time.time()

## EVALUATE AND SHOW SAMPLE RESULT FROM WARMUP MODEL

In [None]:
rmse_wu_val, uce_wu_val = cwgan.evaluate_net_on_ds(G, validation_ds)
rmse_wu_tst, uce_wu_tst = cwgan.evaluate_net_on_ds(G, testing_ds)

l, x = testing_ds.get_full(idx)
x = dst.unnormalize_x(x)

G.apply(l, n_samples=4, fig=True, x=x)
plt.savefig(os.path.join(figure_path, "sample_ests_wu.png"))

## INITIALIZE SCHEDULERS

In [None]:
scheduler_G = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt_G, PARAMS['ITS_PER_CYCLE'])
scheduler_D = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt_D, PARAMS['ITS_PER_CYCLE'] * PARAMS['CRITIC_ITERATIONS'])

## CWGAN ENSEMBLING

In [None]:
start_ens = time.time()

print("--- ENSEMBLES ---")

for i in range(PARAMS['ENSEMBLE_SIZE']):

    print(f"Ensemble {i}")

    min_rmse = np.inf
    min_uce = np.inf

    for it in tqdm(range(PARAMS['ITS_PER_CYCLE'])):

        # Discriminator training
        for _ in range(PARAMS['CRITIC_ITERATIONS']):
            l, x = training_ds.get_batch(PARAMS['BATCH_SIZE'])
            loss_D = D.training_iteration(l, x, G, optimizer=opt_D, scheduler=scheduler_D, weight_clip=PARAMS['WEIGHT_CLIP'])

        # Generator training
        l, _ = testing_ds.get_batch(PARAMS['BATCH_SIZE']) if PARAMS['FINE_TUNING'] else training_ds.get_batch(PARAMS['BATCH_SIZE'])
        loss_G = G.training_iteration(l, D, optimizer=opt_G, scheduler=scheduler_G)

        if it % 100 == 0:
            rmse, uce = cwgan.evaluate_net_on_ds(G, validation_ds)
            rmse_log.append(rmse)
            uce_log.append(uce)

            if uce < min_uce:
                torch.save(G.state_dict(), os.path.join(model_path, f"ens_{i}_G.pt"))
                min_rmse, min_uce = rmse, uce

    l, x = testing_ds.get_full(idx)
    x = dst.unnormalize_x(x)
    G.apply(l, n_samples=4, fig=True, x=x)
    plt.savefig(os.path.join(figure_path, f"sample_ests_{i}.png"))

stop_ens = time.time()

## LOSS CURVE

In [None]:
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.plot((np.arange(len(rmse_log)) + 1) * 100, np.array(rmse_log), 'g-')
ax2.plot((np.arange(len(uce_log)) + 1) * 100, np.array(uce_log), 'b-')
ax1.set_xlabel('Training Iteration')
ax1.set_ylabel('RMSE', color='g')
ax2.set_ylabel('UCE', color='b')

plt.savefig(os.path.join(figure_path, f"losscurve.png"))

## EVALUATE ENSEMBLE

In [None]:
ensemble_rmses_val, ensemble_uces_val, rmses_val, uces_val = cwgan.evaluate_ensemble_on_ds(model_path, validation_ds, module=cwgan.Generator, fig=False)
ensemble_rmses_tst, ensemble_uces_tst, rmses_tst, uces_tst = cwgan.evaluate_ensemble_on_ds(model_path, testing_ds, module=cwgan.Generator, fig=True)
plt.savefig(os.path.join(figure_path, "calib_curve.png"))

## LOG RESULTS

In [None]:
with open(os.path.join(result_path, 'res.txt'), "w") as file:
    print("Pre-training took", stop_PT-start_PT, "seconds.", file=file)
    print(f"Validation after pre-training - RMSE: {rmse_pt_val}, UCE: {uce_pt_val}", file=file)
    print(f"Test after pre-training - RMSE: {rmse_pt_tst}, UCE: {uce_pt_tst}", file=file)
    print("Adversarial training took", (stop_warmup-start_warmup) + (stop_ens-start_ens), "seconds.", file=file)
    print(f'Validation after warmup - RMSE: {rmse_wu_val}, UCE: {uce_wu_val}', file=file)
    print(f'Testing after warmup - RMSE: {rmse_wu_tst}, UCE: {uce_wu_tst}', file=file)
    print(f"Final Ensemble Validation - Ensemble RMSE: {np.mean(ensemble_rmses_val)}, Ensemble UCE: {np.mean(ensemble_uces_val)}, Individual RMSE: {np.mean(rmses_val)}, Individual UCE: {np.mean(uces_val)}", file=file)
    print(f"Final Ensemble Test - Ensemble RMSE: {np.mean(ensemble_rmses_tst)}, Ensemble UCE: {np.mean(ensemble_uces_tst)}, Individual RMSE: {np.mean(rmses_tst)}, Individual UCE: {np.mean(uces_tst)}", file=file)