In [1]:
import torch
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, random_split

from torch import nn, optim

In [2]:
# loading out plant disease detector class
from models import PlantDiseaseDetector

In [3]:
# Creating image transformers
transform = transforms.Compose([
    transforms.Resize(size=256),
    transforms.CenterCrop((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [4]:
# Loading the datasets
train_dir = "training_data/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)/train"
valid_dir = "training_data/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)/valid"

In [5]:
train_data = datasets.ImageFolder(train_dir, transform = transform)
valid_data = datasets.ImageFolder(valid_dir, transform = transform)

In [6]:
# defining the dataloaders
batch_size = 64
data_loaders = {
    'train': DataLoader(train_data, batch_size = batch_size, shuffle = True),
    'valid': DataLoader(valid_data, batch_size = batch_size, shuffle = True),
}

In [7]:
# Creating an instance of the PlantDiseaseDetector model
plant_model = PlantDiseaseDetector(num_classes=len(train_data.classes))

In [8]:
print(plant_model)

PlantDiseaseDetector(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(

## Training the Model

In [9]:
#setting the loss function
loss_fn = nn.CrossEntropyLoss()

# setting the optimizer
optimizer = optim.Adam(plant_model.parameters(), lr=0.001)

In [10]:
# setting the device for device-agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [11]:
# putting the model to device
plant_model.to(device)

PlantDiseaseDetector(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(

In [12]:
# Training loop
epochs = 15

for epoch in range(epochs):
    train_loss = 0.0
    valid_loss = 0.0

    train_correct = 0
    train_total = 0
    valid_correct = 0
    valid_total = 0

    plant_model.train()
    # getting data from generators
    for inputs, labels in data_loaders['train']:
        # putting data to device
        inputs, labels = inputs.to(device), labels.to(device)
        labels_preds = plant_model(inputs)

        loss = loss_fn(labels_preds, labels)

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        train_loss  += loss.item()

        # Calculate training accuracy
        _, predicted = torch.max(labels_preds.data, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()

    plant_model.eval()
    with torch.inference_mode():
        for inputs, labels in data_loaders['valid']:
            inputs, labels = inputs.to(device), labels.to(device)
            labels_preds_valid = plant_model(inputs)
            loss_valid = loss_fn(labels_preds_valid, labels)
            valid_loss += loss_valid.item()

            # Calculate validation accuracy
            _, predicted = torch.max(labels_preds_valid.data, 1)
            valid_total += labels.size(0)
            valid_correct += (predicted == labels).sum().item()


    # Calculate and print accuracies
    train_accuracy = 100 * train_correct / train_total
    valid_accuracy = 100 * valid_correct / valid_total

    print(f"Epoch {epoch+1}/{epochs}.. "
          f"Train loss: {train_loss/len(data_loaders['train']):.3f}.. "
          f"Validation loss: {valid_loss/len(data_loaders['valid']):.3f}.. "
          f"Train Accuracy: {train_accuracy:.3f}%.. "
          f"Validation Accuracy: {valid_accuracy:.3f}%")


Epoch 1/15.. Train loss: 0.335.. Validation loss: 0.136.. Train Accuracy: 91.115%.. Validation Accuracy: 96.085%
Epoch 2/15.. Train loss: 0.175.. Validation loss: 0.108.. Train Accuracy: 94.492%.. Validation Accuracy: 96.813%
Epoch 3/15.. Train loss: 0.153.. Validation loss: 0.109.. Train Accuracy: 95.078%.. Validation Accuracy: 97.001%
Epoch 4/15.. Train loss: 0.130.. Validation loss: 0.094.. Train Accuracy: 95.714%.. Validation Accuracy: 97.251%
Epoch 5/15.. Train loss: 0.122.. Validation loss: 0.103.. Train Accuracy: 95.967%.. Validation Accuracy: 97.388%
Epoch 6/15.. Train loss: 0.111.. Validation loss: 0.092.. Train Accuracy: 96.323%.. Validation Accuracy: 97.439%
Epoch 7/15.. Train loss: 0.106.. Validation loss: 0.108.. Train Accuracy: 96.367%.. Validation Accuracy: 97.308%
Epoch 8/15.. Train loss: 0.097.. Validation loss: 0.076.. Train Accuracy: 96.700%.. Validation Accuracy: 97.786%
Epoch 9/15.. Train loss: 0.094.. Validation loss: 0.090.. Train Accuracy: 96.822%.. Validation A

In [13]:
# Saving the entire model
torch.save(plant_model, 'trained_models/model_full_v2.pth')

In [14]:
# saving dictionary
torch.save(plant_model.state_dict(), 'trained_models/plant_model_v3.pth')

In [16]:
# Saving the classes for the model for future use
import pickle

classes = train_data.classes
classes

['Apple___Apple_scab',
 'Apple___Black_rot',
 'Apple___Cedar_apple_rust',
 'Apple___healthy',
 'Blueberry___healthy',
 'Cherry_(including_sour)___Powdery_mildew',
 'Cherry_(including_sour)___healthy',
 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
 'Corn_(maize)___Common_rust_',
 'Corn_(maize)___Northern_Leaf_Blight',
 'Corn_(maize)___healthy',
 'Grape___Black_rot',
 'Grape___Esca_(Black_Measles)',
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
 'Grape___healthy',
 'Orange___Haunglongbing_(Citrus_greening)',
 'Peach___Bacterial_spot',
 'Peach___healthy',
 'Pepper,_bell___Bacterial_spot',
 'Pepper,_bell___healthy',
 'Potato___Early_blight',
 'Potato___Late_blight',
 'Potato___healthy',
 'Raspberry___healthy',
 'Soybean___healthy',
 'Squash___Powdery_mildew',
 'Strawberry___Leaf_scorch',
 'Strawberry___healthy',
 'Tomato___Bacterial_spot',
 'Tomato___Early_blight',
 'Tomato___Late_blight',
 'Tomato___Leaf_Mold',
 'Tomato___Septoria_leaf_spot',
 'Tomato___Spider_mites Two-spotted_

In [17]:
for index,value in enumerate(classes):
    print(f'{index}-{value}')

0-Apple___Apple_scab
1-Apple___Black_rot
2-Apple___Cedar_apple_rust
3-Apple___healthy
4-Blueberry___healthy
5-Cherry_(including_sour)___Powdery_mildew
6-Cherry_(including_sour)___healthy
7-Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot
8-Corn_(maize)___Common_rust_
9-Corn_(maize)___Northern_Leaf_Blight
10-Corn_(maize)___healthy
11-Grape___Black_rot
12-Grape___Esca_(Black_Measles)
13-Grape___Leaf_blight_(Isariopsis_Leaf_Spot)
14-Grape___healthy
15-Orange___Haunglongbing_(Citrus_greening)
16-Peach___Bacterial_spot
17-Peach___healthy
18-Pepper,_bell___Bacterial_spot
19-Pepper,_bell___healthy
20-Potato___Early_blight
21-Potato___Late_blight
22-Potato___healthy
23-Raspberry___healthy
24-Soybean___healthy
25-Squash___Powdery_mildew
26-Strawberry___Leaf_scorch
27-Strawberry___healthy
28-Tomato___Bacterial_spot
29-Tomato___Early_blight
30-Tomato___Late_blight
31-Tomato___Leaf_Mold
32-Tomato___Septoria_leaf_spot
33-Tomato___Spider_mites Two-spotted_spider_mite
34-Tomato___Target_Spot
35-Tom

In [None]:
# Save the classes to a file using pickle
with open('classes.pkl', 'wb') as f:
    pickle.dump(classes, f)