In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

In [3]:
# Config

BATCH_SIZE = 32
IMG_SIZE = 224
EPOCHS = 10
LEARNING_RATE = 1e-3
DATA_ROOT = r"D:\dataset_split"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [4]:
# Data

train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor()
])

train_dataset = datasets.ImageFolder(root = f"{DATA_ROOT}/train", transform = train_transform)
val_dataset   = datasets.ImageFolder(root = f"{DATA_ROOT}/val", transform = val_transform)
test_dataset  = datasets.ImageFolder(root = f"{DATA_ROOT}/test", transform = val_transform)

train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)
val_loader   = DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle = False)
test_loader  = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = False)

num_classes = len(train_dataset.classes)
print("Classes:", train_dataset.classes)

Classes: ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']


In [5]:
# MobileNetV2 (Phase A)

mobilenet = models.mobilenet_v2(pretrained = True)

# Freeze backbone
for param in mobilenet.parameters():
    param.requires_grad = False

# Replace classifier
mobilenet.classifier[1] = nn.Linear(mobilenet.classifier[1].in_features, num_classes)

model = mobilenet.to(device)

# Loss & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.classifier[1].parameters(), lr = LEARNING_RATE)

# Train & Validate

for epoch in range(EPOCHS):
    # Train
    model.train()
    running_loss, running_correct = 0.0, 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        running_correct += (outputs.argmax(1) == labels).sum().item()

    train_loss = running_loss / len(train_dataset)
    train_acc = running_correct / len(train_dataset)

    # Validate
    model.eval()
    val_loss, val_correct = 0.0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * imgs.size(0)
            val_correct += (outputs.argmax(1) == labels).sum().item()

    val_loss /= len(val_dataset)
    val_acc = val_correct / len(val_dataset)

    print(f"Epoch [{epoch+1}/{EPOCHS}] "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} "
          f"| Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to C:\Users\hoang/.cache\torch\hub\checkpoints\mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:03<00:00, 3.63MB/s]


Epoch [1/10] Train Loss: 0.7315 | Train Acc: 0.7974 | Val Loss: 0.3120 | Val Acc: 0.9155
Epoch [2/10] Train Loss: 0.3727 | Train Acc: 0.8849 | Val Loss: 0.2165 | Val Acc: 0.9344
Epoch [3/10] Train Loss: 0.3218 | Train Acc: 0.8960 | Val Loss: 0.2260 | Val Acc: 0.9252
Epoch [4/10] Train Loss: 0.2950 | Train Acc: 0.9047 | Val Loss: 0.1790 | Val Acc: 0.9446
Epoch [5/10] Train Loss: 0.2835 | Train Acc: 0.9078 | Val Loss: 0.1593 | Val Acc: 0.9524
Epoch [6/10] Train Loss: 0.2783 | Train Acc: 0.9077 | Val Loss: 0.1727 | Val Acc: 0.9407
Epoch [7/10] Train Loss: 0.2656 | Train Acc: 0.9101 | Val Loss: 0.1660 | Val Acc: 0.9436
Epoch [8/10] Train Loss: 0.2586 | Train Acc: 0.9109 | Val Loss: 0.1580 | Val Acc: 0.9446
Epoch [9/10] Train Loss: 0.2560 | Train Acc: 0.9131 | Val Loss: 0.1771 | Val Acc: 0.9422
Epoch [10/10] Train Loss: 0.2566 | Train Acc: 0.9134 | Val Loss: 0.1825 | Val Acc: 0.9354


In [6]:
# Test

model.eval()
test_loss, test_correct = 0.0, 0
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        test_loss += loss.item() * imgs.size(0)
        test_correct += (outputs.argmax(1) == labels).sum().item()

test_loss /= len(test_dataset)
test_acc = test_correct / len(test_dataset)

print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")

# Save model
torch.save(model.state_dict(), "D:/saved_models/mobilenetv2.pt")

Test Loss: 0.2088 | Test Acc: 0.9239
