In [None]:
import datetime
import time

import matplotlib.pyplot as plt
import torch
import yaml
from hy2dl.modelzoo import get_model
from hy2dl.utils.config import Config
from hy2dl.utils.utils import set_random_seed
from utilities.data import DataHandler
from utilities.postprocessing import Postprocessor
from utilities.training import Trainer

### Part 1. Setup

In [None]:
# Load configuration
config = yaml.safe_load(open("files/camels_gb.yml"))

# Modify configuration
config["experiment_name"] = "CAMELS-GB_LSTM-064_NSE"
config["path_save_folder"] = "results/run_LSTM"
config["model"] = "CudaLSTM"

# Convert into 'Config' object
config = Config(config)
config.init_experiment()
config.dump()

### Part 2. Load data

In [None]:
# Get data
handler_data = DataHandler(config)
handler_data.load_data()

basin_ids = handler_data.get_basin_ids()

loader_training = handler_data.get_loader("training")
loader_validation = handler_data.get_loader("validation")
dataloaders = {
    "training": loader_training,
    "validation": loader_validation
}

### Part 3. Model

In [None]:
# Initialize model
set_random_seed(cfg=config)
model = get_model(config).to(config.device)

In [None]:
# Load model
model.load_state_dict(
    torch.load(config.path_save_folder / "model/model_epoch_09.pt")
)

### Part 4. Training

In [None]:
# Start trainer
handler_training = Trainer(config, dataloaders, model)

# Get list of learning rates
num_epochs = config.epochs

lrs = list(range(1, num_epochs + 1))
lrs = [config.learning_rate[max(k for k in config.learning_rate if k <= num)] for num in lrs]

In [None]:
# Start training and report
config.logger.info("Starting training")
config.logger.info(f"{'':^5} | {'':^8} | {'Trainining':^30} | {'Validation':^30} |")
config.logger.info(f"{'Epoch':^5} | {'LR':^8} | {'Loss':^8} | {'NSE':^8} | {'Time':^8} | {'Loss':^8} | {'NSE':^8} | {'Time':^8} |")

time_training = time.time()
for epoch in range(num_epochs):
    # Set learning rate
    handler_training.optimizer.update_optimizer_lr(epoch=(epoch + 1))

    # Train
    loss_train, nse_train, time_train = handler_training.run_epoch("training")
    if (epoch + 1) % config.validate_every != 0:
        config.logger.info(f"{epoch + 1:^5} | {lrs[epoch]:^8.1e} | {loss_train:^8.4f} | {nse_train:^8.4f} | {time_train:^8} | {'':^8} | {'':^8} | {'':^8} |")
        continue
    
    # Validate
    loss_val, nse_val, time_val = handler_training.run_epoch("validation")
    config.logger.info(f"{epoch + 1:^5} | {lrs[epoch]:^8.1e} | {loss_train:^8.4f} | {nse_train:^8.4f} | {time_train:^8} | {loss_val:^8.4f} | {nse_val:^8.4f} | {time_val:^8} |")

# Finish training
time_training = str(datetime.timedelta(seconds=int(time.time() - time_training)))
config.logger.info("Run completed successfully")
config.logger.info(f"Total run time: {time_training}")

### Part 5. Postprocess model

In [None]:
handler_postprocessing = Postprocessor(config)
results = handler_postprocessing.postprocess(model)

In [None]:
basin_id = "73014"

fig, ax = plt.subplots(figsize=(30, 5))
results.sel(basin=basin_id, last_n=365).y_obs.plot(ax=ax, label="Observed", color="tab:blue")
results.sel(basin=basin_id, last_n=365).y_hat.plot(ax=ax, label="Predicted", color="tab:orange")
plt.show()