# Important Library Imports

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os
from torch.utils.data import Dataset
import math
from torchvision.utils import make_grid
import time

# Helper Functions

In [None]:
def validate_model(model, val_loader, device):
    with torch.no_grad():
        num_correct = 0
        total = 0
        model.eval()
        for batch, labels in val_loader:
            batch = batch.to(device)
            labels = labels.to(device)

            pred = model(batch)
            num_correct += (pred.argmax(dim=1) == labels).type(torch.float).sum().item()
            total += len(labels)
        accuracy = (num_correct / total) * 100
        return accuracy

def test_model(model, test_loader, device):
    with torch.no_grad():
        num_correct = 0
        total = 0
        model.eval()
        for batch, labels in test_loader:
            batch = batch.to(device)
            labels = labels.to(device)

            pred = model(batch)
            num_correct += (pred.argmax(dim=1) == labels).type(torch.float).sum().item()
            total += len(labels)
        accuracy = (num_correct / total) * 100
        return accuracy

def train_model(model, train_loader, val_loader, device):
    NUM_EPOCHS = 16
    learning_rate = 0.0001
    adam_beta1 = 0.9
    adam_beta2 = 0.999

    model.train()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, betas=(adam_beta1, adam_beta2))
    criterion = nn.CrossEntropyLoss()

    epoch_loss = []
    train_loss = []
    validaction_acc = []

    for epoch in range(NUM_EPOCHS):
        print("Epoch: %d" % epoch)
        for step_num, (batch, labels) in enumerate(train_loader):
            batch = batch.to(device)
            labels = labels.to(device)

            pred = model(batch)

            optimizer.zero_grad()
            loss = criterion(pred, labels)

            loss.backward()
            optimizer.step()

            if (step_num + 1) % 13 == 0:
                print("Batch %d" % step_num)
                # Perform validation and store accuracy
                validation_accuracy = validate_model(model=model, val_loader=val_loader, device=device)
                validaction_acc.append(validation_accuracy)

            train_loss.append(loss.item())


        # Track average loss for each epoch
        epoch_loss.append(sum(train_loss) / len(train_loss))

    return epoch_loss, train_loss, validaction_acc

# Load Dataset

In [None]:
# Create Data Augmentation
data_transforms = transforms.Compose([
    transforms.RandomChoice([
        transforms.RandomApply([
            transforms.ElasticTransform(alpha=40.0, sigma=8.0)
        ], p=0.2),
        transforms.RandomApply([
            transforms.RandomAffine(degrees=0, shear=20, fill=255)
        ], p=0.2),
        transforms.RandomApply([
            transforms.RandomAffine(degrees=0, scale=(0.8, 1.2), fill=255)
        ], p=0.2),
        transforms.RandomApply([
            transforms.RandomHorizontalFlip(p=1.0)
        ], p=0.2),
        transforms.RandomApply([
            transforms.RandomVerticalFlip(p=1.0)
        ], p=0.2),
    ]),
    transforms.ToTensor()
])

# Load Training, Validation, and Testing Images
LABELS = ["Apple Scab", "Apple Black Rot", "Apple Cedar Rust", "Apple Healthy", "Blueberry Healthy", "Cherry Healthy", "Cherry Powdery Mildew", "Corn Cercospora Leaf Spot", "Corn Common Rust", "Corn Healthy", "Corn Northern Leaf Blight", "Grape Black Rot", "Grape Black Measles", "Grape Healthy", "Grape Isariopsis Leaf Spot", "Orange Haunglonbing",
          "Peach Bacterial Spot", "Peach Healthy", "Bell Pepper Bacterial Spot", "Bell Pepper Healthy", "Potato Early Blight", "Potato Healthy", "Potato Late Blight", "Raspberry Healthy", "Soybean Healthy", "Squash Powdery Mildew", "Strawberry Healthy", "Strawberry Leaf Scorch", "Tomato Bacterial Spot", "Tomato Early Blight", "Tomato Healthy",
          "Tomato Late Blight", "Tomato Leaf Mold", "Tomato Septoria Leaf Spot", "Tomato Spider Mites", "Tomato Target Spot", "Tomato Mosaic Virus", "Tomato Yellow Leaf Curl Virus"]

BATCH_SIZE = 128
DEVICE = torch.device("cuda")
folder_path = "PlantVillage"

train_set = ImageFolder(root=folder_path + "\Training", transform=data_transforms)
val_set = ImageFolder(root=folder_path + "\Validation", transform=transforms.ToTensor())
test_set = ImageFolder(root=folder_path + "\Testing", transform=transforms.ToTensor())

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=12)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Load Pretrained Backbones

In [None]:
DEVICE = torch.device("cuda")
# Load all backbones
resnet101 = torchvision.models.resnet101(weights=torchvision.models.ResNet101_Weights.DEFAULT).to(device=DEVICE)
resnet50 = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT).to(device=DEVICE)
efficientnetb0 = torchvision.models.efficientnet_b0(weights=torchvision.models.EfficientNet_B0_Weights.DEFAULT).to(device=DEVICE)
efficientnetb1 = torchvision.models.efficientnet_b1(weights=torchvision.models.EfficientNet_B1_Weights.DEFAULT).to(device=DEVICE)
efficientnetb2 = torchvision.models.efficientnet_b2(weights=torchvision.models.EfficientNet_B2_Weights.DEFAULT).to(device=DEVICE)
efficientnetb3 = torchvision.models.efficientnet_b3(weights=torchvision.models.EfficientNet_B3_Weights.DEFAULT).to(device=DEVICE)
densenet121 = torchvision.models.densenet121(weights=torchvision.models.DenseNet121_Weights.DEFAULT).to(device=DEVICE)
vgg16_bn = torchvision.models.vgg16_bn(weights=torchvision.models.VGG16_BN_Weights.DEFAULT).to(device=DEVICE)

# Remove last layer for each backbone
resnet101.fc = nn.Identity()
resnet50.fc = nn.Identity()
efficientnetb0.classifier[1] = nn.Identity()
efficientnetb1.classifier[1] = nn.Identity()
efficientnetb2.classifier[1] = nn.Identity()
efficientnetb3.classifier[1] = nn.Identity()
densenet121.classifier = nn.Identity()
vgg16_bn.classifier[6] = nn.Identity()

# Prep list of model and output size
backbones = [("resnet101", resnet101, 2048), ("resnet50", resnet50, 2048), ("efficientnetb0", efficientnetb0, 1280), ("efficientnetb1", efficientnetb1, 1280), ("efficientnetb2", efficientnetb2, 1280), ("efficientnetb3", efficientnetb3, 1536), ("densenet121", densenet121, 1024), ("vgg16_bn", vgg16_bn, 32768)]


# Benchmark Each Backbone

In [None]:
NUM_CLASSES = 38
benchmarks = {}
for name, backbone, backbone_output_size in backbones:
    # Freeze backbone parameters
    for param in backbone.parameters():
        param.requires_grad = False
    # Create model
    model = nn.Sequential(backbone, nn.Linear(backbone_output_size, NUM_CLASSES, device=DEVICE), nn.Softmax(dim=1))
    epoch_loss, train_loss, val_acc = train_model(model, train_loader, val_loader, DEVICE)
    # Test model
    model.eval()
    test_acc = test_model(model, test_loader, DEVICE)
    benchmarks[name] = (model, epoch_loss, train_loss, val_acc, test_acc)