In [None]:
from dataloader import load_raw, create_datasets, create_dataloaders
from visual import learning_curves, forecast_plot
from models import ModelConfig, forward, predict
from measures import compute_intermittent_indicators, label_intermittent

import numpy as np
from tqdm import tqdm
import math
import torch
from gluonts.dataset.field_names import FieldName
from accelerate import Accelerator
from torch.optim import AdamW

In [None]:
# Import data
data_raw, data_info = load_raw(dataset_name="carparts", datasets_folder_path="../data")

# Compute intermittent indicators
adi, cv2 = compute_intermittent_indicators(data_raw)
data_info['intermittent'] = label_intermittent(adi, cv2, f="intermittent")
data_info['lumpy'] = label_intermittent(adi, cv2, f="lumpy")

# Create Datasets (train, valid, test) objects
datasets = create_datasets(data_raw, data_info)

# Model config
CONFIG = ModelConfig(datasets, data_info, model="deepAR")
CONFIG.batch_size = 128
CONFIG.model_name = "xxx"

# Dataloaders
train_dataloader, valid_dataloader, test_dataloader = create_dataloaders(CONFIG, datasets, data_info, batch_size=CONFIG.batch_size)

# Build the model
model = CONFIG.build_model(distribution_head="negative_binomial", scaling=None)

In [None]:
# Training setup
accelerator = Accelerator(cpu=True)
device = accelerator.device
model.to(device)
optimizer = AdamW(model.parameters(), lr=6e-4, betas=(0.9, 0.95), weight_decay=1e-1)
model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(model, optimizer, train_dataloader, valid_dataloader)
print(f'Training on device={device}')

# Training loop
NUM_EPOCHS = 10
best_val_loss = np.inf
best_model = None
PATIENCE = 20
min_delta = 0.001
current_patience = 0
history = { 'train_loss': [], 'val_loss': []}

for epoch in range(NUM_EPOCHS):
    train_loss = 0.0
    model.train()
    for idx, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        loss = forward(model, batch, device, CONFIG)
        train_loss += loss.item()
        accelerator.backward(loss)
        optimizer.step()

    history['train_loss'].append(train_loss / idx)
    
    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for idx, batch in enumerate(valid_dataloader):
            loss = forward(model, batch, device, CONFIG)
            val_loss += loss.item()
    history['val_loss'].append(val_loss / idx)

    print(f"Epoch {epoch+1} \t Train Loss: {history['train_loss'][-1]:.3f} \t Val Loss: {history['val_loss'][-1]:.3f}")  # Log

    if history['val_loss'][-1] < best_val_loss - min_delta:  # Early stopping
        best_val_loss = history['val_loss'][-1]
        best_model = model.state_dict()
        print(f"Early stopping, new validation best: {best_val_loss:.3f}, keep training!")
        current_patience = -1
    current_patience += 1
        
    if current_patience == PATIENCE:
        print(f"Early stopping after {epoch+1} epochs. Validation best: {best_val_loss:.3f}")
        break

# Load the best model state into the model
model.load_state_dict(best_model)

# Learning curves plot
learning_curves(history, figsize=(10,3))

In [None]:
# Save the model
torch.save(model.state_dict(), '../trained_models/'+CONFIG.model_name+'.model')

# Load the model
model.load_state_dict(torch.load('../trained_models/'+CONFIG.model_name+'.model'))

In [None]:
# Prediction
model.eval()
forecasts = [predict(model, batch, device, CONFIG) 
             for batch in tqdm(test_dataloader, total=math.ceil(len(datasets['test']) / CONFIG.batch_size))]
forecasts = np.vstack(forecasts)
forecasts.shape

In [None]:
# Forecast plot
forecast_plot(10, forecasts, datasets, data_info, FieldName.TARGET)