# Important Library Imports

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os
from torch.utils.data import Dataset
import math
from torchvision.utils import make_grid
import time
import pickle
import torch.nn.functional

# Helper Functions

In [None]:
def validate_model(model, val_loader, device):
    with torch.no_grad():
        num_correct = 0
        total = 0
        model.eval()
        for batch, labels in val_loader:
            batch = batch.to(device)
            labels = labels.to(device)

            pred = model(batch)
            num_correct += (pred.argmax(dim=1) == labels).type(torch.float).sum().item()
            total += len(labels)
        accuracy = (num_correct / total) * 100
        return accuracy

def test_model(model, test_loader, device):
    with torch.no_grad():
        num_correct = 0
        total = 0
        model.eval()
        for batch, labels in test_loader:
            batch = batch.to(device)
            labels = labels.to(device)

            pred = model(batch)
            num_correct += (pred.argmax(dim=1) == labels).type(torch.float).sum().item()
            total += len(labels)
        accuracy = (num_correct / total) * 100
        return accuracy

# Load and Augment Dataset

In [None]:
# Create Data Augmentation
data_transforms = transforms.Compose([
    transforms.RandomChoice([
        transforms.RandomApply([
            transforms.ElasticTransform(alpha=40.0, sigma=8.0)
        ], p=0.2),
        transforms.RandomApply([
            transforms.RandomAffine(degrees=0, shear=20, fill=255)
        ], p=0.2),
        transforms.RandomApply([
            transforms.RandomAffine(degrees=0, scale=(0.8, 1.2), fill=255)
        ], p=0.2),
        transforms.RandomApply([
            transforms.RandomHorizontalFlip(p=1.0)
        ], p=0.2),
        transforms.RandomApply([
            transforms.RandomVerticalFlip(p=1.0)
        ], p=0.2),
    ]),
    transforms.ToTensor()
])

# Load Training, Validation, and Testing Images
LABELS = ["Apple Scab", "Apple Black Rot", "Apple Cedar Rust", "Apple Healthy", "Blueberry Healthy", "Cherry Healthy", "Cherry Powdery Mildew", "Corn Cercospora Leaf Spot", "Corn Common Rust", "Corn Healthy", "Corn Northern Leaf Blight", "Grape Black Rot", "Grape Black Measles", "Grape Healthy", "Grape Isariopsis Leaf Spot", "Orange Haunglonbing",
          "Peach Bacterial Spot", "Peach Healthy", "Bell Pepper Bacterial Spot", "Bell Pepper Healthy", "Potato Early Blight", "Potato Healthy", "Potato Late Blight", "Raspberry Healthy", "Soybean Healthy", "Squash Powdery Mildew", "Strawberry Healthy", "Strawberry Leaf Scorch", "Tomato Bacterial Spot", "Tomato Early Blight", "Tomato Healthy",
          "Tomato Late Blight", "Tomato Leaf Mold", "Tomato Septoria Leaf Spot", "Tomato Spider Mites", "Tomato Target Spot", "Tomato Mosaic Virus", "Tomato Yellow Leaf Curl Virus"]

BATCH_SIZE = 128
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Running on {DEVICE}")
folder_path = "PlantVillage"

train_set = ImageFolder(root=folder_path + "\Training", transform=data_transforms)
val_set = ImageFolder(root=folder_path + "\Validation", transform=transforms.ToTensor())
test_set = ImageFolder(root=folder_path + "\Testing", transform=transforms.ToTensor())

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=12)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Construct Student

In [None]:
class StudentModel(nn.Module):
    def __init__(self):
        super().__init__()

        conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        relu1 = nn.LeakyReLU()
        pool1 = nn.MaxPool2d(kernel_size=2)

        conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        relu2 = nn.LeakyReLU()
        pool2 = nn.MaxPool2d(kernel_size=2)

        conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        relu3 = nn.LeakyReLU()
        Apool = nn.AdaptiveAvgPool2d(1)

        fcn1 = nn.Linear(64, 38)
        softmax = nn.Softmax()

        self.layers = nn.Sequential(conv1, relu1, pool1, conv2, relu2, pool2, conv3, relu3, Apool, fcn1, softmax)

    def foward(self, x):
        return self.layers(x)

def distillation_loss(student_pred, teacher_pred, labels, T, alpha):
    # Calculate soft targets
    soft_loss = nn.functional.kl_div(nn.functional.log_softmax(student_pred / T, dim=1), nn.functional.softmax(teacher_pred / T, dim=1), reduction='batchmean') * (T * T)

    # Calculate soft targets
    criterion = nn.CrossEntropyLoss()
    hard_loss = criterion(student_pred, labels)

    # Calculate total loss
    loss = alpha * hard_loss + (1 - alpha) * soft_loss

    return loss

# Train Student

In [None]:
from DiseasedCNN import DiseasedCNN

NUM_EPOCHS = 50

# Load teacher
teacher = DiseasedCNN()
teacher.load_state_dict(torch.load('DiseasedCNN_statedict.pth'))
teacher.eval()


# Init hyperparameters
learning_rate = 0.0005 # 0.0001
adam_beta1 = 0.9
adam_beta2 = 0.999
T = 4
ALPHA = 0.5
# Create student
student = StudentModel()
student.train()
optimizer = torch.optim.Adam(params=student.parameters(), lr=learning_rate, betas=(adam_beta1, adam_beta2))
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

epoch_loss = []
validation_acc = []

student.to(DEVICE)
for epoch in range(NUM_EPOCHS):
    train_loss = []
    print("Epoch: %d" % epoch)
    for step_num, (batch, labels) in enumerate(train_loader):
        batch = batch.to(DEVICE)
        labels = labels.to(DEVICE)

        student_pred = student(batch)

        with torch.no_grad():
            teacher_pred = teacher(batch)

        loss = distillation_loss(student_pred, teacher_pred, labels, T, ALPHA)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        train_loss.append(loss.item())

    # Perform validation and store accuracy
    validation_accuracy = validate_model(student=student, val_loader=val_loader, device=DEVICE)
    validation_acc.append(validation_accuracy)
    print(validation_accuracy)

    if epoch % 5 == 0:
        scheduler.step()

    # Track average loss for each epoch
    epoch_loss.append(sum(train_loss) / len(train_loss))