In [10]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR
from utils.data_util import MushroomDataset
from utils.model_util import TabularTransformer
from torch.utils.data import DataLoader
from dataclasses import dataclass


device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [2]:
def save_checkpoint(state, is_best, checkpoint_dir='checkpoints'):
    torch.save(state, os.path.join(checkpoint_dir, 'latest.pth'))
    if is_best:
        torch.save(state, os.path.join(checkpoint_dir, 'best.pth'))

def load_checkpoint(model, optimizer, scheduler, filename='checkpoints/best.pth'):
    if os.path.isfile(filename):
        checkpoint = torch.load(filename)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print(f"Checkpoint loaded: {filename} (Epoch {start_epoch})")
        return start_epoch, loss
    else:
        print(f"No checkpoint found at {filename}")
        return 0, None

In [3]:
# set parameters

@dataclass
class Train_Parameters:
    batch_size: int = 128 # number of examples per batch
    val_size: float = 0.2 # relative size of validation split
    n_eval: int = 1000 # evaluate model performance every n_eval steps
    epochs: int = 20 # number of training epochs

@dataclass
class Model_Parameters:    
    num_features: int = 20 # number of features in input data
    num_bins: int = 16 # number of bins in k-bins discretizer
    d_model: int = 64 # dimension of model
    d_ff: int = 128 # dimension of feed forward layer
    num_layers: int = 4 # number of decoder layers
    num_heads: int = 8 # number of heads
    dropout: float = 0.3 # dropout rate

tparam = Train_Parameters()
mparam = Model_Parameters()

In [4]:
# create dataset and dataloader objects

train_data = MushroomDataset(n_bins=mparam.num_bins, subset='train', preprocessors=None, val_size=tparam.val_size)
val_data = MushroomDataset(n_bins=mparam.num_bins, subset='val', preprocessors=[train_data.preprocessor, train_data.label_enc], val_size=tparam.val_size)

train_loader = DataLoader(train_data, batch_size=tparam.batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=tparam.batch_size)

In [5]:
# instantiate model and optimizer

model = TabularTransformer(
    num_features=mparam.num_features,
    num_bins=mparam.num_bins,
    d_model=mparam.d_model,
    num_layers=mparam.num_layers,
    num_heads=mparam.num_heads,
    d_ff=mparam.d_ff,
    dropout=mparam.dropout
)
model = model.to(device)
print(f'Number of params in model: {sum(p.numel() for p in model.parameters())}')
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = OneCycleLR(
    optimizer=optimizer,
    max_lr=4e-3,
    epochs=tparam.epochs,
    steps_per_epoch=len(train_loader),
    pct_start=0.1,
    div_factor=1e2,
    final_div_factor=1e3
    )

Number of params in model: 224001


In [6]:
# training loop
def train_model(model, optimizer, scheduler, train_loader, val_loader, training_parameters, checkpoint_path):
    losses = []
    losses_val = []
    start_epoch = 0

    if os.path.exists(checkpoint_path):
        start_epoch, loss = load_checkpoint(model, optimizer, scheduler, checkpoint_path)

    for epoch in range(start_epoch, training_parameters.epochs):
        model.train()
        print(f'epoch {epoch+1}:')
        for batch, (X, y) in enumerate(train_loader):
            X, y = X.to(device), y.to(device)
            pred, loss = model(X, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            scheduler.step()
            if batch % training_parameters.n_eval == 0:
                with torch.no_grad():
                    model.eval()
                    loss_val = []
                    current = batch * training_parameters.batch_size + len(X)
                    for Xval, yval in val_loader:
                        Xval, yval = Xval.to(device), yval.to(device)
                        pred, loss = model(Xval, yval)
                        loss_val.append(loss.item())
                    losses_val.append(np.mean(loss_val))
                    print(f"loss: {np.mean(losses[-training_parameters.n_eval:]):>7f}  val loss: {losses_val[-1]:>7f}  current lr: {scheduler.get_last_lr()[0]:>7f}  [{current:>7d}/{len(train_loader.dataset):>7d}]")
                    save_checkpoint({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'loss': loss.item(),
                    }, losses_val[-1] == np.min(losses_val))

In [7]:
train_model(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=train_loader,
    val_loader=val_loader,
    training_parameters=tparam,
    checkpoint_path='checkpoints/best.pth'
)

epoch 1:
loss: 0.700886  val loss: 0.733996  current lr: 0.000040  [    128/2493556]
loss: 0.403431  val loss: 0.229578  current lr: 0.000046  [ 128128/2493556]
loss: 0.132397  val loss: 0.075370  current lr: 0.000066  [ 256128/2493556]
loss: 0.071698  val loss: 0.061263  current lr: 0.000098  [ 384128/2493556]
loss: 0.061488  val loss: 0.057201  current lr: 0.000142  [ 512128/2493556]
loss: 0.057911  val loss: 0.052744  current lr: 0.000199  [ 640128/2493556]
loss: 0.053713  val loss: 0.048775  current lr: 0.000267  [ 768128/2493556]
loss: 0.053073  val loss: 0.047904  current lr: 0.000347  [ 896128/2493556]
loss: 0.049654  val loss: 0.050179  current lr: 0.000438  [1024128/2493556]
loss: 0.047023  val loss: 0.046762  current lr: 0.000539  [1152128/2493556]
loss: 0.046543  val loss: 0.049181  current lr: 0.000650  [1280128/2493556]
loss: 0.046981  val loss: 0.057246  current lr: 0.000769  [1408128/2493556]
loss: 0.050369  val loss: 0.044877  current lr: 0.000897  [1536128/2493556]
los

In [8]:
# create dataset and dataloader for submission
test_data = MushroomDataset(n_bins=mparam.num_bins, subset='test', preprocessors=[train_data.preprocessor])
test_loader = DataLoader(test_data, batch_size=tparam.batch_size)

In [31]:
def make_submission(model, test_loader, submission_path='submission.csv'):
    model.eval()
    predicted_labels = []

    with torch.no_grad():
        for X in test_loader:
            X = X.to(device)
            logits, _ = model(X, None)
            probs = F.sigmoid(logits)
            preds = (probs > 0.5).long()
            predicted_labels += train_data.label_enc.inverse_transform(preds.cpu().squeeze()).tolist()

    submission_df = pd.DataFrame({'class':predicted_labels})
    submission_df['id'] = pd.read_csv('test.csv', usecols=['id'])
    submission_df[['id', 'class']].to_csv(submission_path, index=False)

    print(f'Submission file created at {submission_path}')


In [32]:
make_submission(
    model=model,
    test_loader=test_loader,
    submission_path='submission.csv'
)

Submission file created at submission.csv
