In [1]:
#!pip install torch torchvision timm


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm

In [3]:
num_classes = 10  
batch_size = 8 #32
learning_rate = 1e-4
epochs = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(45),  
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)), 
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), 
    
    transforms.RandomAutocontrast(p=0.2), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
         
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [5]:
train_dataset = datasets.ImageFolder(root='Dataset/training', transform=transform_train)
val_dataset = datasets.ImageFolder(root='Dataset/validation', transform=transform_val)


In [6]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [7]:
model = timm.create_model('mvitv2_large', pretrained=True, num_classes=num_classes)
model = model.to(device)

In [8]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)

216851482


In [9]:
print(model)

MultiScaleVit(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 144, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
  )
  (stages): ModuleList(
    (0): MultiScaleVitStage(
      (blocks): ModuleList(
        (0): MultiScaleBlock(
          (norm1): LayerNorm((144,), eps=1e-06, elementwise_affine=True)
          (attn): MultiScaleAttention(
            (qkv): Linear(in_features=144, out_features=432, bias=True)
            (proj): Linear(in_features=144, out_features=144, bias=True)
            (pool_q): Conv2d(72, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=72, bias=False)
            (norm_q): LayerNorm((72,), eps=1e-06, elementwise_affine=True)
            (pool_k): Conv2d(72, 72, kernel_size=(3, 3), stride=(4, 4), padding=(1, 1), groups=72, bias=False)
            (norm_k): LayerNorm((72,), eps=1e-06, elementwise_affine=True)
            (pool_v): Conv2d(72, 72, kernel_size=(3, 3), stride=(4, 4), padding=(1, 1), groups=72, bias=False)
            (norm_v): L

In [10]:
for name, param in model.named_parameters():
    print(name, param.requires_grad)


patch_embed.proj.weight True
patch_embed.proj.bias True
stages.0.blocks.0.norm1.weight True
stages.0.blocks.0.norm1.bias True
stages.0.blocks.0.attn.rel_pos_h True
stages.0.blocks.0.attn.rel_pos_w True
stages.0.blocks.0.attn.qkv.weight True
stages.0.blocks.0.attn.qkv.bias True
stages.0.blocks.0.attn.proj.weight True
stages.0.blocks.0.attn.proj.bias True
stages.0.blocks.0.attn.pool_q.weight True
stages.0.blocks.0.attn.norm_q.weight True
stages.0.blocks.0.attn.norm_q.bias True
stages.0.blocks.0.attn.pool_k.weight True
stages.0.blocks.0.attn.norm_k.weight True
stages.0.blocks.0.attn.norm_k.bias True
stages.0.blocks.0.attn.pool_v.weight True
stages.0.blocks.0.attn.norm_v.weight True
stages.0.blocks.0.attn.norm_v.bias True
stages.0.blocks.0.norm2.weight True
stages.0.blocks.0.norm2.bias True
stages.0.blocks.0.mlp.fc1.weight True
stages.0.blocks.0.mlp.fc1.bias True
stages.0.blocks.0.mlp.fc2.weight True
stages.0.blocks.0.mlp.fc2.bias True
stages.0.blocks.1.norm1.weight True
stages.0.blocks.1.

In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [13]:
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs,save_path):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")
        torch.save(model.state_dict(), f"{save_path}/model_epoch_{epoch+1}.pth") #we r saving model after every epoch 

        validate_model(model, val_loader, criterion)


def validate_model(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    print(f'Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {100 * correct / total:.2f}%')


In [None]:
train_model(model, train_loader, val_loader, criterion, optimizer, epochs, save_path = "models_augmentations") 
#kindly change path 

Epoch [1/20], Loss: 0.4140, Accuracy: 87.40%
Validation Loss: 0.2666, Accuracy: 91.51%
Epoch [2/20], Loss: 0.2611, Accuracy: 91.69%
Validation Loss: 0.2754, Accuracy: 91.19%
Epoch [3/20], Loss: 0.2144, Accuracy: 92.89%
Validation Loss: 0.2173, Accuracy: 93.06%
Epoch [4/20], Loss: 0.1906, Accuracy: 93.84%
Validation Loss: 0.2192, Accuracy: 92.59%
Epoch [5/20], Loss: 0.1696, Accuracy: 94.51%
Validation Loss: 0.2686, Accuracy: 91.51%
Epoch [6/20], Loss: 0.1534, Accuracy: 94.89%
Validation Loss: 0.1959, Accuracy: 93.66%
Epoch [7/20], Loss: 0.1399, Accuracy: 95.26%
Validation Loss: 0.1984, Accuracy: 93.66%
Epoch [8/20], Loss: 0.1287, Accuracy: 95.68%
Validation Loss: 0.1863, Accuracy: 94.07%
Epoch [9/20], Loss: 0.1183, Accuracy: 96.00%
Validation Loss: 0.2036, Accuracy: 93.58%
Epoch [10/20], Loss: 0.1104, Accuracy: 96.16%
Validation Loss: 0.1845, Accuracy: 94.40%
Epoch [11/20], Loss: 0.1010, Accuracy: 96.53%
Validation Loss: 0.2147, Accuracy: 93.93%
Epoch [12/20], Loss: 0.1000, Accuracy: 96