In [2]:
import os
import shutil
import cv2
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F

from PIL import Image
from tqdm import tqdm

import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [None]:
#Needs work
#referenced - https://medium.com/@myringoleMLGOD/simple-convolutional-neural-network-cnn-for-dummies-in-pytorch-a-step-by-step-guide-6f4109f6df80
#	          - https://medium.com/@kritiikaaa/convolution-neural-networks-guide-for-your-first-cnn-project-7ea56f7f6960
# Formal dataset definition
from google.colab import drive
drive.mou

!ls -l '/content/drive/MyDrive/ECS_170/data/processed/merged_data'

class PlantHealthDataset(Dataset):
	def __init__(self, root_dir):
		self.samples = []
		self.transform = transforms.Compose([
					transforms.ToTensor(),
		])

		for plant_type in os.listdir(root_dir):
			plant_path = os.path.join(root_dir, plant_type)
			if not os.path.isdir(plant_path):
				continue

			for split in ["train", "valid"]:
				split_path = os.path.join(plant_path, split)
				if not os.path.isdir(split_path):
					continue

				for condition in os.listdir(split_path):
					# Conditions for example healthy, bacterial spot etc.
					condition_path = os.path.join(split_path, condition)
					if not os.path.isdir(condition_path):
						continue

					# For now, we only care about if it's healthy (0) or not (1)
					label = 0 if "healthy" in condition.lower() else 1

					for fname in os.listdir(condition_path):
						if fname.lower().endswith((".jpg", ".jpeg", ".png")):

							# Add each image path + label to the samples
							self.samples.append((
										os.path.join(condition_path, fname),
										label
							))

	def __len__(self):
		return len(self.samples)

	def __getitem__(self, idx):
		img_path, label = self.samples[idx]
		image = Image.open(img_path).convert("L") # Open as grayscale
		image = self.transform(image)
		return image, torch.tensor(label, dtype=torch.long)


#CNN definition
class plantCNN(nn.Module):
	def __init__(self):
		super(plantCNN, self).__init__()
		self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
		self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
		self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

		self.pool = nn.MaxPool2d(2, 2)
		self.relu = nn.ReLU()
		self.flatten = nn.Flatten()

		# After 3 pools: 128 -> 64 -> 32 -> 16
		self.fc1 = nn.Linear(64 * 16 * 16, 128)
		self.fc2 = nn.Linear(128, 2)  # Healthy vs Diseased

	def forward(self, x):
		x = self.pool(self.relu(self.conv1(x)))
		x = self.pool(self.relu(self.conv2(x)))
		x = self.pool(self.relu(self.conv3(x)))
		x = self.flatten(x)
		x = self.relu(self.fc1(x))
		x = self.fc2(x)
		return x

# Load the dataset
data_root = "/content/drive/MyDrive/ECS_170 (Dataset)/processed/PlantDiseasesDataset"
dataset = PlantHealthDataset(data_root)

split_idx = int(0.8 * len(dataset))
train_data, val_data = torch.utils.data.random_split(dataset, [split_idx, len(dataset) - split_idx])

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)

# Set up training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = plantCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training and validation loops
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for imgs, labels in tqdm(loader, desc="Training"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (preds == labels).sum().item()
    acc = 100 * correct / total
    return running_loss / len(loader), acc

def eval_epoch(model, loader, criterion):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="Validation"):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    acc = 100 * correct / total
    return running_loss / len(loader), acc

epochs = 17
for epoch in range(epochs):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = eval_epoch(model, val_loader, criterion)
    print(f"\nEpoch [{epoch+1}/{epochs}]")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%\n")

os.makedirs("/content/drive/MyDrive/ECS_170/models", exist_ok=True)
torch.save(model.state_dict(), "/content/drive/MyDrive/ECS_170/models/plant_health_cnn.pth")