In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.transforms import ToPILImage
import timm
from collections import Counter
import numpy as np
import os

In [3]:
num_classes = 10  
batch_size = 32
learning_rate = 1e-4
epochs = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(45),
    #transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    #transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    #transforms.RandomAutocontrast(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform_val = transforms.Compose([
    #transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]) 

In [6]:
train_dataset = datasets.ImageFolder(
    root='Dataset/training',
    transform=transform_train
)

val_dataset = datasets.ImageFolder(
    root='Dataset/validation',
    transform=transform_val
)


In [7]:
labels = [item[1] for item in train_dataset] 
counter = Counter(labels)
total_samples = len(train_dataset)
class_weights = {cls: total_samples / count for cls, count in counter.items()}
weights = [class_weights[label] for label in labels]
train_sampler = torch.utils.data.WeightedRandomSampler(weights=weights, num_samples=len(weights), replacement=True)


In [19]:
print(counter)
print(total_samples)
print(class_weights)

Counter({6: 28663, 2: 2694, 7: 1162, 0: 1154, 1: 834, 5: 796, 4: 792, 3: 691, 8: 663, 9: 158})
37607
{0: 32.58838821490468, 1: 45.092326139088726, 2: 13.95953971789161, 3: 54.424023154848044, 4: 47.48358585858586, 5: 47.244974874371856, 6: 1.3120399120817778, 7: 32.36402753872633, 8: 56.72247360482655, 9: 238.01898734177215}


In [13]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [14]:
model = timm.create_model('davit_small', pretrained=True, num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate) 

In [15]:
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs, save_dir='models_new'):
    os.makedirs(save_dir, exist_ok=True)

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        avg_loss = running_loss / len(train_loader)
        accuracy = 100 * correct / total
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

        validate_model(model, val_loader, criterion)

        model_path = os.path.join(save_dir, f'model_epoch_{epoch+1}.pth')
        torch.save(model.state_dict(), model_path)
        print(f'Model saved to {model_path}')
def validate_model(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    print(f'Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {100 * correct / total:.4f}%') 


In [None]:
train_model(model, train_loader, val_loader, criterion, optimizer, epochs, save_dir='models_davit_wrs')

In [20]:
print(model)

DaVit(
  (stem): Stem(
    (conv): Conv2d(3, 96, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
    (norm): LayerNorm2d((96,), eps=1e-05, elementwise_affine=True)
  )
  (stages): Sequential(
    (0): DaVitStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): Sequential(
          (0): SpatialBlock(
            (cpe1): ConvPosEnc(
              (proj): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96)
              (act): Identity()
            )
            (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=96, out_features=288, bias=True)
              (proj): Linear(in_features=96, out_features=96, bias=True)
              (softmax): Softmax(dim=-1)
            )
            (drop_path1): Identity()
            (cpe2): ConvPosEnc(
              (proj): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96)
              