In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from collections import defaultdict
from shutil import copyfile

In [None]:
# Configuration
data_dir = "Dataset/Plant_Leaf_Dataset"
batch_size = 20
num_epochs = 10
image_size = 224
learning_rate = 0.001
num_workers = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

In [None]:
# Data transformations with augmentation
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(20),
    transforms.RandomResizedCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
full_dataset = datasets.ImageFolder(root=os.path.join(data_dir), transform=transform)
class_names = full_dataset.classes
print("Class Names:", class_names)

In [None]:
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_data, val_data, test_data = random_split(full_dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

In [None]:
#Export Images From Test Dataset For Future Testing

output_dir = "Tests/test_images"
os.makedirs(output_dir, exist_ok=True)

class_counts = defaultdict(int)
for image_path, label in test_data.dataset.samples:
    if class_counts[label] < 5:

        class_name = class_names[label]
        class_dir = os.path.join(output_dir, class_name)
        os.makedirs(class_dir, exist_ok=True)
        
        output_path = os.path.join(class_dir, os.path.basename(image_path))
        copyfile(image_path, output_path)
        
        class_counts[label] += 1

    if len(class_counts) == len(class_names) and all(count >= 5 for count in class_counts.values()):
        break

print(f"Exported 5 images per class to {output_dir}")

In [None]:
class ModifiedResNet(nn.Module):
    def __init__(self, base_model, num_classes):
        num_features = base_model.fc.in_features
        super(ModifiedResNet, self).__init__()
        # Extract all layers except the last (fully connected)
        self.base = nn.Sequential(*list(base_model.children())[:-2])
        self.pool = nn.AdaptiveAvgPool2d((1, 1))  # Adaptive pooling for fixed-size output
        self.fc = nn.Sequential(
            nn.Dropout(0.5),  # Dropout to prevent overfitting
            nn.Linear(num_features, num_classes)  # 512 is fixed for ResNet-18
        )
        
    def forward(self, x):
        x = self.base(x)  # Feature extraction
        x = self.pool(x)  # Adaptive pooling to (batch_size, 512, 1, 1)
        x = torch.flatten(x, 1)  # Flatten to (batch_size, 512)
        x = self.fc(x)  # Classification layer
        return x

In [None]:
base_model = models.resnet18(weights='ResNet18_Weights.DEFAULT')
num_classes = len(full_dataset.classes)  # Total classes in your dataset
model = ModifiedResNet(base_model, num_classes)
model = model.to(device)


In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)

In [None]:
# Validation function
def validate_model(model, val_loader):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(val_loader)
    val_accuracy = 100 * correct / total
    return val_loss, val_accuracy

In [None]:
# Early stopping class
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None or val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

In [None]:
# Training function
def train_model():
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_accuracy = 100 * correct / total

        # Validation phase
        val_loss, val_accuracy = validate_model(model, val_loader)

        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"  Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%")
        print(f"  Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
        
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break


In [None]:
def save_model(model, file_path="Saved-Models/leaf_disease_model.pth"):
    torch.save(model.state_dict(), file_path)
    print(f"Model saved to {file_path}")

In [None]:
if __name__ == "__main__":
    train_model()
    save_model(model, file_path="Saved-Models/leaf_disease_model.pth")

In [None]:
def load_model(base_model, num_classes, file_path="Saved-Models/leaf_disease_model.pth"):
    # Create an instance of the custom ModifiedResNet model
    model = ModifiedResNet(base_model, num_classes)
    
    # Check if the model checkpoint exists
    if os.path.exists(file_path):
        model.load_state_dict(torch.load(file_path, map_location=device))
        model = model.to(device)
        print(f"Model loaded from {file_path}")
    else:
        print(f"No model found at {file_path}. Starting with a fresh model.")
    
    return model

In [None]:
base_model = models.resnet18(weights=None)  # Initialize ResNet-18 without pretrained weights
num_classes = len(full_dataset.classes)  # Number of classes in your dataset
loaded_model = load_model(base_model, num_classes, file_path="Saved-Models/leaf_disease_model.pth")

In [None]:
def test_model(model, test_loader):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_loss /= len(test_loader)
    test_accuracy = 100 * correct / total
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    return test_accuracy

In [None]:
test_accuracy = test_model(loaded_model, test_loader)

In [None]:
def predict_image(image_path, model, class_names):
    """Predict the class of a single image."""
    # Transform the image to match the model's input
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),  # Resize to the model's input size
        transforms.ToTensor(),         # Convert to Tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
    ])
    
    # Open the image
    image = Image.open(image_path).convert('RGB')  # Ensure it's in RGB format
    image_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension
    
    # Predict
    model.eval()
    with torch.no_grad():
        output = model(image_tensor)
        _, predicted_class = torch.max(output, 1)  # Get the class with the highest score
    
    return class_names[predicted_class.item()]


In [None]:
def predict_directory(directory_path, model, class_names):
    """Iterate through subdirectories and predict the class for each image."""
    for subdir in os.listdir(directory_path):
        subdir_path = os.path.join(directory_path, subdir)
        
        # Skip if not a directory
        if not os.path.isdir(subdir_path):
            continue
        
        print(f"\nProcessing directory: {subdir}")
        
        # Iterate through the first 5 images in the directory
        for idx, image_name in enumerate(sorted(os.listdir(subdir_path))[:5]):
            image_path = os.path.join(subdir_path, image_name)
            
            # Skip non-image files
            if not image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                continue
            
            # Predict the class of the image
            predicted_class = predict_image(image_path, model, class_names)
            print(f"Image {idx + 1}: {image_name} -> Predicted class: {predicted_class}")

In [None]:
# Define the test directory and call the function
test_directory = "Tests/test_images"
predict_directory(test_directory, loaded_model, class_names)

In [None]:
class_names = full_dataset.classes

image_path = "Tests/test_images/Tomato__Tomato_Yellow_Leaf_Curl_Virus/image (1).jpg"
predict_image(image_path, loaded_model, class_names)
