# UNet with Idealized Grass Dataset
### [Link to MLFlow Dashboard](http://localhost:8443/mlflow)

## Imports

In [None]:
from __future__ import unicode_literals, print_function, division
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.utils import data
import time
from unet.unets import U_net
from unet.utils_unet import train_epoch, eval_epoch, test_epoch
from unet.dataset import IdealizedGrasslands
import warnings
import mlflow
import imageio
from pathlib import Path
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings("ignore")

## Training

In [None]:
# set paths
model_name = 'U_Net'
results_dir=Path('results')
graphs_directory = results_dir / 'graphs'
pics_temp_directory = results_dir / 'pics'
movie_directory = pics_temp_directory / 'movie'
graphs_directory.mkdir(parents=True, exist_ok=True)
movie_directory.mkdir(parents=True, exist_ok=True)
model_filename=f"{model_name}_model.pth"

In [None]:
# create experiment(if not existing) and start new mlflow run
mlflow.end_run()
suffix = '_' + os.getenv('JUPYTERHUB_USER').split("@")[0]  # experiment name will contain username for uniqueness
experiment_name = model_name+suffix
try:
    mlflow.create_experiment(experiment_name)
except:
    pass
mlflow.set_experiment(experiment_name)
mlflow.start_run()


# set parameters
min_mse=10
output_length=100
input_length=7
learning_rate=0.001
dropout_rate=0
kernel_size=3
batch_size=1
max_epochs=1

# split data
train_indices=list(range(0,3))
valid_indices = list(range(3, 5))

# confirure model
model=U_net(input_channels = input_length, output_channels = 1, kernel_size = kernel_size,
            dropout_rate = dropout_rate).to(device)
train_set = IdealizedGrasslands(train_indices, input_length , 15, output_length, "train", file='uniform-pgml-success_list_simulation_runs.csv')
valid_set =IdealizedGrasslands(valid_indices, input_length , 15, output_length, "test", file='uniform-pgml-success_list_simulation_runs.csv')
train_loader = data.DataLoader(train_set, batch_size = batch_size, shuffle = True, num_workers = 1)
valid_loader = data.DataLoader(valid_set, batch_size = batch_size, shuffle = False, num_workers = 1)
loss_fun = torch.nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), learning_rate, betas = (0.9, 0.999), weight_decay = 4e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 1, gamma = 0.9)

# log parameters into MLFlow
mlflow.log_param("learning_rate", learning_rate)
mlflow.log_param("batch_size", batch_size)
mlflow.log_param("loss_function", loss_fun)
mlflow.log_param("max_epochs", max_epochs)
mlflow.log_param("optimizer", str(optimizer))
mlflow.log_param("scheduler", str(scheduler))

# start training
train_mse = []
valid_mse = []
test_mse = []
for i in range(max_epochs):
    mlflow.log_metric("Current Training Epoch", i + 1)
    print(f'Epoch {i} started')
    start = time.time()
    torch.cuda.empty_cache()
    scheduler.step()
    model.train()
    teacher_force_ratio=np.maximum(0, 1 - i * 0.03)
    train_loss = train_epoch(train_loader, model, optimizer, loss_fun, teacher_force_ratio)
    train_mse.append(train_loss)
    model.eval()
    mse, preds, trues = eval_epoch(valid_loader, model, loss_fun)
    valid_mse.append(mse)

    # send training metrics to MLFlow
    mlflow.log_metric("Epoch Loss", train_loss, step=i)
    mlflow.log_metric("Epoch Validation", mse, step=i)
    
    if valid_mse[-1] < min_mse:
        min_mse = valid_mse[-1]
        best_model = model
        torch.save(best_model, results_dir / model_filename)
    end = time.time()
    if (len(train_mse) > 50 and np.mean(valid_mse[-5:]) >= np.mean(valid_mse[-10:-5])):
            break
    print(train_mse[-1], valid_mse[-1], round((end-start)/60,5))
    print(f'Epoch {i} ended')
    
# send model file to mlflow
mlflow.log_artifact(results_dir / model_filename)

## Testing

In [None]:
loss_fun = torch.nn.L1Loss()
best_model = torch.load(results_dir / model_filename)
test_indices = list(range(5, 7))
test_set = IdealizedGrasslands(test_indices, input_length , 15, output_length, 'test', file='uniform-pgml-success_list_simulation_runs.csv')
test_loader = data.DataLoader(test_set, batch_size = batch_size, shuffle = False, num_workers = 8)
loss_curve,preds, trues  = test_epoch(test_loader, best_model, loss_fun)

# send testing metrics to MLFlow
for i, mae_item in enumerate(loss_curve):
    mlflow.log_metric("Test MAE", mae_item, step=i)

# calculate mean MAE for all frames
mean_mae = np.mean(loss_curve)
mlflow.log_metric("Test Avg MAE", mean_mae)

# save testing results
torch.save({"preds": preds[:10],
            "trues": trues[:10],
            "loss_curve": loss_curve},
            results_dir / f"{model_name}_results.pt")

## Produce Visualizations

In [None]:
pt_file = torch.load(results_dir / f"{model_name}_results.pt")

# print the head of the file
y = pt_file['loss_curve']

# create x ticks for plot
x = [str(x) for x in range(len(y))]

# mae over timesteps graph
ax = plt.axes()
ax.plot(x, y)
plt.title(f'{model_name.upper()} Model MAE over the Time Steps')
plt.xlabel('Time Steps')
plt.ylabel('MAE')
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
plt.savefig(graphs_directory / f'{model_name}_mae.png')

# reset plots
plt.clf()
plt.cla()
plt.close()

# plot and save each timestamp picture
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

# generate png for each timestamp containing true and pred
filenames = []
for i in range(len(x)):
    # different models produce different shapes, so trying to support everything
    true = pt_file['trues'][0][i][0] if len(pt_file['trues'].shape) == 5 else pt_file['trues'][0][i]
    pred = pt_file['preds'][0][i][0] if len(pt_file['preds'].shape) == 5 else pt_file['preds'][0][i]
    axs[0].imshow(true, cmap='rainbow', origin="lower")
    axs[1].imshow(pred, cmap='rainbow', origin="lower")
    axs[0].set_title('Ground Truth')
    axs[1].set_title('Prediction')
    axs[0].set_xlabel('x_coord')
    axs[0].set_ylabel('y_coord')
    axs[1].set_xlabel('x_coord')
    axs[1].set_ylabel('y_coord')
    plt.suptitle(f'Time Step {x[i]}')
    plt.savefig(pics_temp_directory / f'{i}.png')
    filenames.append(pics_temp_directory / f'{i}.png')

fig.tight_layout()

# make animation
images = []
for filename in filenames:
    images.append(imageio.imread(filename))
animation_file = f'movie_{model_name}.gif'
imageio.mimsave(movie_directory / animation_file, images)

# reset plots
plt.clf()
plt.cla()
plt.close()

# send to MLFlow
mlflow.log_artifacts(str(movie_directory))
mlflow.log_artifacts(str(graphs_directory))
mlflow.end_run()