## Read dataset

In [25]:
# -----------------Have a look at the CIFAR-10 dataset-----------------
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
    return dict

data = unpickle('data/unpacked/train')
print(data.keys())



dict_keys(['filenames', 'batch_label', 'fine_labels', 'coarse_labels', 'data'])


In [26]:
# ------------------Load data with data augmentation-------------------
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import v2
from torch.utils.data import default_collate
import torch.nn.functional as F
import numpy as np
import os

# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]="2,3,4"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
print(f"device: {device}")

CIFAR_PATH = "data"
num_coarse_classes = 20
num_fine_classes = 100

# Calculated mean and standard deviation of image channels for normalization
mean = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
std = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]

# Number of worker threads used for loading data
num_workers = 0

# CutMix augmentation
cutmix = v2.CutMix(num_classes=100)

# Define a custom collate function to apply CutMix
def collate_fn(batch):
    try:
        return cutmix(*default_collate(batch))
    except Exception as e:
        print(f"Error in collate_fn: {e}")
        raise e

# Function to load the CIFAR-100 dataset
def cifar100_dataset(batchsize):
    # Define the data transformation for training data
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),  # Randomly crop the image with padding
        transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
        transforms.RandomRotation(15),  # Randomly rotate the image
        transforms.ToTensor(),  # Convert the image to a tensor
        transforms.Normalize(mean, std)  # Normalize the image using the predefined mean and std
    ])
    
    # Define the data transformation for test data
    transform_test = transforms.Compose([
        transforms.ToTensor(),  # Convert the image to a tensor
        transforms.Normalize(mean, std)  # Normalize the image using the predefined mean and std
    ])

    # Load the training dataset
    cifar100_training = torchvision.datasets.CIFAR100(
        root=CIFAR_PATH,  # Root directory of the dataset
        train=True,  # Load training data
        download=True,  # Download the dataset if not available
        transform=transform_train  # Apply the transformations
    )
    # Create a data loader for the training dataset
    trainloader = torch.utils.data.DataLoader(
        cifar100_training,  # The training dataset
        batch_size=batchsize,  # Batch size
        shuffle=True,  # Shuffle the data
        num_workers=num_workers,  # Set number of workers to 0 for debugging
        collate_fn=collate_fn,  # Apply CutMix augmentation
        generator=torch.Generator(device=device), 
    )
    
    # Load the test dataset
    cifar100_testing = torchvision.datasets.CIFAR100(
        root=CIFAR_PATH,  # Root directory of the dataset
        train=False,  # Load test data
        download=True,  # Download the dataset if not available
        transform=transform_test,  # Apply the transformations
    )
    # Create a data loader for the test dataset
    testloader = torch.utils.data.DataLoader(
        cifar100_testing,  # The test dataset
        batch_size=100,  # Batch size
        shuffle=False,  # Do not shuffle the data
        num_workers=num_workers,  # Set number of workers to 0 for debugging
        generator=torch.Generator(device=device),  
    )
    
    # Return the training and test data loaders
    return trainloader, testloader

trainloader, testloader = cifar100_dataset(batchsize=512*5)

print(trainloader, testloader)


device: cuda
Files already downloaded and verified
Files already downloaded and verified
<torch.utils.data.dataloader.DataLoader object at 0x7f3056b63fa0> <torch.utils.data.dataloader.DataLoader object at 0x7f3056b63ac0>


In [27]:
# Have a look at the dataloader
for inputs, fine_labels in trainloader:
    print(inputs.shape)  # Should print torch.Size([batch_size, 3, 32, 32])
    print(fine_labels.shape)  # Should print torch.Size([batch_size, num_fine_classes])
    break

torch.Size([2560, 3, 32, 32])
torch.Size([2560, 100])


# Define the model

In [28]:
# Only for a test


# from vit_pytorch import ViT
# from torch import nn, optim

# model = ViT(
#     image_size = 32,
#     patch_size = 4,
#     num_classes = num_fine_classes,
#     dim = 512,
#     depth = 6,
#     heads = 8,
#     mlp_dim = 1024,
#     dropout = 0.1,
#     emb_dropout = 0.1,
# ).to(device)

# if torch.cuda.device_count() > 1:
#     print(f"Using {torch.cuda.device_count()} GPUs")
#     model = nn.DataParallel(model)

# model = model.to(device)

In [29]:
# --------------------- Use my implementation ------------------------
from torch import nn, optim
from myvit import VisionTransformer
model = VisionTransformer(img_size=32, patch_size=4, d_model=48*16, num_heads=16, mlp_dim=48*8*3, num_layers=6, num_classes=100)
if torch.cuda.device_count() > 1:
    devices = [0,1,2,3,4]
    print(f"Using {len(devices)} GPUs")
    model = nn.DataParallel(model, device_ids=devices)

print(device)
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
total_params_million = total_params / 10**6
print(f"Total parameters: {total_params_million:.2f}M")

Using 5 GPUs
cuda
Total parameters: 24.99M


# Train the model

In [30]:
# Training loop
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/ViT')

def train_model(model, trainloader, criterion, optimizer, schedulers=[], num_epochs=10, val_every_iter=False, start_epoch=0):
    model.train()
    global_step = 0
    for epoch in range(start_epoch, num_epochs):
        model.train()
        running_loss = 0.0

        train_loader = tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)  # tqdm progress bar
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            # Record LR
            current_lr = optimizer.param_groups[0]['lr']
            writer.add_scalar('LR', current_lr, global_step)
            # Loss
            loss = criterion(outputs, labels)
            # Add to tensorboard
            writer.add_scalar('Loss/train', loss.item(), global_step)
            global_step += 1
            # Backward
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            # Update progress bar
            train_loader.set_postfix(loss=loss.item(), lr=current_lr)

            # Val every iteration
            if val_every_iter:
                val_loss, accuracy = evaluate_model(model, testloader)
                writer.add_scalar('Loss/val', val_loss, global_step)
                writer.add_scalar('Accuracy/val', accuracy, global_step)

        for scheduler in schedulers:
            scheduler.step()

        loss_epoch = running_loss / len(trainloader)
        if not val_every_iter:
            val_loss, accuracy = evaluate_model(model, testloader)
            writer.add_scalar('Loss/val', val_loss, global_step)
            writer.add_scalar('Accuracy/val', accuracy, global_step)
        # save model
        if (epoch+1) % 20 == 0:
            torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler1_state_dict': schedulers[0].state_dict(),
                    'epoch': epoch
                    }, f'weights/ViT/epoch{epoch+1}.pth')
            print("*****************model saved*****************\n")
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss_epoch}")

    writer.close()

# Evaluate the model
def evaluate_model(model, testloader):
    model.eval()
    correct = 0
    correct_top5 = 0
    total = 0
    with torch.no_grad():
        val_loss = 0
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            # Loss
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            # Accuracy
            _, predicted = torch.max(outputs.data, 1)
            _, predicted_top5 = outputs.topk(5, dim=1, largest=True, sorted=True)
            temp = predicted_top5 == labels.view(-1, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            correct_top5 += temp.sum().item()

        accuracy = correct / total
        accuracy_top5 = correct_top5 / total
    print(f"Accuracy: {100 * accuracy:.2f}%, Accuracy(Top5): {100 * accuracy_top5:.2f}%")
    return val_loss/len(testloader), accuracy


# Warm up and decrease of LR
def lr_lambda(epoch):
    warmup_epochs = 30
    if epoch < warmup_epochs:
        return (epoch + 1) / warmup_epochs
    else:
        return 0.1 ** ((epoch - warmup_epochs) / 300)



In [17]:
# Set the model
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
num_epochs = 400

if not os.path.exists('weights/ViT'):
    os.makedirs('weights/ViT')
torch.save({  # check how big the model is
            'model_state_dict': model.state_dict(),
            }, f'weights/ViT/epoch{0}.pth')



In [None]:
# Load and train the model

# **********************
# checkpoint = torch.load('weights/ViT/epoch1.pth', map_location=device)
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# scheduler.load_state_dict(checkpoint['scheduler1_state_dict'])
# start_epoch = checkpoint['epoch'] + 1
#************************

train_model(model, trainloader, criterion, optimizer, schedulers = [scheduler], num_epochs=num_epochs, start_epoch=0)

In [31]:
# Load and train the model

# **********************
checkpoint = torch.load('weights/ViT_patch_size4_d_model48*16_num_heads16_mlp_dim48*8*3_num_layers6/epoch200.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler1_state_dict'])
start_epoch = checkpoint['epoch'] + 1
#************************

evaluate_model(model, testloader)

Accuracy: 63.90%, Accuracy(Top5): 86.32%


(1.3963919854164124, 0.639)