In [None]:
import os
import random
import warnings

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.utils import make_grid

from modules.plant_disease_cnn import PlantDiseaseCNN

# Suppress warnings
warnings.filterwarnings('ignore')

In [None]:
print(torch.backends.mps.is_available())  # Should return True if MPS is supported

In [None]:
# !pip install --upgrade torch torchvision torchaudio

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
data_path = 'data/'
train_folder = os.path.join(data_path, 'New Plant Diseases Dataset(Augmented)', 'New Plant Diseases Dataset(Augmented)', 'train')
valid_folder = os.path.join(data_path, 'New Plant Diseases Dataset(Augmented)', 'New Plant Diseases Dataset(Augmented)', 'valid')

classes = os.listdir(train_folder)
unique_plants = []
for item in classes:
    plant = item.split('_')[0]
    if plant not in unique_plants:
        unique_plants.append(plant)
print("Number of unique plants:", len(unique_plants))
print("Plants:", unique_plants)

In [None]:
# Define transformations for training and validation
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize to 128x128 pixels
    transforms.RandomHorizontalFlip(),  # Augment data
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize
])

# Load dataset
train_dataset = datasets.ImageFolder(root=train_folder, transform=transform)
val_dataset = datasets.ImageFolder(root=valid_folder, transform=transform)

In [None]:
import multiprocessing

num_workers = min(8, multiprocessing.cpu_count() // 2)  # Use half of available CPUs, max 8
print(f"Using {num_workers} workers for DataLoader.")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=num_workers, pin_memory=True)

In [None]:
print("Number of training images:", len(train_dataset))
print("Number of test images:", len(val_dataset))

In [None]:
print("Number of classes:", len(train_dataset.classes))
print(f"Classes: {train_dataset.classes}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(train_dataset.classes)

# Initialize model
model = PlantDiseaseCNN(num_classes).to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for images, labels in train_loader:
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")

In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Validation Accuracy: {100 * correct / total:.2f}%")

In [None]:
# Save model
torch.save(model.state_dict(), "plant_disease_cnn.pth")

# Load model
model.load_state_dict(torch.load("plant_disease_cnn.pth"))
model.eval()

In [None]:
def predict(image_path):
    image = Image.open(image_path)
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    img_tensor = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(img_tensor)
        _, predicted = torch.max(output, 1)

    class_names = train_dataset.classes
    return class_names[predicted.item()]

# Test prediction
print(predict("data/test/test/PotatoHealthy1.JPG"))

In [None]:
model_path = "plant_disease_cnn.pth"  # Update if needed
size_in_bytes = os.path.getsize(model_path)
size_in_mb = size_in_bytes / (1024 * 1024)

print(f"Model size: {size_in_mb:.2f} MB")

In [None]:
train_dataset.classes