Using PyTorch lightning from Multi-Key-Attention-in-Vision-Transformers which is just using PyTorch.  

Multi-Key-Attention-in-Vision-Transformer uses training vit_small_patch16_224 from scratch (no pre-trained weights) on the Food-101 dataset with the default attention block

the original notebook reduces the model to have:
10/12 attention block
a drop rate of .3
drop path rate of .1
batch size of 32
weight decay of 1e-4
is trained with a base learning rate of 1e-5
as it progresses through LR schedule defined in cell 4.

change as of 10/2025

drop rate of .5
drop path rate of .2
weight decay is .05

This cell imports libraries for deep learning, data handling, visualization, and sets random seeds for reproducibility

In [1]:
import math
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Food101
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from timm import create_model


This cell defines image transformations for training and validation.  Training data is augmented with cropping, flipping, and color jittering, while validation data is resized and center cropped.  Both are converted to tensors and normalized using ImageNet statistics.

train_transform = transforms.Compose([transforms.RandomSizedCrop(224, scale(0.8, 1.0)), 
                                      transforms.RandomHorizontalFlip(), 
                                      transform.Colorjitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

val_transform = transforms.Compose([transforms.Resize(256),
                                    transforms.CenterCrop(224),
                                    transforms.ToTensor(),
                                    tranforms.Normalize(mean[0.485, 0.456, 0.406])])
            
taking the above code and putting it into the ImageDataModule .setup() creates datasets using transforms train_dataloader() builds the training DataLoader val_dataloader() builds the validation Dataloader - this is auto generated in the init so it's cleaner.
Organize Datasets into a lightningDataModule in PyTorch Lightning, it's best practice to use a LightningDataModule to handle your dataset loading and transforms.  Wrapping transforms and dataloaders:

In [2]:
from torch.utils.data import DataLoader
from torchvision.datasets import Food101
from torchvision import transforms
import pytorch_lightning as pl


# ==========================
# DataModule
# ==========================

class Food101DataModule(pl.LightningDataModule):
    def __init__(self, data_dir="./data", batch_size=32, num_workers=4):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225])
        ])

        self.val_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
        ])

    def setup(self, stage=None):
        self.train_dataset = Food101(root=self.data_dir, split="train", transform=self.train_transform, download=True)
        self.val_dataset = Food101(root=self.data_dir, split="test", transform=self.val_transform, download=True)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, 
                          batch_size=self.batch_size, 
                          shuffle=True, 
                          num_workers=self.num_workers,
                          persistent_workers=True
                          )

    def val_dataloader(self):
        return DataLoader(self.val_dataset, 
                          batch_size=self.batch_size, 
                          shuffle=False, 
                          num_workers=self.num_workers,
                          persistent_workers=True
                          )

    # Fix for Lightning 2.x compatibility
    # @property
    # def allow_zero_length_dataloader_with_multiple_devices(self):
    #     return False


create_model instantiates a Vision Transformer (small variant) from timm.
num_classes=101 sets the final linear head to output 101 logits(Food101)
Now I'm still using this ViTLayer Reduction model but putting it into lightning
This time, weight_decay is at 0.5 because of overfitting from previous code.
Also changing the drop_rate to 0.5 and drop_path_rate to 0.2 (regularization hyperparameters) - dropout and stochastic depth.

In [3]:
from timm import create_model

'''By default, PyTorch uses full float32 precision for matrix multiplications.

Tensor Cores on modern GPUs can accelerate mixed-precision math (float16) or reduced-precision float32 computations.

Trade a tiny bit of numerical precision for faster training.
'''
torch.set_float32_matmul_precision('medium')

# ==========================
# ViT model
# ==========================
class ViTLayerReduction(nn.Module):
    def __init__(self, num_blocks=10, num_classes=101):
        super().__init__()
        full_model = create_model(
            "vit_small_patch16_224",
            pretrained=False,
            num_classes=num_classes,
            drop_rate=0.5,
            drop_path_rate=0.2
        )
        self.patch_embed = full_model.patch_embed
        self.cls_token = full_model.cls_token
        self.pos_embed = full_model.pos_embed
        self.pos_drop = full_model.pos_drop
        self.blocks = nn.Sequential(*list(full_model.blocks[:num_blocks]))
        self.norm = full_model.norm
        self.head = full_model.head

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        x = self.blocks(x)
        x = self.norm(x)
        return self.head(x[:, 0])

Shape example (for 224x224, patch_size=16):
Image x: (B,3,224,224)
After patch_embed: (B, 196, D) - 14x14 = 196 patches

full_model.blocks is a list of Transformer encoder blocks (attention +MLP).
This keeps only the first 10 blocks (layer reduction).  Wrapping them in nn.Sequential allows you to call them as one module.

Why?  Fewer blocks mean fewer parameters & FLOPS, and possibly less overfitting.

norm is the final layer norm applied to token embeddings.
head is the classification head (a linear mapping D -> num_classes)

The class token is at index 0 after concatenation; x[:0] extracts it per batch entry
head returns raw logits (not softmax).  Loss functions like CrossEntropyLoss expects logits

In [4]:
import math
import time

# ==========================
# Lightning Module
# ==========================
class LitViT(pl.LightningModule):
    def __init__(self, num_classes=101, base_lr=1e-5, peak_lr=1e-4, final_lr_fraction=0.1,
                 num_epochs=60, warmup_epochs=15, rampup_epochs=15, weight_decay=0.5):
        super().__init__()
        self.save_hyperparameters()
        self.model = ViTLayerReduction(num_blocks=10, num_classes=num_classes)
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.epoch_start_time = None

    def on_train_epoch_start(self):
        self.epoch_start_time = time.time()

    def on_train_epoch_end(self):
        duration = time.time() - self.epoch_start_time
        train_loss = self.trainer.callback_metrics.get("train_loss_epoch")
        train_acc = self.trainer.callback_metrics.get("train_acc_epoch")
        val_loss = self.trainer.callback_metrics.get("val_loss")
        val_acc = self.trainer.callback_metrics.get("val_acc")
        # Epoch number starting at 1
        epoch_num = self.current_epoch + 1
        print(f"Epoch {epoch_num} completed in {duration:.2f} seconds")
        print(f"Train Loss: {train_loss:.4f}  Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f}  Val Acc: {val_acc:.4f}\n")

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        self.log("train_acc", acc, prog_bar=True, on_step=True, on_epoch=True)

        # Log LR
        optimizer = self.optimizers()
        current_lr = optimizer.param_groups[0]['lr']
        self.log("lr", current_lr, prog_bar=True, on_step=True, on_epoch=False)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log("val_acc", acc, prog_bar=True, on_step=False, on_epoch=True)

    # Step-level LR schedule
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.peak_lr, weight_decay=self.hparams.weight_decay)

        def lr_lambda(step):
            steps_per_epoch = self.trainer.estimated_stepping_batches / self.hparams.num_epochs
            epoch_step = step / steps_per_epoch

            if epoch_step < self.hparams.warmup_epochs:
                return self.hparams.base_lr / self.hparams.peak_lr
            elif epoch_step < self.hparams.warmup_epochs + self.hparams.rampup_epochs:
                progress = (epoch_step - self.hparams.warmup_epochs) / self.hparams.rampup_epochs
                lr = self.hparams.base_lr + progress * (self.hparams.peak_lr - self.hparams.base_lr)
                return lr / self.hparams.peak_lr
            else:
                decay_progress = (epoch_step - self.hparams.warmup_epochs - self.hparams.rampup_epochs) / max(1, self.hparams.num_epochs - self.hparams.warmup_epochs - self.hparams.rampup_epochs)
                cosine_decay = 0.5 * (1 + math.cos(math.pi * decay_progress))
                lr = self.hparams.final_lr_fraction * self.hparams.peak_lr + (1 - self.hparams.final_lr_fraction) * self.hparams.peak_lr * cosine_decay
                return lr / self.hparams.peak_lr

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "step"}}


Lightning wraps training boilerplate so you write only the core logic.
LitViT wraps the model and stores configuration(learning rate).  It still calls forward() for inference.  Keep it simple -- delegate to the underlying module.

training_step is called for each training batch
batch is typically (inputs, labels) from DataLoader.
Compute logits -> compute cross_entropy loss (combines log_softmax + nll_loss).
self.log records metrics; prog_bar=True shows it in the progress bar.
Returning loss tells Lightning to run backprop with it.

The validation step is similar to the training step, but for validation batches.  NO optimizer step.

Training the model now...
Training orchestrates the whole run: device placement, loop, checkingpoint (if configured), logging, etc.
accelerator="gpu", devices=1 runs on 1 GPU(if available).  Use "cpu" or remove parameters if you don't have a GPU.
trainer.fit(model, data_module) starts training.  data_module supplies train_dataloader() and val_dataloader()

In [5]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

# ==========================
# Training Script
# ==========================
if __name__ == "__main__":
    batch_size = 32
    max_epochs = 60

    data_module = Food101DataModule(batch_size=batch_size)
    model = LitViT(num_classes=101, num_epochs=max_epochs)

    checkpoint_callback = ModelCheckpoint(
        monitor="val_acc",
        mode="max",
        save_top_k=1,
        filename="best-vit-{epoch:02d}-{val_acc:.4f}"
    )

    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1 if torch.cuda.is_available() else None,
        precision=16,
        callbacks=[checkpoint_callback],
        log_every_n_steps=50
    )

    trainer.fit(model, datamodule=data_module)

    print(f"Best model saved at: {checkpoint_callback.best_model_path}")

c:\Users\MILLAC24\Miniconda3\envs\pytorch_env\lib\site-packages\lightning_fabric\connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
c:\Users\MILLAC24\Miniconda3\envs\pytorch_env\lib\site-packages\pytorch_lightning\utilities\model_summary\model_summary.py:231: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name      | Type              | Params | Mode 
--------------------------------------------------------
0 | model     | ViTLayerReduction | 18.2 M | train
1 | criterion | CrossEntropyLoss  | 0      | train
--------------------------------------------------------
18.2 M    Tra

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1 completed in 141.39 seconds
Train Loss: 4.4828  Train Acc: 0.0354
Val Loss: 4.2580  Val Acc: 0.0806



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2 completed in 144.81 seconds
Train Loss: 4.2679  Train Acc: 0.0721
Val Loss: 4.0717  Val Acc: 0.1081



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3 completed in 141.57 seconds
Train Loss: 4.1365  Train Acc: 0.0994
Val Loss: 3.9539  Val Acc: 0.1330



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4 completed in 144.49 seconds
Train Loss: 4.0258  Train Acc: 0.1179
Val Loss: 3.8349  Val Acc: 0.1638



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5 completed in 144.57 seconds
Train Loss: 3.9347  Train Acc: 0.1391
Val Loss: 3.7509  Val Acc: 0.1752



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6 completed in 142.85 seconds
Train Loss: 3.8548  Train Acc: 0.1550
Val Loss: 3.6643  Val Acc: 0.1916



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7 completed in 142.40 seconds
Train Loss: 3.7779  Train Acc: 0.1706
Val Loss: 3.6124  Val Acc: 0.2048



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8 completed in 145.73 seconds
Train Loss: 3.7184  Train Acc: 0.1831
Val Loss: 3.5160  Val Acc: 0.2289



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9 completed in 141.76 seconds
Train Loss: 3.6625  Train Acc: 0.1957
Val Loss: 3.4589  Val Acc: 0.2445



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10 completed in 141.97 seconds
Train Loss: 3.6162  Train Acc: 0.2027
Val Loss: 3.4453  Val Acc: 0.2487



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11 completed in 138.40 seconds
Train Loss: 3.5732  Train Acc: 0.2142
Val Loss: 3.3712  Val Acc: 0.2613



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12 completed in 138.91 seconds
Train Loss: 3.5318  Train Acc: 0.2229
Val Loss: 3.3540  Val Acc: 0.2652



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13 completed in 138.25 seconds
Train Loss: 3.4960  Train Acc: 0.2327
Val Loss: 3.3128  Val Acc: 0.2768



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14 completed in 137.77 seconds
Train Loss: 3.4622  Train Acc: 0.2403
Val Loss: 3.2573  Val Acc: 0.2906



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15 completed in 138.23 seconds
Train Loss: 3.4239  Train Acc: 0.2492
Val Loss: 3.2472  Val Acc: 0.2921



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 16 completed in 137.62 seconds
Train Loss: 3.4227  Train Acc: 0.2486
Val Loss: 3.2430  Val Acc: 0.2935



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17 completed in 138.29 seconds
Train Loss: 3.4559  Train Acc: 0.2389
Val Loss: 3.2953  Val Acc: 0.2794



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 18 completed in 138.24 seconds
Train Loss: 3.4713  Train Acc: 0.2373
Val Loss: 3.2827  Val Acc: 0.2826



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 19 completed in 136.79 seconds
Train Loss: 3.4778  Train Acc: 0.2359
Val Loss: 3.2927  Val Acc: 0.2789



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 20 completed in 135.66 seconds
Train Loss: 3.4686  Train Acc: 0.2364
Val Loss: 3.2929  Val Acc: 0.2776



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 21 completed in 136.24 seconds
Train Loss: 3.4621  Train Acc: 0.2402
Val Loss: 3.2331  Val Acc: 0.2962



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 22 completed in 139.98 seconds
Train Loss: 3.4413  Train Acc: 0.2459
Val Loss: 3.2300  Val Acc: 0.2872



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 23 completed in 139.20 seconds
Train Loss: 3.4270  Train Acc: 0.2480
Val Loss: 3.1876  Val Acc: 0.2978



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 24 completed in 138.73 seconds
Train Loss: 3.4097  Train Acc: 0.2513
Val Loss: 3.1822  Val Acc: 0.3082



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 25 completed in 140.06 seconds
Train Loss: 3.3901  Train Acc: 0.2544
Val Loss: 3.2015  Val Acc: 0.3048



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 26 completed in 139.76 seconds
Train Loss: 3.3769  Train Acc: 0.2597
Val Loss: 3.2067  Val Acc: 0.2989



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 27 completed in 140.38 seconds
Train Loss: 3.3565  Train Acc: 0.2647
Val Loss: 3.1247  Val Acc: 0.3256



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 28 completed in 139.04 seconds
Train Loss: 3.3314  Train Acc: 0.2711
Val Loss: 3.0877  Val Acc: 0.3322



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 29 completed in 138.85 seconds
Train Loss: 3.3174  Train Acc: 0.2745
Val Loss: 3.0953  Val Acc: 0.3270



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 30 completed in 140.99 seconds
Train Loss: 3.2978  Train Acc: 0.2789
Val Loss: 3.0819  Val Acc: 0.3315



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 31 completed in 134.80 seconds
Train Loss: 3.2728  Train Acc: 0.2862
Val Loss: 3.0616  Val Acc: 0.3390



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 32 completed in 137.10 seconds
Train Loss: 3.2466  Train Acc: 0.2907
Val Loss: 2.9807  Val Acc: 0.3612



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 33 completed in 138.51 seconds
Train Loss: 3.2195  Train Acc: 0.3015
Val Loss: 3.0451  Val Acc: 0.3379



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 34 completed in 137.54 seconds
Train Loss: 3.1964  Train Acc: 0.3047
Val Loss: 2.9563  Val Acc: 0.3608



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 35 completed in 139.31 seconds
Train Loss: 3.1738  Train Acc: 0.3112
Val Loss: 2.9381  Val Acc: 0.3664



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 36 completed in 138.95 seconds
Train Loss: 3.1517  Train Acc: 0.3170
Val Loss: 3.0078  Val Acc: 0.3471



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 37 completed in 140.03 seconds
Train Loss: 3.1385  Train Acc: 0.3197
Val Loss: 2.9505  Val Acc: 0.3652



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 38 completed in 139.32 seconds
Train Loss: 3.1150  Train Acc: 0.3286
Val Loss: 2.8803  Val Acc: 0.3896



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 39 completed in 138.30 seconds
Train Loss: 3.1007  Train Acc: 0.3314
Val Loss: 2.8266  Val Acc: 0.3972



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 40 completed in 136.56 seconds
Train Loss: 3.0840  Train Acc: 0.3346
Val Loss: 2.8898  Val Acc: 0.3817



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 41 completed in 136.39 seconds
Train Loss: 3.0670  Train Acc: 0.3406
Val Loss: 2.8522  Val Acc: 0.3909



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 42 completed in 139.76 seconds
Train Loss: 3.0487  Train Acc: 0.3445
Val Loss: 2.8396  Val Acc: 0.3971



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 43 completed in 135.76 seconds
Train Loss: 3.0292  Train Acc: 0.3506
Val Loss: 2.8188  Val Acc: 0.3952



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 44 completed in 135.87 seconds
Train Loss: 3.0103  Train Acc: 0.3562
Val Loss: 2.7974  Val Acc: 0.4087



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 45 completed in 136.21 seconds
Train Loss: 2.9906  Train Acc: 0.3604
Val Loss: 2.7565  Val Acc: 0.4184



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 46 completed in 137.40 seconds
Train Loss: 2.9732  Train Acc: 0.3610
Val Loss: 2.7484  Val Acc: 0.4168



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 47 completed in 136.66 seconds
Train Loss: 2.9453  Train Acc: 0.3704
Val Loss: 2.7813  Val Acc: 0.4099



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 48 completed in 136.05 seconds
Train Loss: 2.9210  Train Acc: 0.3777
Val Loss: 2.7077  Val Acc: 0.4292



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 49 completed in 136.56 seconds
Train Loss: 2.9002  Train Acc: 0.3803
Val Loss: 2.6885  Val Acc: 0.4341



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 50 completed in 136.07 seconds
Train Loss: 2.8754  Train Acc: 0.3890
Val Loss: 2.6571  Val Acc: 0.4457



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 51 completed in 135.35 seconds
Train Loss: 2.8539  Train Acc: 0.3948
Val Loss: 2.6476  Val Acc: 0.4439



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 52 completed in 138.15 seconds
Train Loss: 2.8252  Train Acc: 0.4018
Val Loss: 2.6117  Val Acc: 0.4568



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 53 completed in 137.71 seconds
Train Loss: 2.8062  Train Acc: 0.4085
Val Loss: 2.5872  Val Acc: 0.4639



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 54 completed in 137.68 seconds
Train Loss: 2.7839  Train Acc: 0.4130
Val Loss: 2.5638  Val Acc: 0.4655



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 55 completed in 136.96 seconds
Train Loss: 2.7603  Train Acc: 0.4184
Val Loss: 2.5589  Val Acc: 0.4690



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 56 completed in 136.53 seconds
Train Loss: 2.7394  Train Acc: 0.4227
Val Loss: 2.5452  Val Acc: 0.4740



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 57 completed in 136.77 seconds
Train Loss: 2.7168  Train Acc: 0.4297
Val Loss: 2.5353  Val Acc: 0.4773



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 58 completed in 137.16 seconds
Train Loss: 2.7021  Train Acc: 0.4352
Val Loss: 2.5286  Val Acc: 0.4790



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 59 completed in 136.82 seconds
Train Loss: 2.6922  Train Acc: 0.4371
Val Loss: 2.5143  Val Acc: 0.4817



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 60 completed in 135.88 seconds
Train Loss: 2.6815  Train Acc: 0.4424
Val Loss: 2.5121  Val Acc: 0.4832



`Trainer.fit` stopped: `max_epochs=60` reached.


Best model saved at: c:\Users\MILLAC24\Documents\GitHub\Multi-Key-Attention-in-Vision-Transformers\lightning_logs\version_8\checkpoints\best-vit-epoch=59-val_acc=0.4832.ckpt


Shapes: ensure DataLoader returns (images, labels) with labels as integers 0..100 for Food101.
Final layer: num_classes= 101 must match dataset labels.  If you change dataset, update the head.
Pretrained weights: pretrained=Flase in create_model.  If switched to True may need to adapt or freeze layers.
Regularization: drop_rate, drop_path_rate, and weight_decay all affect overfitting; tune them one at a time.
Batch size/num_workers: set appropriately for GPU and CPU
Logging metrics: log accuracy (torchmetrics compute acc = (logits.argmax(Dim=1) == y).float().mean())
Learning rate schedulers: add ReduceLROnPlateau or CosineAnnealingLR in configure_optimizers (done inside configure_optimizers).  Use LambdaLR scheduler but normalized by epoch.  This works per-step not per-epoch.
Epoch is more stable and reduces noise.