In [1]:
import os
from torchvision import datasets, models, transforms
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Paths
data_dir = 'PlantVillage'

# Transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Load dataset
dataset = datasets.ImageFolder(data_dir, transform=transform)
class_names = dataset.classes

# Split into train and val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Load pretrained model
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(class_names))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Loss & optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)




In [2]:
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += torch.sum(preds == labels).item()

    accuracy = 100 * correct / len(train_dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss:.4f}, Accuracy: {accuracy:.2f}%")


Epoch 1/5, Loss: 188.9766, Accuracy: 88.02%
Epoch 2/5, Loss: 79.9216, Accuracy: 95.00%
Epoch 3/5, Loss: 56.4977, Accuracy: 96.38%
Epoch 4/5, Loss: 43.7405, Accuracy: 97.11%
Epoch 5/5, Loss: 44.1292, Accuracy: 97.18%


In [3]:
torch.save(model.state_dict(), 'plant_disease_model.pth')
print("Model saved as plant_disease_model.pth")

Model saved as plant_disease_model.pth
