# Imports

In [None]:
import os, sys
PROJ_ROOT = os.path.abspath('/home/jovyan')
os.chdir(PROJ_ROOT)
PROJ_SRC = os.path.abspath('/home/jovyan/src')
sys.path.append(PROJ_SRC)

In [None]:
import time
from pprint import pprint
from typing import Any, Dict

from src.helpers.utils import dump_json, load_json, create_dir
from src.s1_data_loaders.TransformerDataLoader import TransformerDataLoader as DataLoader
from src.s2_data_transformers.TransformerDataTransformer import TransformerDataTransformer as DataTransformer
from src.s3_models.TransformerModel import TransformerModel as Model
from src.s4_trainers.TransformerTrainer import TransformerTrainer as Trainer
from src.s5_evaluaters.TransformerEvaluater import TransformerEvaluater as Evaluater

# Parameters

In [None]:
# Default parameters will be overwritten by papermill parameters in a next bloc
config = 'config/transformer.json'

In [None]:
print('Using config: {}'.format(config))
config = load_json(config)
experiment_time = time.strftime("%Y-%m-%d-%Hh%M", time.localtime())

# Data

In [None]:
data_loader = DataLoader(config)

train_data_transformer = DataTransformer(config, data_loader.train_data())
test_data_transformer = DataTransformer(config, data_loader.test_data())

data = {
    "train": train_data_transformer.get_transformed_data(),
    "test": test_data_transformer.get_transformed_data()
}

# Model

In [None]:
model = Model(config)

# Trainer

In [None]:
trainer = Trainer(config, model.model, data, experiment_time=experiment_time)
trainer.train()

# Evaluater

In [None]:
evaluater = Evaluater(model, data["test"])
performances = evaluater.summary()
pprint(performances)

# Report

In [None]:
report = {
    "model": config["name"],
    "training_data": len(data["train"]["X"]),
    "performances": performances,
    "created_at": experiment_time,
    "config": config
}

# Saving

In [None]:
# Directory
output_path = os.path.join("data/saved", config["name"], experiment_time)
create_dir(output_path)

# Report
report_path = os.path.join(output_path, "report.json")
dump_json(report, report_path, sort_keys=True, indent=2)

# Model
model.save(output_path)