02 -- Train RiceLeafs Dataset

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, models, datasets
from torch.optim import Adam
from tqdm import tqdm
import matplotlib.pyplot as plt
import os


In [None]:
DATA_DIR = "processed/riceleaf"
BATCH_SIZE = 32
LR = 1e-4
EPOCHS = 15
IMG_SIZE = 224
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])

val_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])


train_ds = datasets.ImageFolder(os.path.join(DATA_DIR, "train"), transform=train_tfms)
val_ds   = datasets.ImageFolder(os.path.join(DATA_DIR, "val"),   transform=val_tfms)

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

NUM_CLASSES = len(train_ds.classes)
train_ds.classes


In [None]:
model = models.vit_b_16(weights="IMAGENET1K_V1")
model.heads.head = nn.Linear(model.heads.head.in_features, NUM_CLASSES)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LR)


for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for imgs, labels in tqdm(train_dl):
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {total_loss/len(train_dl):.4f}")


In [None]:
model.eval()
correct, total = 0, 0

with torch.no_grad():
    for imgs, labels in val_dl:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        preds = model(imgs).argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
accuracy


In [None]:
# Save the model
torch.save(model.state_dict(), "cassava_vit_best.pth")
