In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data augmentation and normalization for training
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),      # random crops
    transforms.RandomHorizontalFlip(),         # random horizontal flip
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # normalize to [-1,1]
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader  = DataLoader(testset,  batch_size=128, shuffle=False, num_workers=2)

100%|██████████| 170M/170M [00:03<00:00, 43.4MB/s]


In [2]:
class ViT(nn.Module):
    def __init__(self, image_size=32, patch_size=4, emb_dim=128, depth=6, heads=8, mlp_dim=256, num_classes=10, dropout=0.1):
        super().__init__()
        assert image_size % patch_size == 0, "Image dimensions must be divisible by patch size."
        num_patches = (image_size // patch_size) ** 2

        # Patch embedding using a conv layer
        self.patch_size = patch_size
        self.emb_dim = emb_dim
        self.patch_embed = nn.Conv2d(3, emb_dim, kernel_size=patch_size, stride=patch_size)

        # Class token and positional embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim))

        # Transformer encoder blocks
        self.transformer_blocks = nn.ModuleList([])
        for _ in range(depth):
            block = nn.ModuleDict({
                'norm1': nn.LayerNorm(emb_dim),
                'attn': nn.MultiheadAttention(embed_dim=emb_dim, num_heads=heads, batch_first=True),
                'drop1': nn.Dropout(dropout),
                'norm2': nn.LayerNorm(emb_dim),
                'mlp': nn.Sequential(
                    nn.Linear(emb_dim, mlp_dim),
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(mlp_dim, emb_dim),
                    nn.Dropout(dropout),
                )
            })
            self.transformer_blocks.append(block)

        self.norm = nn.LayerNorm(emb_dim)
        self.classifier = nn.Linear(emb_dim, num_classes)

        # Initialize parameters
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B = x.size(0)
        # Patchify and embed
        x = self.patch_embed(x)           # shape: (B, emb_dim, H/ps, W/ps)
        x = x.flatten(2)                  # shape: (B, emb_dim, num_patches)
        x = x.transpose(1, 2)             # shape: (B, num_patches, emb_dim)
        # Prepend class token
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B,1,emb_dim)
        x = torch.cat([cls_tokens, x], dim=1)          # (B, num_patches+1, emb_dim)
        # Add positional embeddings
        x = x + self.pos_embed
        # Transformer blocks
        for block in self.transformer_blocks:
            # Multi-Head Self-Attention with residual
            x = x + block['drop1'](block['attn'](block['norm1'](x), block['norm1'](x), block['norm1'](x))[0])
            # MLP with residual
            x = x + block['mlp'](block['norm2'](x))
        x = self.norm(x)
        cls_out = x[:, 0]  # Take the CLS token representation
        logits = self.classifier(cls_out)
        return logits

# Instantiate the model
model = ViT(image_size=32, patch_size=4, emb_dim=128, depth=6, heads=8, mlp_dim=256, num_classes=10, dropout=0.1).to(device)
print(model)


ViT(
  (patch_embed): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
  (transformer_blocks): ModuleList(
    (0-5): 6 x ModuleDict(
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (drop1): Dropout(p=0.1, inplace=False)
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=256, out_features=128, bias=True)
        (4): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (classifier): Linear(in_features=128, out_features=10, bias=True)
)


# Training

In [5]:
# Loss function, optimizer, and LR scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)  # cosine decay over 20 epochs

num_epochs = 100
best_acc = 0.0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    scheduler.step()

    # Evaluate on test set
    model.eval()
    correct = 0; total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = outputs.max(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    acc = 100 * correct / total
    if acc > best_acc:
        best_acc = acc
        best_config = {
            'patch_size': 4, 'depth': 6, 'emb_dim': 128, 'heads': 8,
            'mlp_dim': 256, 'lr': 1e-3, 'weight_decay': 5e-4
        }
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(trainloader):.4f}, Test Accuracy: {acc:.2f}%")

print(f"Best Test Accuracy: {best_acc:.2f}%")
print("Best Config:", best_config)


Epoch [1/100], Loss: 0.9927, Test Accuracy: 64.41%
Epoch [2/100], Loss: 0.9856, Test Accuracy: 66.51%
Epoch [3/100], Loss: 0.9650, Test Accuracy: 66.11%
Epoch [4/100], Loss: 0.9483, Test Accuracy: 67.08%
Epoch [5/100], Loss: 0.9281, Test Accuracy: 67.59%
Epoch [6/100], Loss: 0.9062, Test Accuracy: 67.70%
Epoch [7/100], Loss: 0.8849, Test Accuracy: 67.72%
Epoch [8/100], Loss: 0.8590, Test Accuracy: 69.90%
Epoch [9/100], Loss: 0.8373, Test Accuracy: 69.72%
Epoch [10/100], Loss: 0.8135, Test Accuracy: 71.35%
Epoch [11/100], Loss: 0.7886, Test Accuracy: 71.69%
Epoch [12/100], Loss: 0.7568, Test Accuracy: 72.00%
Epoch [13/100], Loss: 0.7327, Test Accuracy: 72.58%
Epoch [14/100], Loss: 0.7060, Test Accuracy: 73.11%
Epoch [15/100], Loss: 0.6867, Test Accuracy: 73.40%
Epoch [16/100], Loss: 0.6638, Test Accuracy: 73.92%
Epoch [17/100], Loss: 0.6482, Test Accuracy: 73.89%
Epoch [18/100], Loss: 0.6313, Test Accuracy: 73.97%
Epoch [19/100], Loss: 0.6243, Test Accuracy: 74.29%
Epoch [20/100], Loss:

In [6]:
# Final evaluation on test set
model.eval()
correct = 0; total = 0
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = outputs.max(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
final_acc = 100 * correct / total
print(f"Final Test Accuracy: {final_acc:.2f}%")
print("Best Config:", best_config)


Final Test Accuracy: 76.67%
Best Config: {'patch_size': 4, 'depth': 6, 'emb_dim': 128, 'heads': 8, 'mlp_dim': 256, 'lr': 0.001, 'weight_decay': 0.0005}


In [1]:
!pip install -q timm

In [2]:
import math, time, os, copy, random
from pathlib import Path
from functools import partial
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import InterpolationMode
from torch.cuda.amp import autocast, GradScaler

# reproducibility
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cuda


In [3]:
# helper functions: MixUp, CutMix, accuracy
def accuracy(output, target):
    preds = output.argmax(dim=1)
    return (preds == target).float().mean().item()

def mixup_data(x, y, alpha=1.0):
    if alpha <= 0:
        return x, y, None, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def cutmix_data(x, y, alpha=1.0):
    if alpha <= 0:
        return x, y, None, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size, _, H, W = x.size()
    index = torch.randperm(batch_size).to(x.device)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    cut_w = int(W * math.sqrt(1 - lam))
    cut_h = int(H * math.sqrt(1 - lam))
    x1 = np.clip(cx - cut_w // 2, 0, W)
    y1 = np.clip(cy - cut_h // 2, 0, H)
    x2 = np.clip(cx + cut_w // 2, 0, W)
    y2 = np.clip(cy + cut_h // 2, 0, H)
    x[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2]
    lam = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
    y_a, y_b = y, y[index]
    return x, y_a, y_b, lam


In [4]:
# DropPath (stochastic depth) implementation
class DropPath(nn.Module):
    def __init__(self, drop_prob=0.):
        super().__init__()
        self.drop_prob = drop_prob
    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor


In [5]:
# ViT model (from scratch) with configurable stochastic depth & layer scale
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=128):
        super().__init__()
        assert img_size % patch_size == 0
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size // patch_size) ** 2
    def forward(self, x):
        x = self.proj(x)  # B x D x H' x W'
        x = x.flatten(2).transpose(1,2)  # B x N x D
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x); x = self.act(x); x = self.drop(x)
        x = self.fc2(x); x = self.drop(x)
        return x

class AttentionBlock(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., drop_path=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads=num_heads, batch_first=True, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.drop_path = DropPath(drop_path) if drop_path>0. else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, hidden_features=dim*4, drop=proj_drop)
        # layer scaling (small init)
        self.gamma_1 = nn.Parameter(1e-2 * torch.ones(dim))  # layer scale
        self.gamma_2 = nn.Parameter(1e-2 * torch.ones(dim))
    def forward(self, x):
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False)
        x = x + self.drop_path(self.gamma_1 * attn_out)
        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x

class ViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=10,
                 embed_dim=192, depth=12, num_heads=3, mlp_ratio=4., drop_path_rate=0.1, attn_drop=0., drop=0.):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop)
        # stochastic depth linear decay
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            AttentionBlock(embed_dim, num_heads, qkv_bias=True, attn_drop=attn_drop, proj_drop=drop, drop_path=dpr[i])
            for i in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        # init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return self.head(x[:,0])


In [6]:
# Data and augmentations (strong)
from torchvision.transforms import RandAugment

NUM_CLASSES = 10
BATCH_SIZE = 128
NUM_WORKERS = 2

mean = (0.4914, 0.4822, 0.4465)
std  = (0.2470, 0.2435, 0.2616)
train_transforms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(p=0.5),
    RandAugment(num_ops=2, magnitude=9),  # strong augmentation
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train_set = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transforms)
test_set  = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transforms)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)


100%|██████████| 170M/170M [00:04<00:00, 40.9MB/s]


In [7]:
# training utilities: label smoothing loss, warmup cosine scheduler, EMA
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
    def forward(self, pred, target):
        log_probs = F.log_softmax(pred, dim=1)
        n_classes = pred.size(1)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (n_classes - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), 1.0 - self.smoothing)
        return torch.mean(torch.sum(-true_dist * log_probs, dim=1))

class CosineWarmupScheduler(_LRScheduler):
    def __init__(self, optimizer, warmup_epochs, max_epochs, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        super().__init__(optimizer, last_epoch)
    def get_lr(self):
        cur = self.last_epoch
        if cur < self.warmup_epochs:
            return [base_lr * float(cur + 1) / float(self.warmup_epochs) for base_lr in self.base_lrs]
        else:
            # cosine decay
            t = (cur - self.warmup_epochs) / max(1, self.max_epochs - self.warmup_epochs)
            return [base_lr * 0.5 * (1 + math.cos(math.pi * t)) for base_lr in self.base_lrs]

# Simple EMA
class ModelEMA:
    def __init__(self, model, decay=0.9999):
        self.ema = copy.deepcopy(model).eval()
        self.decay = decay
        for p in self.ema.parameters():
            p.requires_grad_(False)
    def update(self, model):
        with torch.no_grad():
            msd = model.state_dict()
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= self.decay
                    v += (1.0 - self.decay) * msd[k].detach()


In [8]:
#instantiate model, optimizer, scheduler, loss

CFG = {
    'img_size': 32, 'patch_size': 4, 'embed_dim': 192, 'depth': 12,
    'num_heads': 3, 'mlp_ratio': 4.0, 'drop_path_rate': 0.1,
    'drop': 0.0, 'attn_drop': 0.0
}
model = ViT(img_size=CFG['img_size'], patch_size=CFG['patch_size'], embed_dim=CFG['embed_dim'],
            depth=CFG['depth'], num_heads=CFG['num_heads'], drop_path_rate=CFG['drop_path_rate']).to(device)

criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
base_lr = 3e-3
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.05)
NUM_EPOCHS = 120
WARMUP_EPOCHS = 5
scheduler = CosineWarmupScheduler(optimizer, warmup_epochs=WARMUP_EPOCHS, max_epochs=NUM_EPOCHS)
scaler = GradScaler()
ema = ModelEMA(model, decay=0.9998)
print("Total params (M):", sum(p.numel() for p in model.parameters())/1e6)


Total params (M): 5.36737


  scaler = GradScaler()


In [9]:
# training loop with MixUp/CutMix + eval
use_mixup = True
mixup_alpha = 0.8
use_cutmix = False   # choose one or both — I've seen MixUp + CutMix both work; set True to try CutMix.

best_acc = 0.0
save_path = "best_vit_cifar.pth"

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    n = 0
    t0 = time.time()
    for images, targets in train_loader:
        images = images.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        # augment via mixup/cutmix
        if use_cutmix:
            images, y_a, y_b, lam = cutmix_data(images, targets, alpha=mixup_alpha)
            with autocast():
                outputs = model(images)
                loss = lam * criterion(outputs, y_a) + (1 - lam) * criterion(outputs, y_b)
            scaler.scale(loss).backward()
        elif use_mixup:
            images, y_a, y_b, lam = mixup_data(images, targets, alpha=mixup_alpha)
            with autocast():
                outputs = model(images)
                loss = lam * criterion(outputs, y_a) + (1 - lam) * criterion(outputs, y_b)
            scaler.scale(loss).backward()
        else:
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, targets)
            scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item() * images.size(0)
        n += images.size(0)
        # update EMA
        ema.update(model)
    scheduler.step()
    # eval
    model.eval()
    correct = 0; total = 0
    with torch.no_grad():
        for images, targets in test_loader:
            images, targets = images.to(device), targets.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            correct += (preds == targets).sum().item()
            total += targets.size(0)
    acc = 100 * correct / total
    if acc > best_acc:
        best_acc = acc
        # save model + ema
        torch.save({'model': model.state_dict(), 'ema': ema.ema.state_dict(), 'cfg': CFG}, save_path)
    print(f"Epoch {epoch+1:03d}/{NUM_EPOCHS}  Loss: {running_loss/n:.4f}  Test Acc: {acc:.2f}%  Best: {best_acc:.2f}%  LR: {optimizer.param_groups[0]['lr']:.2e}  Time: {(time.time()-t0):.1f}s")

print("Training finished. Best test acc:", best_acc)


  with autocast():


Epoch 001/120  Loss: 2.1053  Test Acc: 45.80%  Best: 45.80%  LR: 1.20e-03  Time: 61.3s
Epoch 002/120  Loss: 1.9824  Test Acc: 50.00%  Best: 50.00%  LR: 1.80e-03  Time: 56.9s
Epoch 003/120  Loss: 1.9556  Test Acc: 52.23%  Best: 52.23%  LR: 2.40e-03  Time: 56.5s
Epoch 004/120  Loss: 1.9614  Test Acc: 50.91%  Best: 52.23%  LR: 3.00e-03  Time: 56.7s
Epoch 005/120  Loss: 1.9535  Test Acc: 52.59%  Best: 52.59%  LR: 3.00e-03  Time: 56.8s
Epoch 006/120  Loss: 1.9343  Test Acc: 51.84%  Best: 52.59%  LR: 3.00e-03  Time: 57.2s
Epoch 007/120  Loss: 1.9030  Test Acc: 54.97%  Best: 54.97%  LR: 3.00e-03  Time: 58.2s
Epoch 008/120  Loss: 1.8845  Test Acc: 56.91%  Best: 56.91%  LR: 2.99e-03  Time: 57.2s
Epoch 009/120  Loss: 1.8630  Test Acc: 58.30%  Best: 58.30%  LR: 2.99e-03  Time: 56.5s
Epoch 010/120  Loss: 1.8553  Test Acc: 57.67%  Best: 58.30%  LR: 2.99e-03  Time: 57.0s
Epoch 011/120  Loss: 1.8317  Test Acc: 60.23%  Best: 60.23%  LR: 2.98e-03  Time: 57.0s
Epoch 012/120  Loss: 1.8112  Test Acc: 63.3

In [10]:
# load best and evaluate EMA
ckpt = torch.load("best_vit_cifar.pth", map_location=device)
model.load_state_dict(ckpt['model'])
# optionally use EMA weights for final evaluation:
ema_model = ViT(img_size=CFG['img_size'], patch_size=CFG['patch_size'], embed_dim=CFG['embed_dim'],
            depth=CFG['depth'], num_heads=CFG['num_heads'], drop_path_rate=CFG['drop_path_rate']).to(device)
ema_model.load_state_dict(ckpt['ema'])
ema_model.eval()
# ema eval
correct = 0; total = 0
with torch.no_grad():
    for images, targets in test_loader:
        images, targets = images.to(device), targets.to(device)
        outputs = ema_model(images)
        preds = outputs.argmax(dim=1)
        correct += (preds == targets).sum().item()
        total += targets.size(0)
print("EMA final Accuracy:", 100*correct/total)


EMA final Accuracy: 90.68
