In [1]:
import torch
from vit_pytorch.vit_for_small_dataset import ViT

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision import transforms
from numpy.random import RandomState
from torch.utils.data import Subset
import numpy as np
import torch


batch_size = 128
n_workers = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cifar_data = CIFAR10(root="data", train=True, download=True,  transform= transforms.Compose([
            # transforms.RandomResizedCrop(image_size),
            transforms.RandomHorizontalFlip(),
            # transforms.GaussianBlur(3),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
        ]))

cifar_data_val = CIFAR10(root='.',train=True, transform= transforms.Compose([
            # transforms.RandomResizedCrop(image_size),
            # transforms.RandomHorizontalFlip(),
            
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
        ]), download=True)



def get_model(patch_size=16, dim=512, depth=3, heads=16):
    return ViT(
        image_size = 32,
        patch_size = patch_size,
        num_classes = 10,
        dim = dim,
        depth = depth,
        heads = heads,
        mlp_dim = 1024,
        dropout = 0.1,
        emb_dropout = 0.1
    )

def training_epoch_transformer(
    model, data_loader, optimizer, criterion
):
    all_loss = []
    accuracies = []
    model.train()
    
    for batch, (
        images, labels
    ) in enumerate(data_loader):
        optimizer.zero_grad()
        
        classification_scores = model(images.to(device))
        
        
        loss = criterion(classification_scores, labels.to(device))
        loss.backward()
        optimizer.step()
        accuracy = (classification_scores.argmax(dim=1) == labels.to(device)).float().mean()
        accuracies.append(accuracy.item())
        all_loss.append(loss.item())

        # tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss), all_loss, accuracies


def evaluate(
    model, data_loader,
):
    accuracies = []
    model.train()
    
    for batch, (
        images, labels
    ) in enumerate(data_loader):
        model.eval()
        classification_scores = model(images.to(device))
        
        accuracy = (classification_scores.argmax(dim=1) == labels.to(device)).float().mean()
        accuracies.append(accuracy.item())

    return  accuracies

def get_dataloader_transformer(dataset,dataset_val, n_workers, seed, batch_size=128):
    prng = RandomState(seed)
    random_permute = prng.permutation(np.arange(0, 500))
    classes =  prng.permutation(np.arange(0,10))
    indx_train = np.concatenate([np.where(np.array(dataset.targets) == classe)[0][random_permute[0:25]] for classe in classes[0:2]])
    indx_val = np.concatenate([np.where(np.array(dataset_val.targets) == classe)[0][random_permute[25:225]] for classe in classes[0:2]])
    train_targets = np.array(dataset.targets)[indx_train]
    val_targets = np.array(dataset.targets)[indx_val]
    train_data = Subset(cifar_data, indx_train)
    val_data = Subset(cifar_data_val, indx_val)
        
    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_workers,
        pin_memory=True,
    )    
    
    val_loader = DataLoader(
        val_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_workers,
        pin_memory=True,
    )
    return train_loader, val_loader

Files already downloaded and verified
Files already downloaded and verified


In [3]:
from requests import head
from sklearn import metrics
from torch.optim import SGD, Adam
from torch.nn import CrossEntropyLoss
from timm.scheduler.cosine_lr import CosineLRScheduler
from timm.loss import LabelSmoothingCrossEntropy
import pickle 
import json
from statistics import mean
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


n_epochs = 150

patch = 4
dim = 512
depth = 3
head = 4

metrics_map = {}


                
model_info = {'patch': patch, 'dim': dim, 'depth': depth, 'head': head}

model_name = json.dumps(model_info)
print(model_name)
v = get_model(patch, dim, depth, head)
v = v.to(device)


train_optimizer = Adam(v.parameters(), lr=3e-4)
criterion = LabelSmoothingCrossEntropy()

train_scheduler = CosineLRScheduler(
        train_optimizer,  t_initial=5
    )



losses = np.array([])
train_accuracies = np.array([])
test_accuracies = np.array([])
times = np.array([])
for seed in range(50):
    print("seed", seed)
    
    train_loader, val_loader = get_dataloader_transformer(cifar_data,cifar_data_val, n_workers, seed)
    
    start = torch.cuda.Event(enable_timing=True)
    start.record()
    for epoch in range(n_epochs):
        # print(f"Epoch {epoch}")
        average_loss, loss, accuracy = training_epoch_transformer(v, train_loader, train_optimizer, criterion)

        train_scheduler.step(average_loss)
        if epoch % 10 == 0:
            print(f"Epoch {epoch} - Loss: {average_loss}, lr: {train_optimizer.param_groups[0]['lr']}")
        losses = np.append(losses,loss)
        train_accuracies = np.append(train_accuracies,accuracy)
    end = torch.cuda.Event(enable_timing=True)
    end.record()
    torch.cuda.synchronize()
    times = np.append(times,(start.elapsed_time(end)/1000))
    #wait for everything to finish running
    test_accuracy = evaluate(
        v, val_loader
    )
    print(
        f"Accuracy: {100* np.array(test_accuracy).mean():.3f}%, time: {start.elapsed_time(end)/1000:.3f}s"
    )
    test_accuracies = np.append(test_accuracies,test_accuracy)

print(
        f"Accuracy: {100* np.array(test_accuracies).mean():.3f}% \u00B1 {100*np.array(test_accuracies).std():.3f}%, time: {np.array(times).mean():.3f}s \u00B1 {np.array(times).std():.3f}s"
    )
metrics_map[model_name] = {'losses': losses, 'train_accuracies': train_accuracies, 'test_accuracies': test_accuracies, 'times': times}
with open('metrics-vit.pkl', 'wb') as f:
    pickle.dump(metrics_map, f)



{"patch": 4, "dim": 512, "depth": 3, "head": 4}
seed 0
Epoch 0 - Loss: 2.724707841873169, lr: 0.00012889207611893125
Epoch 10 - Loss: 1.081413745880127, lr: 0.0002666856640680531
Epoch 20 - Loss: 0.7824822664260864, lr: 0.0002822334027554061
Epoch 30 - Loss: 0.5846447348594666, lr: 0.0002899927233517413
Epoch 40 - Loss: 0.538875162601471, lr: 0.000291483828394471
Epoch 50 - Loss: 0.5197283029556274, lr: 0.00029207296203177655
Epoch 60 - Loss: 0.5171729922294617, lr: 0.00029215003355009367
Epoch 70 - Loss: 0.5104131698608398, lr: 0.0002923521512686787
Epoch 80 - Loss: 0.5076977014541626, lr: 0.0002924326206005556
Epoch 90 - Loss: 0.5072008371353149, lr: 0.00029244729965323646
Epoch 100 - Loss: 0.5131145119667053, lr: 0.00029227168938500474
Epoch 110 - Loss: 0.5059419274330139, lr: 0.0002924844299501199
Epoch 120 - Loss: 0.5097729563713074, lr: 0.0002923711604941912
Epoch 130 - Loss: 0.5065673589706421, lr: 0.0002924659946070497
Epoch 140 - Loss: 0.5057558417320251, lr: 0.000292489910801