In [8]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import os

from datetime import datetime
from datetime import timedelta



https://github.com/lucidrains/vit-pytorch?tab=readme-ov-file

In [9]:


# --- Custom Dataset ---
# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        for cls in self.classes:
            cls_path = os.path.join(root_dir, cls)
            for img_name in os.listdir(cls_path):
                self.image_paths.append(os.path.join(cls_path, img_name))
                self.labels.append(self.class_to_idx[cls])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label


In [10]:

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [11]:


def evaluate_model(model, loader, criterion, device):
    model.eval()
    losses, preds, targets = [], [], []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            losses.append(loss.item())
            preds.extend(outputs.argmax(1).cpu().numpy())
            targets.extend(labels.cpu().numpy())

    acc = accuracy_score(targets, preds)
    prec, recall, f1, _ = precision_recall_fscore_support(targets, preds, average='macro', zero_division=0)
    return np.mean(losses), acc, prec, recall, f1, preds, targets

def train_model(model, criterion, optimizer, train_loader, val_loader, test_loader, device, num_epochs, save_policy='min', save_dir='models'):
    os.makedirs(save_dir, exist_ok=True)

    best_val_score = float('-inf')
    worst_val_score = float('inf')
    start_training_time = time.time()
    now = datetime.now()
    print("Starting time:", now.strftime("%d-%m-%Y %H:%M:%S"))
    train_logs, val_logs = [], []

    for epoch in range(1, num_epochs + 1):
        epoch_start_time = time.time()
        model.train()
        train_losses, train_preds, train_targets = [], [], []
       

        for images, labels in train_loader:
            

            images, labels = images.to(device), labels.to(device)
           
            optimizer.zero_grad()
            outputs = model(images)
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())
            train_preds.extend(outputs.argmax(1).cpu().numpy())
            train_targets.extend(labels.cpu().numpy())

        
        train_loss = np.mean(train_losses)
        train_acc = accuracy_score(train_targets, train_preds)
        train_prec, train_rec, train_f1, _ = precision_recall_fscore_support(train_targets, train_preds, average='macro', zero_division=0)

        val_loss, val_acc, val_prec, val_rec, val_f1, _, _ = evaluate_model(model, val_loader, criterion, device)

        epoch_time = time.time() - epoch_start_time
        

        print(f"Epoch [{epoch}/{num_epochs}]")
        print(f"Train | Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | Prec: {train_prec:.4f} | Rec: {train_rec:.4f} | F1: {train_f1:.4f}")
        print(f"Val   | Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | Prec: {val_prec:.4f} | Rec: {val_rec:.4f} | F1: {val_f1:.4f}")
        print(f"Epoch Time: {str(timedelta(seconds=epoch_time))} sec\n")

        train_logs.append([epoch, train_loss, train_acc, train_prec, train_rec, train_f1, epoch_time])
        val_logs.append([epoch, val_loss, val_acc, val_prec, val_rec, val_f1, epoch_time])

        # Model saving policy
        save_path = os.path.join(save_dir, f"epoch_{epoch:03d}.pt")
        

        if save_policy == 'max' and val_acc > best_val_score:
            best_val_score = val_acc
            torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pt'))
            torch.save(model.state_dict(), save_path)

        elif save_policy == 'last' :
            torch.save(model.state_dict(), save_path)
            torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pt'))

    total_time = time.time() - start_training_time
     
    print("Fininsh time:", datetime.now().strftime("%d-%m-%Y %H:%M:%S"))
    print(f"Total training time: {str(timedelta(seconds=total_time))} seconds")

    pd.DataFrame(train_logs, columns=['Epoch', 'Loss', 'Accuracy', 'Precision', 'Recall', 'F1', 'Time']).to_csv(os.path.join(save_dir, 'train_log.csv'), index=False)
    pd.DataFrame(val_logs, columns=['Epoch', 'Loss', 'Accuracy', 'Precision', 'Recall', 'F1', 'Time']).to_csv(os.path.join(save_dir, 'val_log.csv'), index=False)

    print("\n🎯 Final Evaluation on Test Set:")
    test_loss, test_acc, test_prec, test_rec, test_f1, test_preds, test_targets = evaluate_model(model, test_loader, criterion, device)
    print(f"Test | Loss: {test_loss:.4f} | Acc: {test_acc:.4f} | Prec: {test_prec:.4f} | Rec: {test_rec:.4f} | F1: {test_f1:.4f}")

    cm = confusion_matrix(test_targets, test_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.savefig(os.path.join(save_dir, "confusion_matrix.png"))
    plt.close()
    
    
  
    model.load_state_dict(torch.load(os.path.join(save_dir, 'best_model.pt')))
    _, acc, prec, rec, f1, _, _ = evaluate_model(model, test_loader, criterion, device)
    print(f"Best Model | Acc: {acc:.4f} | Prec: {prec:.4f} | Rec: {rec:.4f} | F1: {f1:.4f}")
   

In [12]:
# --- Settings ---

image_dir = 'DataSet/top-agriculture-crop-disease'          # Folder where images are stored
image_size = 32
batch_size = 32
val_ratio = 0.1
test_ratio = 0.1
num_workers = 4
save_dir='vit_models'

In [13]:

# --- Transforms ---
train_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

val_test_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# --- Load full dataset ---
full_dataset = CustomDataset( image_dir)
num_samples = len(full_dataset)
num_val = int(num_samples * val_ratio)
num_test = int(num_samples * test_ratio)
num_train = num_samples - num_val - num_test

# --- Split dataset ---
train_dataset, val_dataset, test_dataset = random_split(full_dataset, [num_train, num_val, num_test])

# Assign transforms after split
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = val_test_transform
test_dataset.dataset.transform = val_test_transform

# --- DataLoaders ---
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# --- Class Info ---
class_names = full_dataset.classes
num_classes = len(class_names)
print(f"Loaded dataset with {num_classes} classes: {class_names}")


Loaded dataset with 3 classes: ['Potato___Early_Blight', 'Potato___Healthy', 'Potato___Late_Blight']


In [None]:

model = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = len(class_names),
    dim = 192,
    depth = 9,
    heads = 12,
    mlp_dim = 384,
    dropout = 0,
    emb_dropout = 0
)


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

train_model(model, criterion, optimizer, train_loader, val_loader, test_loader, device,
            num_epochs=10, save_policy='last', save_dir='vit_models')


Starting time: 14-07-2025 23:16:07
Epoch [1/50]
Train | Loss: 0.6319 | Acc: 0.7182 | Prec: 0.7186 | Rec: 0.7171 | F1: 0.7175
Val   | Loss: 0.3635 | Acc: 0.8529 | Prec: 0.8712 | Rec: 0.8553 | F1: 0.8573
Epoch Time: 0:00:07.985401 sec

Epoch [2/50]
Train | Loss: 0.3093 | Acc: 0.8768 | Prec: 0.8755 | Rec: 0.8757 | F1: 0.8754
Val   | Loss: 0.2557 | Acc: 0.9020 | Prec: 0.9070 | Rec: 0.9038 | F1: 0.9034
Epoch Time: 0:00:06.632752 sec

Epoch [3/50]
Train | Loss: 0.2146 | Acc: 0.9192 | Prec: 0.9186 | Rec: 0.9182 | F1: 0.9179
Val   | Loss: 0.1445 | Acc: 0.9608 | Prec: 0.9597 | Rec: 0.9609 | F1: 0.9599
Epoch Time: 0:00:07.343273 sec

Epoch [4/50]
Train | Loss: 0.1437 | Acc: 0.9478 | Prec: 0.9474 | Rec: 0.9471 | F1: 0.9471
Val   | Loss: 0.2225 | Acc: 0.9118 | Prec: 0.9196 | Rec: 0.9151 | F1: 0.9140
Epoch Time: 0:00:07.674053 sec

Epoch [5/50]
Train | Loss: 0.1053 | Acc: 0.9637 | Prec: 0.9634 | Rec: 0.9633 | F1: 0.9633
Val   | Loss: 0.1687 | Acc: 0.9216 | Prec: 0.9266 | Rec: 0.9241 | F1: 0.9207
Ep

KeyboardInterrupt: 

In [59]:
model.load_state_dict(torch.load(os.path.join(save_dir, 'best_model.pt')))
_, acc, prec, rec, f1, _, _ = evaluate_model(model, test_loader, criterion, device)
print(f"Best Model | Acc: {acc:.4f} | Prec: {prec:.4f} | Rec: {rec:.4f} | F1: {f1:.4f}")

Best Model | Acc: 0.9790 | Prec: 0.9770 | Rec: 0.9790 | F1: 0.9779
