In [None]:
from dataloader import load_raw, create_datasets, create_dataloaders
from visual import learning_curves, forecast_plot, Logger
from models import ModelConfigBuilder, forward, predict, EarlyStop
from measures import compute_intermittent_indicators, label_intermittent, quantile_loss

import os
from datetime import datetime
import numpy as np
from tqdm import tqdm
import pickle as pkl
import math
import torch
from gluonts.dataset.field_names import FieldName
from accelerate import Accelerator
from torch.optim import AdamW

In [None]:
# Import data
DNAME = "carparts"
data_raw, data_info = load_raw(dataset_name=DNAME, datasets_folder_path="../data")

# Compute intermittent indicators
adi, cv2 = compute_intermittent_indicators(data_raw)
data_info['intermittent'] = label_intermittent(adi, cv2, f="intermittent")
data_info['lumpy'] = label_intermittent(adi, cv2, f="lumpy")

# Create Datasets (train, valid, test) objects
datasets = create_datasets(data_raw, data_info)
data_info['N'] = len(datasets['train'])

In [None]:
# Model config
model_builder = ModelConfigBuilder(model="transformer", distribution_head="tweedie", scaling="mean-demand")
model_builder.build(data_info)
CONFIG = model_builder.params

# Dataloaders
BATCH_SIZE = 128
train_dataloader, valid_dataloader, test_dataloader = create_dataloaders(CONFIG, datasets, data_info, batch_size=BATCH_SIZE)

# Build the model
model = model_builder.get_model()

In [None]:
# batch = next(iter(train_dataloader))
# for k, v in batch.items():
#     print(k, v.shape, v.type())

In [None]:
# Training setup
accelerator = Accelerator(cpu=True)
device = accelerator.device
model.to(device)
optimizer = AdamW(model.parameters(), lr=6e-4, betas=(0.9, 0.95), weight_decay=1e-1)
model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(model, optimizer, train_dataloader, valid_dataloader)

logger = Logger()
early_stop = EarlyStop(logger, patience=20, min_delta=1e-3)

# Training loop
NUM_EPOCHS = 10
history = { 'train_loss': [], 'val_loss': []}
logger.log(f'Training on device={device}')
for epoch in range(NUM_EPOCHS):
    # 1. Training
    train_loss = 0.0
    model.train()
    for idx, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        loss = forward(model, batch, device, CONFIG)
        train_loss += loss.item()
        accelerator.backward(loss); optimizer.step()
    history['train_loss'].append(train_loss / idx)
    # 2. Validation
    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for idx, batch in enumerate(valid_dataloader):
            loss = forward(model, batch, device, CONFIG)
            val_loss += loss.item()
    history['val_loss'].append(val_loss / idx)
    logger.log_epoch(epoch, history)
    # 3. Early Stopping
    if early_stop.update(model, epoch, history['val_loss'][-1]): break

model_folder_name = model_builder.model+"__"+DNAME+"__"+datetime.now().strftime("%Y-%m-%d-%H:%M")
model_folder_path = os.path.join("..","trained_models",model_folder_name)
if not os.path.exists(path=model_folder_path):
    os.makedirs(model_folder_path)
    # 5. Plot of Learning curves
    learning_curves(history, path=model_folder_path)
    # 6. Save the model
    torch.save(early_stop.best_model, os.path.join("..","trained_models",model_folder_name,"model_state.model"))
    pkl.dump(model_builder, open(os.path.join("..","trained_models",model_folder_name,"model_builder.config"), "wb"))

In [None]:
# Load the model
model.load_state_dict(torch.load(os.path.join(model_folder_path,"model_state.model")))

# Prediction
model.eval()
forecasts = [predict(model, batch, device, CONFIG) 
             for batch in tqdm(test_dataloader, total=math.ceil(len(datasets['test']) / BATCH_SIZE))]
forecasts = np.vstack(forecasts)
actuals = np.array([x[FieldName.TARGET][-data_info['h']:] for x in datasets['test']])
actuals.shape, forecasts.shape
pkl.dump((actuals, forecasts), open(os.path.join(model_folder_path,"forecasts.pkl"), "wb"))

In [None]:
# Quantile Loss
metrics = {}
metrics['quantile_loss'] = quantile_loss(actuals, forecasts, q=[0.25, 0.5, 0.8, 0.9, 0.95, 0.99])
pkl.dump(metrics, open(os.path.join(model_folder_path,"metrics.pkl"), "wb"))