Using PyTorch lightning from Multi-Key-Attention-in-Vision-Transformers which is just using PyTorch.  

Multi-Key-Attention-in-Vision-Transformer uses training vit_small_patch16_224 from scratch (no pre-trained weights) on the Food-101 dataset with the default attention block

## the original notebook reduces the model to have:
10/12 attention block
a drop rate of .3
drop path rate of .1
batch size of 32
weight decay of 1e-4
is trained with a base learning rate of 1e-5
as it progresses through LR schedule defined in cell 4.

## change as of 10/2025 ##

drop rate of .5
drop path rate of .2
weight decay is .05

## This cell imports libraries for deep learning, data handling, visualization, and sets random seeds for reproducibility

In [None]:
import time
import copy
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import matplotlib.pyplot as pyplot
import numpy as np
import random

#Seeding random
torch.manual_seed(31)
random.seed(31)
np.random.seed(31)

## This cell defines image transformations for training and validation.  Training data is augmented with cropping, flipping, and color jittering, while validation data is resized and center cropped.  Both are converted to tensors and normalized using ImageNet statistics. ##

train_transform = transforms.Compose([transforms.RandomSizedCrop(224, scale(0.8, 1.0)), 
                                      transforms.RandomHorizontalFlip(), 
                                      transform.Colorjitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

val_transform = transforms.Compose([transforms.Resize(256),
                                    transforms.CenterCrop(224),
                                    transforms.ToTensor(),
                                    tranforms.Normalize(mean[0.485, 0.456, 0.406])])
            
## Taking the above code and putting it into the ImageDataModule .setup() creates datasets using transforms train_dataloader() builds the training DataLoader val_dataloader() builds the validation Dataloader - this is auto generated in the init so it's cleaner. ##
Organize Datasets into a lightningDataModule in PyTorch Lightning, it's best practice to use a LightningDataModule to handle your dataset loading and transforms.  Wrapping transforms and dataloaders:

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets
import pytorch_lightning as pl

class ImageDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, num_workers=4): #what's data dir?
        super().__init__
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.train_transform = transforms.Compose(
            transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225])
        )
        self.val_transform = transforms.Compose(
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
        )
    def setup(self, stage=None):
        self.train_dataset = datasets.ImageFolder(
            root=self.data_dir,
            split='train',
            transform=self.train_transform,
            download=True
        )
        self.val_dataset = datasets.ImageFolder(
            root=self.data_dir,
            split='train',
            transform=self.val_transform,
            download=True
        )
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

Adding this later to Trainer: data_module = ImageDataModule(data_dir='path/to/dataset', batch_size=32) model = YourLightningModel() # changing this once I get the model trainer = pl.Trainer(max_epochs=10) trainer.fit(model, datamodule=data_module)