In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.transforms import ToPILImage
import timm
import cv2

In [2]:
num_classes = 10  
batch_size = 32
learning_rate = 1e-4
epochs = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(45),
    #transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    #transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    #transforms.RandomAutocontrast(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform_val = transforms.Compose([
    #transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]) 

In [4]:
train_dataset = datasets.ImageFolder(
    root='Dataset/training',
    transform=transform_train
)

val_dataset = datasets.ImageFolder(
    root='Dataset/validation',
    transform=transform_val
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)



In [5]:
model = timm.create_model('davit_small', pretrained=True, num_classes=num_classes).to(device)


  return self.fget.__get__(instance, owner)()


In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [7]:
import os

In [8]:
# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs, save_dir='models_new'):
    
    os.makedirs(save_dir, exist_ok=True)

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        avg_loss = running_loss / len(train_loader)
        accuracy = 100 * correct / total
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

     
     
        validate_model(model, val_loader, criterion)

        
        model_path = os.path.join(save_dir, f'model_epoch_{epoch+1}.pth')
        torch.save(model.state_dict(), model_path)
        print(f'Model saved to {model_path}')

# Validation function
def validate_model(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    print(f'Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {100 * correct / total:.4f}%') 






In [9]:

# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, epochs, save_dir='models_davit')

Epoch [1/20], Loss: 0.3731, Accuracy: 88.63%
Validation Loss: 0.2484, Accuracy: 92.2762%
Model saved to models_davit/model_epoch_1.pth
Epoch [2/20], Loss: 0.1982, Accuracy: 93.52%
Validation Loss: 0.2004, Accuracy: 93.4788%
Model saved to models_davit/model_epoch_2.pth
Epoch [3/20], Loss: 0.1523, Accuracy: 94.96%
Validation Loss: 0.1784, Accuracy: 94.1049%
Model saved to models_davit/model_epoch_3.pth
Epoch [4/20], Loss: 0.1255, Accuracy: 95.79%
Validation Loss: 0.1816, Accuracy: 94.1607%
Model saved to models_davit/model_epoch_4.pth
Epoch [5/20], Loss: 0.1032, Accuracy: 96.46%
Validation Loss: 0.1667, Accuracy: 94.3962%
Model saved to models_davit/model_epoch_5.pth
Epoch [6/20], Loss: 0.0878, Accuracy: 96.92%
Validation Loss: 0.1893, Accuracy: 94.0739%
Model saved to models_davit/model_epoch_6.pth
Epoch [7/20], Loss: 0.0767, Accuracy: 97.30%
Validation Loss: 0.2044, Accuracy: 94.1793%
Model saved to models_davit/model_epoch_7.pth
Epoch [8/20], Loss: 0.0693, Accuracy: 97.52%
Validation