In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.data_util import MushroomDataset
from utils.model_util import TabularTransformer
from torch.utils.data import DataLoader
from dataclasses import dataclass


device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [2]:
# set parameters

@dataclass
class Train_Parameters:
    batch_size: int = 256 # number of examples per batch
    val_size: float = 0.1 # relative size of validation split
    n_eval: int = 1000 # evaluate model performance every n_eval steps
    epochs: int = 3 # number of training epochs

@dataclass
class Model_Parameters:    
    num_features: int = 20 # number of features in input data
    num_bins: int = 8 # number of bins in k-bins discretizer
    d_model: int = 32 # dimension of model
    d_ff: int = 64 # dimension of feed forward layer
    num_layers: int = 4 # number of decoder layers
    num_heads: int = 4 # number of heads
    dropout: float = 0.3 # dropout rate

tparam = Train_Parameters()
mparam = Model_Parameters()

In [3]:
# create dataset and dataloader objects

train_data = MushroomDataset(n_bins=mparam.num_bins, subset='train', preprocessors=None, val_size=tparam.val_size)
val_data = MushroomDataset(n_bins=mparam.num_bins, subset='val', preprocessors=[train_data.preprocessor, train_data.label_enc], val_size=tparam.val_size)
#test_data = MushroomDataset(n_bins=mparam.num_bins, subset='test', preprocessors=[train_data.preprocessor])

train_loader = DataLoader(train_data, batch_size=tparam.batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=tparam.batch_size)#len(val_data))
#test_loader = DataLoader(test_data, batch_size=tparam.batch_size)

In [4]:
# instantiate model and optimizer
 
model = TabularTransformer(
    num_features=mparam.num_features,
    num_bins=mparam.num_bins,
    d_model=mparam.d_model,
    num_layers=mparam.num_layers,
    num_heads=mparam.num_heads,
    d_ff=mparam.d_ff,
    dropout=mparam.dropout
)
model = model.to(device)
print(f'Number of params in model: {sum(p.numel() for p in model.parameters())}')
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

Number of params in model: 57729


In [5]:
# training loop

losses = []
losses_val = []
for epoch in range(tparam.epochs):
    model.train()
    print(f'epoch {epoch}:')
    for batch, (X, y) in enumerate(train_loader):
        X, y = X.to(device), y.to(device)
        pred, loss = model(X, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        if batch % tparam.n_eval == 0:
            with torch.no_grad():
                model.eval()
                loss_val = []
                current = batch * tparam.batch_size + len(X)
                for Xval, yval in val_loader:
                    Xval, yval = Xval.to(device), yval.to(device)
                    pred, loss = model(Xval, yval)
                    loss_val.append(loss.item())
                losses_val.append(np.mean(loss_val))
                print(f"loss: {np.mean(losses[-tparam.n_eval:]):>7f}  val loss: {losses_val[-1]:>7f}  [{current:>7d}/{len(train_loader.dataset):>7d}]")

epoch 0:
loss: 0.690962  val loss: 0.684374  [    256/2805250]
loss: 0.185256  val loss: 0.064665  [ 256256/2805250]
loss: 0.058561  val loss: 0.053011  [ 512256/2805250]
loss: 0.053062  val loss: 0.051209  [ 768256/2805250]
loss: 0.049275  val loss: 0.048218  [1024256/2805250]
loss: 0.049464  val loss: 0.047888  [1280256/2805250]
loss: 0.047933  val loss: 0.050003  [1536256/2805250]
loss: 0.046487  val loss: 0.045751  [1792256/2805250]
loss: 0.046397  val loss: 0.045278  [2048256/2805250]
loss: 0.045703  val loss: 0.046107  [2304256/2805250]
loss: 0.045942  val loss: 0.044801  [2560256/2805250]
epoch 1:
loss: 0.045455  val loss: 0.044144  [    256/2805250]
loss: 0.043912  val loss: 0.043903  [ 256256/2805250]
loss: 0.044347  val loss: 0.044172  [ 512256/2805250]
loss: 0.042591  val loss: 0.044402  [ 768256/2805250]
loss: 0.044300  val loss: 0.044664  [1024256/2805250]
loss: 0.042631  val loss: 0.043206  [1280256/2805250]
loss: 0.043845  val loss: 0.043747  [1536256/2805250]
loss: 0.04