In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR
from utils.data_util import CarsDataset
from utils.model_util import TabularTransformer
from torch.utils.data import DataLoader
from dataclasses import dataclass


device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [2]:
def save_checkpoint(state, is_best, checkpoint_dir='checkpoints'):
    torch.save(state, os.path.join(checkpoint_dir, 'latest.pth'))
    if is_best:
        torch.save(state, os.path.join(checkpoint_dir, 'best.pth'))

def load_checkpoint(model, optimizer, scheduler, filename='checkpoints/best.pth'):
    if os.path.isfile(filename):
        checkpoint = torch.load(filename)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print(f"Checkpoint loaded: {filename} (Epoch {start_epoch})")
        return start_epoch, loss
    else:
        print(f"No checkpoint found at {filename}")
        return 0, None

In [3]:
# set parameters

@dataclass
class Train_Parameters:
    batch_size: int = 64 # number of examples per batch
    val_size: float = 0.2 # relative size of validation split
    n_eval: int = 200 # evaluate model performance every n_eval steps
    epochs: int = 50 # number of training epochs

@dataclass
class Model_Parameters:    
    num_features: int = 15 # number of features in input data
    num_bins: int = 64 # number of bins in k-bins discretizer
    d_model: int = 128 # dimension of model
    d_ff: int = 256 # dimension of feed forward layer
    num_layers: int = 4 # number of decoder layers
    num_heads: int = 8 # number of heads
    dropout: float = 0.4 # dropout rate

tparam = Train_Parameters()
mparam = Model_Parameters()

In [4]:
# create dataset and dataloader objects

train_data = CarsDataset(n_bins=mparam.num_bins, subset='train', preprocessors=None, val_size=tparam.val_size)
val_data = CarsDataset(n_bins=mparam.num_bins, subset='val', preprocessors=[train_data.preprocessor, train_data.scaler], val_size=tparam.val_size)

train_loader = DataLoader(train_data, batch_size=tparam.batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=tparam.batch_size)



In [5]:
# instantiate model and optimizer

model = TabularTransformer(
    num_features=mparam.num_features,
    num_bins=mparam.num_bins,
    d_model=mparam.d_model,
    num_layers=mparam.num_layers,
    num_heads=mparam.num_heads,
    d_ff=mparam.d_ff,
    dropout=mparam.dropout
)
model = model.to(device)
print(f'Number of params in model: {sum(p.numel() for p in model.parameters())}')
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = OneCycleLR(
    optimizer=optimizer,
    max_lr=4e-5,
    epochs=tparam.epochs,
    steps_per_epoch=len(train_loader),
    pct_start=0.1,
    div_factor=1e2,
    final_div_factor=1e3
    )

Number of params in model: 921857


In [6]:
# training loop
def train_model(model, optimizer, scheduler, train_loader, val_loader, training_parameters, checkpoint_path):
    losses = []
    losses_val = []
    start_epoch = 0

    if os.path.exists(checkpoint_path):
        start_epoch, loss = load_checkpoint(model, optimizer, scheduler, checkpoint_path)

    for epoch in range(start_epoch, training_parameters.epochs):
        model.train()
        print(f'epoch {epoch+1}:')
        for batch, (X, y) in enumerate(train_loader):
            X, y = X.to(device), y.to(device)
            pred, loss = model(X, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            scheduler.step()
            if batch % training_parameters.n_eval == 0:
                with torch.no_grad():
                    model.eval()
                    loss_val = []
                    current = batch * training_parameters.batch_size + len(X)
                    for Xval, yval in val_loader:
                        Xval, yval = Xval.to(device), yval.to(device)
                        pred, loss = model(Xval, yval)
                        loss_val.append(loss.item())
                    losses_val.append(np.mean(loss_val))
                    print(f"loss: {np.mean(losses[-training_parameters.n_eval:]):>7f}  val loss: {losses_val[-1]:>7f}  current lr: {scheduler.get_last_lr()[0]:>7f}  [{current:>7d}/{len(train_loader.dataset):>7d}]")
                    save_checkpoint({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'loss': loss.item(),
                    }, losses_val[-1] == np.min(losses_val))

In [7]:
train_model(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=train_loader,
    val_loader=val_loader,
    training_parameters=tparam,
    checkpoint_path='checkpoints/best.pth'
)

epoch 1:
loss: 0.222776  val loss: 0.303354  current lr: 0.000000  [     64/ 150826]
loss: 0.253211  val loss: 0.233876  current lr: 0.000000  [  12864/ 150826]
loss: 0.211925  val loss: 0.215194  current lr: 0.000001  [  25664/ 150826]
loss: 0.201455  val loss: 0.198381  current lr: 0.000001  [  38464/ 150826]
loss: 0.171933  val loss: 0.181971  current lr: 0.000001  [  51264/ 150826]
loss: 0.164995  val loss: 0.166959  current lr: 0.000001  [  64064/ 150826]
loss: 0.164163  val loss: 0.154128  current lr: 0.000001  [  76864/ 150826]
loss: 0.139669  val loss: 0.143819  current lr: 0.000002  [  89664/ 150826]
loss: 0.126784  val loss: 0.135963  current lr: 0.000002  [ 102464/ 150826]
loss: 0.135581  val loss: 0.129478  current lr: 0.000003  [ 115264/ 150826]
loss: 0.135520  val loss: 0.124156  current lr: 0.000003  [ 128064/ 150826]
loss: 0.124092  val loss: 0.119732  current lr: 0.000004  [ 140864/ 150826]
epoch 2:
loss: 0.113701  val loss: 0.117102  current lr: 0.000004  [     64/ 15