In [None]:
import vit
import training_utils
import dataset_utils
import torch
import wandb

In [None]:
device = torch.device('cuda:2')

In [None]:
model_config: vit.ViTConfig = {
    'd_model': 256,
    'num_heads': 8,
    'num_layers': 12,
    'd_ffn': 4*256,
    'dropout': 0.2,
    
    'image_size': 32,
    'image_channels': 3,
    'patch_size': 4,
    'out_classes': 10
}

train_config: training_utils.TrainConfig = {
    'num_steps': 10_000,
    'warmup_steps': 2_000,
    'optimizer': {
        'optim': "SGD",
        'base_lr': 1,
        'args': {
        }
    },
    'batches_per_step': 1,
    'eval_interval': 1000,
    'log_interval': 100,
    'autocast': True,
    'lr_scheduler': "like_transformer",
    'label_smoothing': 0.1  ,
    'clip_grad': None
}

dataset_config: dataset_utils.DatasetConfig = {
    'dataset': "CIFAR10",
    'augmentation': "AutoCIFAR10",
    'batch_size': 512,
    'num_workers': 8
}

In [None]:
wandb.init(project='vit-classifier', config={
    'dataset': dataset_config,
    'model': model_config,
    'train': train_config
})

In [None]:
model = vit.get_model(model_config)
optim = training_utils.get_optim(train_config, model)
lr_scheduler = training_utils.get_scheduler(train_config, optim, d_model=model_config['d_model'])
train_loader, test_loader = dataset_utils.get_dataloader(dataset_config, 'data')

criterion = torch.nn.CrossEntropyLoss(label_smoothing=train_config['label_smoothing'])

In [None]:
def calc_train_loss(model, batch: list[torch.Tensor]) -> torch.Tensor:
    model.train()
    img, label = batch
    img, label = img.to(device), label.to(device)
    img = vit.image_to_patches(img, model_config['patch_size'])
    
    pred = model(img)
    loss = criterion(pred, label)
    return loss

def eval_model(model) -> float:
    with torch.no_grad():
        model.eval()
        correct = 0
        total = 0
        for input, target in test_loader:
            input = input.to(device)
            target = target.to(device)
            input = vit.image_to_patches(input, model_config['patch_size'])

            predicted = model.predict(input)
            total += target.size(0)
            correct += (predicted == target).sum().item()

        accuracy = correct / total
    return accuracy

In [None]:
try:
    training_utils.train(
        model,
        train_config,
        optim,
        lr_scheduler,
        calc_train_loss,
        train_loader,
        eval_model,
        device
    )
finally:
    try:
        torch.save(model.state_dict(), 'models/cifar10.pt')
        wandb.log_model(path='models/cifar10.pt', name='cifar10or100')
        wandb.finish()
    finally:
        import os
        os._exit(00)