## Read dataset

In [30]:
# -----------------Have a look at the CIFAR-10 dataset-----------------
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
    return dict

data = unpickle('data/unpacked/train')
print(data.keys())



dict_keys(['filenames', 'batch_label', 'fine_labels', 'coarse_labels', 'data'])


In [51]:
# ------------------Load data with data augmentation-------------------
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import v2
from torch.utils.data import default_collate
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
print(f"device: {device}")

CIFAR_PATH = "data"
num_coarse_classes = 20
num_fine_classes = 100

# Calculated mean and standard deviation of image channels for normalization
mean = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
std = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]

# Number of worker threads used for loading data
num_workers = 0

# CutMix augmentation
cutmix = v2.CutMix(num_classes=100)

# Define a custom collate function to apply CutMix
def collate_fn(batch):
    try:
        return cutmix(*default_collate(batch))
    except Exception as e:
        print(f"Error in collate_fn: {e}")
        raise e

# Function to load the CIFAR-100 dataset
def cifar100_dataset(batchsize):
    # Define the data transformation for training data
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),  # Randomly crop the image with padding
        transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
        transforms.RandomRotation(15),  # Randomly rotate the image
        transforms.ToTensor(),  # Convert the image to a tensor
        transforms.Normalize(mean, std)  # Normalize the image using the predefined mean and std
    ])
    
    # Define the data transformation for test data
    transform_test = transforms.Compose([
        transforms.ToTensor(),  # Convert the image to a tensor
        transforms.Normalize(mean, std)  # Normalize the image using the predefined mean and std
    ])

    # Load the training dataset
    cifar100_training = torchvision.datasets.CIFAR100(
        root=CIFAR_PATH,  # Root directory of the dataset
        train=True,  # Load training data
        download=True,  # Download the dataset if not available
        transform=transform_train  # Apply the transformations
    )
    # Create a data loader for the training dataset
    trainloader = torch.utils.data.DataLoader(
        cifar100_training,  # The training dataset
        batch_size=batchsize,  # Batch size
        shuffle=True,  # Shuffle the data
        num_workers=num_workers,  # Set number of workers to 0 for debugging
        collate_fn=collate_fn,  # Apply CutMix augmentation
        generator=torch.Generator(device=device), 
    )
    
    # Load the test dataset
    cifar100_testing = torchvision.datasets.CIFAR100(
        root=CIFAR_PATH,  # Root directory of the dataset
        train=False,  # Load test data
        download=True,  # Download the dataset if not available
        transform=transform_test,  # Apply the transformations
    )
    # Create a data loader for the test dataset
    testloader = torch.utils.data.DataLoader(
        cifar100_testing,  # The test dataset
        batch_size=100,  # Batch size
        shuffle=False,  # Do not shuffle the data
        num_workers=num_workers,  # Set number of workers to 0 for debugging
        generator=torch.Generator(device=device),  
    )
    
    # Return the training and test data loaders
    return trainloader, testloader

trainloader, testloader = cifar100_dataset(batchsize=512)

print(trainloader, testloader)


device: cuda
Files already downloaded and verified
Files already downloaded and verified
<torch.utils.data.dataloader.DataLoader object at 0x00000187532E6290> <torch.utils.data.dataloader.DataLoader object at 0x00000187250DA410>


In [52]:
# Have a look at the dataloader
for inputs, fine_labels in trainloader:
    print(inputs.shape)  # Should print torch.Size([batch_size, 3, 32, 32])
    print(fine_labels.shape)  # Should print torch.Size([batch_size, num_fine_classes])
    break

torch.Size([512, 3, 32, 32])
torch.Size([512])


# Define the model

In [45]:
from vit_pytorch import SimpleViT
from torch import nn, optim

model = SimpleViT(
    image_size = 32,
    patch_size = 4,
    num_classes = num_fine_classes,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512
).to(device)

# Train the model

In [53]:
# Training loop
from tqdm import tqdm

def train_model(model, trainloader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        train_loader = tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)  # tqdm progress bar
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            # loss
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            train_loader.set_postfix(loss=loss.item())
        loss_epoch = running_loss / len(trainloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss_epoch}")
    

# Evaluate the model
def evaluate_model(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        val_loss = 0
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            # Loss
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            # Accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = correct / total
    print(f"Accuracy: {100 * accuracy}%")
    return val_loss, accuracy
    

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)

# Train the model
train_model(model, trainloader, criterion, optimizer, num_epochs=10)

Epoch 1/10:  43%|████▎     | 42/98 [00:47<01:03,  1.14s/it, loss=4.37]


KeyboardInterrupt: 

# Evaluate the model

In [None]:
evaluate_model(model, testloader)