In [1]:
import os
import gc
import pickle
from pathlib import Path

import mxnet as mx
from gluonts.mx import Trainer
from gluonts.evaluation import Evaluator
from gluonts.core.component import validated
from gluonts.mx.trainer.callback import Callback
from gluonts.evaluation.backtest import make_evaluation_predictions
mx.random.seed(0)

from model import *
from utils import *

import warnings
warnings.filterwarnings("ignore")


In [2]:
class EarlyStoppingCallback(Callback):
    def __init__(self, patience: int = 10):
        self.patience = patience
        self.best_loss = float('inf')
        self.wait_count = 0
        
    def on_validation_epoch_end(
        self,
        epoch_no: int,
        epoch_loss: float,
        training_network: mx.gluon.HybridBlock,
        trainer: mx.gluon.Trainer,
    ) -> bool:
        if epoch_loss < self.best_loss:
            self.best_loss = epoch_loss
            self.wait_count = 0
        else:
            self.wait_count += 1
            
        if self.wait_count >= self.patience:
            print(f"\nEarly stopping triggered")
            return False
            
        return True

class EarlyStoppingTrainer(Trainer):
    @validated()
    def __init__(self, patience: int = 10, **kwargs):
        callbacks = kwargs.get('callbacks', [])
        callbacks.append(EarlyStoppingCallback(patience=patience))
        kwargs['callbacks'] = callbacks
        super().__init__(**kwargs)

In [None]:
def train_models(epochs, learning_rate):
    # save path
    save_dir = f'../result/epochs_{epochs}-learning_rate_{learning_rate}'
    os.makedirs(save_dir, exist_ok=True)
    
    for level_idx in range(1, 13):
        # level path
        level_dir = os.path.join(save_dir, f'level_{level_idx}')
        os.makedirs(level_dir, exist_ok=True)

        estimators = create_estimators(
                        level_idx=level_idx, 
                        train_dataset=dataset['train']
                    )
        highlight_print(f"Level {level_idx}: Loading dataset")
        if estimators == 'TFT':
            with open(os.path.join('../dataset/tft', f'dataset_level_{level_idx}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
        else:
            with open(os.path.join('../dataset/else', f'dataset_level_{level_idx}.pkl'), 'rb') as f:
                dataset = pickle.load(f)

        for estimator_name, estimator in estimators.items():
            # estimator path
            estimator_dir = os.path.join(level_dir, estimator_name)
            os.makedirs(estimator_dir, exist_ok=True)

            highlight_print(f"Level {level_idx}: Training {estimator_name}")
            estimator.trainer = EarlyStoppingTrainer(
                epochs=epochs,
                learning_rate=learning_rate,
                num_batches_per_epoch=get_optimal_num_batches(mx.context.num_gpus()),
                patience=10
            )
            predictor = estimator.train(
                            training_data=dataset['train'],
                            validation_data=dataset['test']
                        )
            predictor.serialize(Path(f"{level_dir}/{estimator_name}"))

            highlight_print(f"Level {level_idx}: Making predictions")
            train_forecasts_it, train_labels_it = make_evaluation_predictions(
                dataset=dataset['train'],
                predictor=predictor,
            )
            train_forecasts = list(train_forecasts_it)
            train_labels = list(train_labels_it)

            test_forecasts_it, test_labels_it = make_evaluation_predictions(
                dataset=dataset['test'],
                predictor=predictor,
            )
            test_forecasts = list(test_forecasts_it)
            test_labels = list(test_labels_it)

            highlight_print(f"Level {level_idx}: Saving predictions")
            with open(f"{level_dir}/{estimator_name}_train_labels.pkl", "wb") as f:
                pickle.dump(train_labels, f)
            with open(f"{level_dir}/{estimator_name}_train_forecasts.pkl", "wb") as f:
                pickle.dump(train_forecasts, f)
            with open(f"{level_dir}/{estimator_name}_test_labels.pkl", "wb") as f:
                pickle.dump(test_labels, f)
            with open(f"{level_dir}/{estimator_name}_test_forecasts.pkl", "wb") as f:
                pickle.dump(test_forecasts, f)

            highlight_print(f"Level {level_idx}: Evaluating predictions")
            evaluator = Evaluator(quantiles=(0.5,), ignore_invalid_values=True)
            train_metrics_all_id, train_metrics_per_id = evaluator(train_labels, train_forecasts)
            test_metrics_all_id, test_metrics_per_id = evaluator(test_labels, test_forecasts)

            highlight_print(f"Level {level_idx}: Saving evaluations")
            with open(f"{level_dir}/{estimator_name}_train_metrics_all_id.pkl", "wb") as f:
                pickle.dump(train_metrics_all_id, f)
            with open(f"{level_dir}/{estimator_name}_train_metrics_per_id.pkl", "wb") as f:
                pickle.dump(train_metrics_per_id, f)
            with open(f"{level_dir}/{estimator_name}_test_metrics_all_id.pkl", "wb") as f:
                pickle.dump(test_metrics_all_id, f)
            with open(f"{level_dir}/{estimator_name}_test_metrics_per_id.pkl", "wb") as f:
                pickle.dump(test_metrics_per_id, f)

            # reduce memory
            del dataset, predictor, train_forecasts, train_labels, test_forecasts, test_labels
            gc.collect()

train_models(epochs=300, learning_rate=1e-3)