In [21]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.data_util import MushroomDataset
from architectures.Transformer import Model
from torch.utils.data import DataLoader
from dataclasses import dataclass


device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [19]:
# set parameters

@dataclass
class Train_Parameters:
    batch_size: int = 256 # number of examples per batch
    val_size: float = 0.1 # relative size of validation split
    n_eval: int = 100 # evaluate model performance every n_eval steps
    max_steps: int = 1001 # maximum number of steps in training

@dataclass
class Model_Parameters:
    d_model: int = 16 # dimension of model
    n_embed: int = 8 # dimension of embedding
    n_heads: int = 2 # number of heads
    head_size: int = n_embed//n_heads # head size
    dropout: float = 0.4 # dropout rate
    n_in: int = 20 # number of columns in input tensor


tparam = Train_Parameters()
mparam = Model_Parameters()

In [20]:
# create dataset and dataloader objects

train_data = MushroomDataset(n_bins=mparam.n_in, subset='train', preprocessors=None, val_size=tparam.val_size)
val_data = MushroomDataset(n_bins=mparam.n_in, subset='val', preprocessors=[train_data.preprocessor, train_data.label_enc], val_size=tparam.val_size)
test_data = MushroomDataset(n_bins=mparam.n_in, subset='test', preprocessors=[train_data.preprocessor])

train_loader = DataLoader(train_data, batch_size=tparam.batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=len(val_data))
test_loader = DataLoader(test_data, batch_size=tparam.batch_size)

In [22]:
model = Model(config=mparam)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()

In [None]:
size = len(train_loader.dataset)
losses = []
for epoch in range(10):
    print(f'epoch {epoch}:')
    for batch, (X, y) in enumerate(train_loader):
        X, y = X.to(device), y.to(device)
        pred, loss = model(X, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        if batch % 10000 == 0:
            current = batch * 32 + len(X)
            print(f"loss: {np.mean(losses[-10000:]):>7f}  [{current:>5d}/{size:>5d}]")