### Warning: Running this on your local/personal laptop will be super slow and likely to destroy your battery. 
### Run on on a GPU provided by the CS department instead.

In [None]:
# Ensures compatibility for Python 2 and 3 print function
from __future__ import print_function

# Core PyTorch libraries for tensor and neural network operations
import torch
import torch.nn as nn  # Neural network module in PyTorch
import torch.optim as optim  # Optimization algorithms
import torch.nn.functional as F  # Functional interface with activation functions, loss functions, etc.
import torch.backends.cudnn as cudnn  # CUDNN backend to optimize deep learning operations on GPUs

# Numerical operations library
import numpy as np

# PyTorch library for vision-related utilities
import torchvision
import torchvision.transforms as transforms  # Transformations that can be applied to image data

# Standard Python libraries for system operations, argument parsing, and data handling
import os  # Interfaces with the operating system
import argparse  # Parser for command-line options, arguments, and sub-commands
import pandas as pd  # Data analysis and manipulation tool
import csv  # CSV file reading and writing
import time  # Time access and conversions

# Importing custom modules that are presumably part of the user's project
from models import *  # This would import all available models in the models directory
from utils import progress_bar  # Likely a utility to display a progress bar in the console
from randomaug import RandAugment  # Custom module for random data augmentation
from models.vit import ViT  # Importing the Vision Transformer model from the models directory
from models.convmixer import ConvMixer  # Importing the ConvMixer model from the models directory

In [None]:
# Import the argparse library, which is used to create user-friendly command-line interfaces.
import argparse

# Set up a parser for the command-line arguments.
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')

# Add command-line options to the parser. Each option can alter how the training script behaves.
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')  # Set the default learning rate for the optimizer.
parser.add_argument('--opt', default="adam")  # Choose the optimizer to use; default is Adam.
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')  # Flag to resume training from a checkpoint if provided.
parser.add_argument('--noaug', action='store_true', help='disable use randomaug')  # Flag to disable random augmentations.
parser.add_argument('--noamp', action='store_true', help='disable mixed precision training. for older pytorch versions')  # Flag to disable automatic mixed precision training.
parser.add_argument('--nowandb', action='store_true', help='disable wandb')  # Flag to disable Weights & Biases logging.
parser.add_argument('--mixup', action='store_true', help='add mixup augmentations')  # Flag to enable mixup augmentations.
parser.add_argument('--net', default='vit')  # Specify the network architecture to use; default is Vision Transformer (ViT).
parser.add_argument('--bs', default='512')  # Set the default batch size for training.
parser.add_argument('--size', default="32")  # Define the size of the images used for training.
parser.add_argument('--n_epochs', type=int, default='200')  # Set the number of epochs for which to train.
parser.add_argument('--patch', default='4', type=int, help="patch size for ViT")  # Specify the patch size for the Vision Transformer.
parser.add_argument('--dimhead', default="512", type=int)  # Define the dimension of the heads in the transformer model.
parser.add_argument('--convkernel', default='8', type=int, help="kernel size parameter for ConvMixer")  # Set the kernel size for the ConvMixer model.

# Parse the command-line arguments.
args = parser.parse_args()

# Process the parsed arguments.
usewandb = not args.nowandb  # Determine whether to use Weights & Biases logging.
if usewandb:
    import wandb  # Import the wandb library if it is going to be used.
    # Set up a watermark identifier for the Weights & Biases logging.
    watermark = "{}_lr{}".format(args.net, args.lr)
    wandb.init(project="cifar10-challenge", name=watermark)  # Initialize the Weights & Biases project.
    wandb.config.update(args)  # Update the Weights & Biases configuration with the arguments.

# Convert batch size and image size arguments to integers from strings.
bs = int(args.bs)
imsize = int(args.size)

# Set a flag for using automatic mixed precision training.
use_amp = not args.noamp
# Set a flag for using augmentations.
aug = not args.noaug

# Determine the device to use for training: 'cuda' for GPU (if available) or 'cpu' for CPU.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # Initialize a variable to keep track of the best test accuracy.
start_epoch = 0  # Set the starting epoch, which will be 0 unless a checkpoint is loaded to resume training.


In [None]:
# Data preparation message.
print('==> Preparing data..')

# Set the image size based on the network type; different networks may require different input sizes.
if args.net == "vit_timm":
    size = 384  # Vision Transformer (ViT) from the 'timm' library may expect a larger input size.
else:
    size = imsize  # For other networks, use the size provided as an argument.

# Define transformations for the training data.
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # Randomly crop the images to 32x32 pixels after padding the edges by 4 pixels.
    transforms.Resize(size),  # Resize the images to the specified size.
    transforms.RandomHorizontalFlip(),  # Randomly flip the images horizontally.
    transforms.ToTensor(),  # Convert the PIL Image to a PyTorch tensor.
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),  # Normalize the images with mean and std deviation for CIFAR10.
])

# Define transformations for the testing data (without random augmentations).
transform_test = transforms.Compose([
    transforms.Resize(size),  # Resize the images to the specified size.
    transforms.ToTensor(),  # Convert the PIL Image to a PyTorch tensor.
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),  # Normalize the images.
])

# Add RandAugment only if augmentations are enabled.
if aug:
    N = 2; M = 14  # Define the number and magnitude of the augmentations.
    transform_train.transforms.insert(0, RandAugment(N, M))  # Insert RandAugment at the beginning of the transformation pipeline.

# Load the CIFAR10 dataset for training and apply the transformations.
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=8)  # Create a DataLoader for the training set.

# Load the CIFAR10 dataset for testing and apply the transformations.
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)  # Create a DataLoader for the test set.

# Define the class names for CIFAR10 dataset.
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


In [None]:
# Model factory..
print('==> Building model..')
# net = VGG('VGG19')
if args.net=='res18':
    net = ResNet18()
elif args.net=='vgg':
    net = VGG('VGG19')
elif args.net=='res34':
    net = ResNet34()
elif args.net=='res50':
    net = ResNet50()
elif args.net=='res101':
    net = ResNet101()
elif args.net=="convmixer":
    # from paper, accuracy >96%. you can tune the depth and dim to scale accuracy and speed.
    net = ConvMixer(256, 16, kernel_size=args.convkernel, patch_size=1, n_classes=10)
elif args.net=="mlpmixer":
    from models.mlpmixer import MLPMixer
    net = MLPMixer(
    image_size = 32,
    channels = 3,
    patch_size = args.patch,
    dim = 512,
    depth = 6,
    num_classes = 10
)
elif args.net=="vit_small":
    from models.vit_small import ViT
    net = ViT(
    image_size = size,
    patch_size = args.patch,
    num_classes = 10,
    dim = int(args.dimhead),
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
)
elif args.net=="vit_tiny":
    from models.vit_small import ViT
    net = ViT(
    image_size = size,
    patch_size = args.patch,
    num_classes = 10,
    dim = int(args.dimhead),
    depth = 4,
    heads = 6,
    mlp_dim = 256,
    dropout = 0.1,
    emb_dropout = 0.1
)
elif args.net=="simplevit":
    from models.simplevit import SimpleViT
    net = SimpleViT(
    image_size = size,
    patch_size = args.patch,
    num_classes = 10,
    dim = int(args.dimhead),
    depth = 6,
    heads = 8,
    mlp_dim = 512
)
elif args.net=="vit":
    # ViT for cifar10
    net = ViT(
    image_size = size,
    patch_size = args.patch,
    num_classes = 10,
    dim = int(args.dimhead),
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
)
elif args.net=="vit_timm":
    import timm
    net = timm.create_model("vit_base_patch16_384", pretrained=True)
    net.head = nn.Linear(net.head.in_features, 10)
elif args.net=="cait":
    from models.cait import CaiT
    net = CaiT(
    image_size = size,
    patch_size = args.patch,
    num_classes = 10,
    dim = int(args.dimhead),
    depth = 6,   # depth of transformer for patch to patch attention only
    cls_depth=2, # depth of cross attention of CLS tokens to patch
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1,
    layer_dropout = 0.05
)
elif args.net=="cait_small":
    from models.cait import CaiT
    net = CaiT(
    image_size = size,
    patch_size = args.patch,
    num_classes = 10,
    dim = int(args.dimhead),
    depth = 6,   # depth of transformer for patch to patch attention only
    cls_depth=2, # depth of cross attention of CLS tokens to patch
    heads = 6,
    mlp_dim = 256,
    dropout = 0.1,
    emb_dropout = 0.1,
    layer_dropout = 0.05
)
elif args.net=="swin":
    from models.swin import swin_t
    net = swin_t(window_size=args.patch,
                num_classes=10,
                downscaling_factors=(2,2,2,1))

# For Multi-GPU
# Check if the code should be run on multiple GPUs.
if 'cuda' in device:
    print(device)
    print("using data parallel")
    net = torch.nn.DataParallel(net)  # Wrap the model for parallel processing.
    cudnn.benchmark = True  # Set the benchmark mode to true for optimized inference.

# If the resume argument is set, try to load a saved checkpoint.
if args.resume:
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'  # Ensure the checkpoint directory exists.
    checkpoint = torch.load('./checkpoint/{}-ckpt.t7'.format(args.net))  # Load the checkpoint file.
    net.load_state_dict(checkpoint['net'])  # Load the model parameters.
    best_acc = checkpoint['acc']  # Retrieve the best recorded accuracy.
    start_epoch = checkpoint['epoch']  # Retrieve the epoch at which training was stopped.

# Define the loss function to be used for training.
criterion = nn.CrossEntropyLoss()

# Setup the optimizer based on the command line argument, with the specified learning rate.
if args.opt == "adam":
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
elif args.opt == "sgd":
    optimizer = optim.SGD(net.parameters(), lr=args.lr)

# Setup a learning rate scheduler to adjust the learning rate over epochs.
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_epochs)

In [None]:
# Import the necessary library for mixed precision training.
from torch.cuda.amp import GradScaler, autocast

# Set up the gradient scaler for automatic mixed precision (AMP).
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

# Define the train function that will be called for each epoch.
def train(epoch):
    print('\nEpoch: %d' % epoch)  # Print the current epoch number.
    net.train()  # Set the model to training mode.
    
    # Initialize variables to track the loss and accuracy.
    train_loss = 0
    correct = 0
    total = 0
    
    # Iterate over the batched data in the trainloader.
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)  # Move the data to the correct device (GPU or CPU).

        # Enable AMP autocasting for the forward pass, if it's enabled.
        with torch.cuda.amp.autocast(enabled=use_amp):
            outputs = net(inputs)  # Forward pass: compute the output class probabilities.
            loss = criterion(outputs, targets)  # Compute the loss between the outputs and labels.

        # Backward pass and optimization are scaled for AMP.
        scaler.scale(loss).backward()  # Scale the loss and compute the gradients.
        scaler.step(optimizer)  # Call the optimizer step to update the weights.
        scaler.update()  # Update the scale for the next iteration.
        optimizer.zero_grad()  # Zero the parameter gradients.

        # Accumulate the loss and accuracy statistics.
        train_loss += loss.item()
        _, predicted = outputs.max(1)  # Get the index of the max log-probability as the prediction.
        total += targets.size(0)  # Update the total count of examples.
        correct += predicted.eq(targets).sum().item()  # Update the correct prediction count.

        # Call the progress bar function to print the progress.
        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Return the average loss for the epoch.
    return train_loss/(batch_idx+1)


In [None]:
# Define the function for the validation or test phase.
def test(epoch):
    global best_acc  # Reference the global variable that tracks the best accuracy observed.

    net.eval()  # Set the network to evaluation mode, which disables dropout and batch normalization effects.

    # Initialize variables to track the loss and accuracy.
    test_loss = 0
    correct = 0
    total = 0

    # Disable gradient computation for efficiency and to prevent model updates.
    with torch.no_grad():
        # Iterate over the batched data in the testloader.
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)  # Move the data to the correct device.
            outputs = net(inputs)  # Compute the model's outputs.
            loss = criterion(outputs, targets)  # Calculate the loss.

            # Accumulate the loss and update accuracy metrics.
            test_loss += loss.item()
            _, predicted = outputs.max(1)  # The class with the highest output value is the predicted class.
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()  # Count the number of correct predictions.

            # Display the progress with the utility function.
            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Calculate the accuracy of the current epoch.
    acc = 100. * correct / total

    # Check if the current accuracy is the best and save the model state if true.
    if acc > best_acc:
        print('Saving..')
        state = {
            "model": net.state_dict(),  # Save the model parameters.
            "optimizer": optimizer.state_dict(),  # Save the optimizer state.
            "scaler": scaler.state_dict()  # Save the gradient scaler state for AMP.
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')  # Create a directory for checkpoints if it doesn't exist.
        torch.save(state, './checkpoint/'+args.net+'-{}-ckpt.t7'.format(args.patch))  # Save the state to a file.
        best_acc = acc  # Update the best accuracy.

    # Create a directory for logs if it does not exist and write the evaluation results.
    os.makedirs("log", exist_ok=True)
    content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, val loss: {test_loss:.5f}, acc: {(acc):.5f}'
    print(content)
    with open(f'log/log_{args.net}_patch{args.patch}.txt', 'a') as appender:
        appender.write(content + "\n")

    # Return the test loss and accuracy for this epoch.
    return test_loss, acc

# Lists to keep track of loss and accuracy over the epochs.
list_loss = []
list_acc = []

# If Weights & Biases logging is enabled, start watching the model to log metrics.
if usewandb:
    wandb.watch(net)
    
# Move the network to the GPU if CUDA is available.
net.cuda()


In [None]:
# Iterate over the epochs from the starting epoch to the total number of epochs.
for epoch in range(start_epoch, args.n_epochs):
    start = time.time()  # Record the start time of the epoch.

    # Call the train function defined elsewhere to perform training on the training dataset.
    trainloss = train(epoch)

    # Call the test function defined elsewhere to perform evaluation on the validation dataset.
    val_loss, acc = test(epoch)
    
    # Step the scheduler to adjust the learning rate based on the number of epochs completed.
    scheduler.step(epoch-1) # Assuming a cosine annealing learning rate schedule.

    # Append the validation loss and accuracy to their respective lists for tracking.
    list_loss.append(val_loss)
    list_acc.append(acc)
    
    # If Weights & Biases logging is enabled, log the metrics for the current epoch.
    if usewandb:
        wandb.log({
            'epoch': epoch,
            'train_loss': trainloss,
            'val_loss': val_loss,
            "val_acc": acc,
            "lr": optimizer.param_groups[0]["lr"],  # Log the current learning rate.
            "epoch_time": time.time()-start  # Log the time taken to complete the epoch.
        })

    # Write out the loss and accuracy to a CSV file after each epoch.
    with open(f'log/log_{args.net}_patch{args.patch}.csv', 'w') as f:
        writer = csv.writer(f, lineterminator='\n')
        writer.writerow(list_loss)  # Write the list of validation losses.
        writer.writerow(list_acc)  # Write the list of validation accuracies.
    print(list_loss)  # Print the list of validation losses to the console.

# After the training is completed, if Weights & Biases is being used, save the run data locally.
if usewandb:
    wandb.save("wandb_{}.h5".format(args.net))
