In [1]:
from src.models import resmlp_s12, vit_s12, STN, Localizer
from src.tokenizers import foveated_tokenizer, Patchify
from src.training import train, test
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
from src.resmlp import resmlp_s12_224

from torchvision.models import resnet18

In [2]:
def resnet():
    model = resnet18()
    model.fc = nn.Linear(512, 10)
    return model

In [3]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

def get_transforms():
    
    mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    
    train_transforms = [transforms.Resize(256),
                        transforms.TrivialAugmentWide(),
                        transforms.RandomResizedCrop((224,224), scale=(0.08, 1.0)),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=mean, std=std)]
    
    test_transforms = [transforms.Resize(256),
                       transforms.CenterCrop((224,224)),
                       transforms.ToTensor(),
                       transforms.Normalize(mean=mean, std=std)]
    
    return transforms.Compose(train_transforms), transforms.Compose(test_transforms)

def get_dataloaders(root, batch_size, **kwargs):
    
    train_transform, test_transform = get_transforms()
    
    _train_dataset = ImageFolder(root + "imagewoof2-320/train", transform=train_transform)
    _valid_dataset = ImageFolder(root + "imagewoof2-320/train", transform=test_transform)
    test_dataset = ImageFolder(root + "imagewoof2-320/val", transform=test_transform)
    
    # split the dataset 
    indices = torch.randperm(len(_train_dataset))
    val_size = len(_train_dataset)//10
    train_dataset = torch.utils.data.Subset(_train_dataset, indices[:-val_size])
    valid_dataset = torch.utils.data.Subset(_valid_dataset, indices[-val_size:])
    
    train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=4, **kwargs)
    valid_loader = DataLoader(valid_dataset, batch_size, shuffle=False, num_workers=4, **kwargs)
    test_loader = DataLoader(test_dataset, batch_size, shuffle=False, num_workers=4, **kwargs)
    
    return train_loader, valid_loader, test_loader
    
def init_model(n_classes):
    tokenizer = Patchify(224, 16, 3, 384)
    classifier = resmlp_s12(n_classes, init_value=0.1)
    return nn.Sequential(tokenizer, classifier)

def init_stn_optimizer(
    model,
    steps_per_epoch,
    num_epochs,
    max_lr,
    weight_decay,
):
    
    weight_decay_params = []
    no_weight_decay_params = []
    
    for name, param in model.named_parameters():
        if ('bias' in name) or ('norm' in name) or ('saliency' in name) or ('layerscale' in name):
            no_weight_decay_params.append(param)
        else:
            weight_decay_params.append(param)
            
    opt = AdamW([{'params': weight_decay_params, 'lr': max_lr, 'weight_decay': weight_decay},
                 {"params": no_weight_decay_params, 'lr': max_lr, 'weight_decay': 0.0}])
        
    
    scheduler = OneCycleLR(
        opt,
        max_lr=max_lr,
        total_steps = steps_per_epoch * num_epochs,
    )
    
    scaler = torch.cuda.amp.GradScaler()
    return opt, scheduler, scaler
                 

In [4]:
EPOCHS = 90
BATCH_SIZE = 256

train_loader, valid_loader, test_loader = get_dataloaders("../Datasets/", batch_size=BATCH_SIZE)

model = init_model(10)
optimizer, scheduler, scaler = init_stn_optimizer(model, len(train_loader), EPOCHS, 0.004, 0.0125)

train(model, optimizer, scheduler, scaler, train_loader, valid_loader, EPOCHS)

test(model, test_loader)

### Training ###

Epoch 0: [------------------------------]

Batch: 0 - Loss: 2.36 --- Accuracy: 8.59
Batch: 5 - Loss: 2.29 --- Accuracy: 15.23
Batch: 10 - Loss: 2.32 --- Accuracy: 14.84
Batch: 15 - Loss: 2.25 --- Accuracy: 17.58
Batch: 20 - Loss: 2.27 --- Accuracy: 13.28
Batch: 25 - Loss: 2.23 --- Accuracy: 13.67
Batch: 30 - Loss: 2.23 --- Accuracy: 16.80

Average Train Loss: 2.26 - Average Train Accuracy 15.14

### Validation and Checkpointing ###

Average Validation Loss: 2.19 - Average Validation Accuracy 21.18

Epoch 1: [------------------------------]

Batch: 0 - Loss: 2.22 --- Accuracy: 18.36
Batch: 5 - Loss: 2.20 --- Accuracy: 21.48
Batch: 10 - Loss: 2.21 --- Accuracy: 18.36
Batch: 15 - Loss: 2.21 --- Accuracy: 15.23
Batch: 20 - Loss: 2.20 --- Accuracy: 17.19
Batch: 25 - Loss: 2.18 --- Accuracy: 22.66
Batch: 30 - Loss: 2.19 --- Accuracy: 16.41

Average Train Loss: 2.21 - Average Train Accuracy 18.60

### Validation and Checkpointing ###

Average Validation Loss: 2.09 - Average 


KeyboardInterrupt

