In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out

class MyResNet18_LeafSpecieClassification(nn.Module):
    def __init__(self, block, layers, num_classes=14):
        super(MyResNet18_LeafSpecieClassification, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion),
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlockHealth(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlockHealth, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out

class MyResNet18_LeafHealthClassification(nn.Module):
    def __init__(self, block, layers, num_classes):
        super(MyResNet18_LeafHealthClassification, self).__init__()

        self.in_channels = 32
        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 32, layers[0])
        self.layer2 = self._make_layer(block, 64, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 128, layers[2], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion),
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


In [12]:
leaf_species_model = None
apple_model = None
cherry_model = None
corn_model = None
grape_model = None
peach_model = None
pepperbell_model = None
potato_model = None
strawberry_model = None
tomato_model = None

species_dictionary = {
      0: 'apple',
      1: 'blueberry',
      2: 'grape',
      3: 'peach',
      4: 'potato',
      5: 'raspberry',
      6: 'soybean',
      7: 'tomato',
      8: 'cherry_(including_sour)',
      9: 'corn_(maize)',
      10: 'orange',
      11: 'pepper,_bell',
      12: 'squash',
      13: 'strawberry'
}

apple_dictionary = {0: 'apple_scab', 1: 'black_rot', 2: 'healthy', 3: 'cedar_apple_rust'}
cherry_dictionary = {0: 'healthy', 1: 'powdery_mildew'}
corn_dictionary = {0: 'cercospora_leaf_spot gray_leaf_spot', 1: 'common_rust_', 2: 'healthy', 3: 'northern_leaf_blight'}
grape_dictionary = {0: 'black_rot', 1: 'healthy', 2: 'esca_(black_measles)', 3: 'leaf_blight_(isariopsis_leaf_spot)'}
peach_dictionary = {0: 'healthy', 1: 'bacterial_spot'}
pepperbell_dictionary = {0: 'bacterial_spot', 1: 'healthy'}
potato_dictionary = {0: 'healthy', 1: 'late_blight', 2: 'early_blight'}
strawberry_dictionary = {0: 'healthy', 1: 'leaf_scorch'}
tomato_dictionary = {
    0: 'healthy',
    1: 'leaf_mold',
    2: 'bacterial_spot',
    3: 'early_blight',
    4: 'late_blight',
    5: 'septoria_leaf_spot',
    6: 'spider_mites two-spotted_spider_mite',
    7: 'target_spot',
    8: 'tomato_mosaic_virus',
    9: 'tomato_yellow_leaf_curl_virus'
}

In [13]:
import os

def load_leaf_species_classification_model():
    num_classes = 14
    model_path = '/content/drive/MyDrive/Colab Notebooks/dl/leafs_model/leaf_specie_clmodel'

    pretrain_model = MyResNet18_LeafSpecieClassification(BasicBlock, [2, 2, 2, 2], num_classes)
    pretrain_model.load_state_dict(torch.load(model_path))
    pretrain_model.to('cuda')

    return pretrain_model

def load_health_classification_model(model_name, num_classes):
    checkpoint_path = os.path.join('/content/drive/MyDrive/Colab Notebooks/dl/leafs_model', model_name)
    pretrain_model = MyResNet18_LeafHealthClassification(BasicBlockHealth, [2, 2, 2], num_classes)
    pretrain_model.load_state_dict(torch.load(checkpoint_path))
    pretrain_model.to('cuda')

    return pretrain_model

def load_apple_health_classification_model():
    pretrain_model = load_health_classification_model('apple_model', 4)
    return pretrain_model

def load_cherry_health_classification_model():
    pretrain_model = load_health_classification_model('cherry_model', 2)
    return pretrain_model

def load_corn_health_classification_model():
    pretrain_model = load_health_classification_model('corn_model', 4)
    return pretrain_model

def load_grape_health_classification_model():
    pretrain_model = load_health_classification_model('grape_model', 4)
    return pretrain_model

def load_peach_health_classification_model():
    pretrain_model = load_health_classification_model('peach_model', 2)
    return pretrain_model

def load_pepperbell_health_classification_model():
    pretrain_model = load_health_classification_model('pepperbell_model', 2)
    return pretrain_model

def load_potato_health_classification_model():
    pretrain_model = load_health_classification_model('potato_model', 3)
    return pretrain_model

def load_strawberry_health_classification_model():
    pretrain_model = load_health_classification_model('strawberry_model', 2)
    return pretrain_model

def load_tomato_health_classification_model():
    pretrain_model = load_health_classification_model('tomato_model', 10)
    return pretrain_model

In [14]:
leaf_species_model = load_leaf_species_classification_model()
apple_model = load_apple_health_classification_model()
cherry_model = load_cherry_health_classification_model()
corn_model = load_corn_health_classification_model()
grape_model = load_grape_health_classification_model()
peach_model = load_peach_health_classification_model()
pepper_bell_model = load_pepperbell_health_classification_model()
potato_model = load_potato_health_classification_model()
strawberry_model = load_strawberry_health_classification_model()
tomato_model = load_tomato_health_classification_model()

In [15]:
from PIL import Image
import torch
import torchvision.transforms as transforms

def get_health_classification_model(specie):
  health_model = None

  if specie == 'apple':
    health_model = apple_model
  elif specie == 'grape':
    health_model = grape_model
  elif specie == 'peach':
    health_model = peach_model
  elif specie == 'potato':
    health_model = potato_model
  elif specie == 'tomato':
    health_model = tomato_model
  elif specie == 'cherry_(including_sour)':
    health_model = cherry_model
  elif specie == 'corn_(maize)':
    health_model = corn_model
  elif specie == 'pepper,_bell':
    health_model = pepperbell_model
  elif specie == 'strawberry':
    health_model = strawberry_model

  return health_model


def get_health_classes_dictionary(specie):
  health_classes_dictionary = None

  if specie == 'apple':
    health_classes_dictionary = apple_dictionary
  elif specie == 'grape':
    health_classes_dictionary = grape_dictionary
  elif specie == 'peach':
    health_classes_dictionary = peach_dictionary
  elif specie == 'potato':
    health_classes_dictionary = potato_dictionary
  elif specie == 'tomato':
    health_classes_dictionary = tomato_dictionary
  elif specie == 'cherry_(including_sour)':
    health_classes_dictionary = cherry_dictionary
  elif specie == 'corn_(maize)':
    health_classes_dictionary = corn_dictionary
  elif specie == 'pepper,_bell':
    health_classes_dictionary = pepperbell_dictionary
  elif specie == 'strawberry':
    health_classes_dictionary = strawberry_dictionary

  return health_classes_dictionary


def classify_leaf_specie(image):

    leaf_species_model.eval()
    with torch.no_grad():
        output = leaf_species_model(image)

    probabilities = torch.nn.functional.softmax(output, dim=1)
    _, pred_idx = torch.max(probabilities, 1)
    predicted_specie = species_dictionary[pred_idx.item()]

    return predicted_specie

def classify_leaf_health(image, model, classes):
    model.eval()
    with torch.no_grad():
        output = model(image)

    probabilities = torch.nn.functional.softmax(output, dim=1)
    _, pred_idx = torch.max(probabilities, 1)

    predicted_health = classes[pred_idx.item()]

    return predicted_health

def classify_leaf_specie_and_health(image_path):

    transformations = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    image = Image.open(image_path).convert('RGB')
    transformed_image = transformations(image).unsqueeze(0).to('cuda')

    predicted_specie = classify_leaf_specie(transformed_image)
    predicted_health = ''

    health_classification_model = get_health_classification_model(predicted_specie)

    if health_classification_model != None:
      health_classes_dictionary = get_health_classes_dictionary(predicted_specie)
      predicted_health = classify_leaf_health(transformed_image, health_classification_model, health_classes_dictionary)

    return predicted_specie, predicted_health

In [16]:
import pandas as pd

base_path = '/content/drive/MyDrive/Colab Notebooks/dl/leafs-dataset/'
image_paths = [
        'apple___apple_scab/97d9b70c-15fc-4463-a7cd-129e7bccef1c___FREC_Scab 2910.JPG',
        'apple___black_rot/f15c02d9-9aa6-45b5-84fc-97e22f7fabaa___JR_FrgE.S 2831.JPG',
        'apple___healthy/bdcce69f-7520-48cc-b22b-c3ad33d5be71___RS_HL 7878.JPG',
        'blueberry___healthy/d2dec908-fb8f-464b-b427-8f112340d649___RS_HL 0570.JPG',
        'grape___black_rot/e2dca6ce-4b58-4850-b6c0-1f15189fe944___FAM_B.Rot 0431.JPG',
        'grape___healthy/2b4a92e2-a039-45eb-ad1a-ee9914dfeefe___Mt.N.V_HL 6177.JPG',
        'peach___healthy/95f38d4e-5e8d-4b72-ac83-bd7e794a3093___Rutg._HL 2498.JPG',
        'potato___healthy/ff700844-68ad-4e99-8427-58a39c07f817___RS_HL 1860.JPG',
        'potato___late_blight/1789b0d1-d850-4e04-8a20-bc3901d4ab0a___RS_LB 5251.JPG',
        'raspberry___healthy/1bf5e5e4-d00c-442c-b153-497aa3ed9278___Mary_HL 6240.JPG',
        'soybean___healthy/51720e52-cf2a-4818-889f-a7426e764421___RS_HL 2841.JPG',
        'tomato___healthy/67549f03-2c86-44b7-9730-32d4a23287f3___GH_HL Leaf 495.2.JPG',
        'tomato___leaf_mold/59be5fe5-77f9-4538-a3c9-8c3be20f662d___Crnl_L.Mold 8938.JPG',
        'apple___cedar_apple_rust/9ace7aaf-8950-43b5-bafb-2c63f8839c20___FREC_C.Rust 9867.JPG',
        'cherry_(including_sour)___healthy/f9236ca5-ea7c-4a87-bac0-396883f95ba6___JR_HL 9838.JPG',
        'cherry_(including_sour)___powdery_mildew/bda9d7d5-617a-4159-92f2-a30a05396091___FREC_Pwd.M 4924.JPG',
        'corn_(maize)___cercospora_leaf_spot gray_leaf_spot/9103b8e5-919c-4d08-a282-25176874769c___RS_GLSp 4653.JPG',
        'corn_(maize)___common_rust_/RS_Rust 1591.JPG',
        'corn_(maize)___healthy/b914d1d7-db9b-4db3-9360-1ba44beef18b___R.S_HL 8176 copy 2.jpg',
        'corn_(maize)___northern_leaf_blight/7220bfaa-b955-46c6-994b-00810d0e65b3___RS_NLB 3964.JPG',
        'grape___esca_(black_measles)/c1015ba9-4629-43eb-90a2-accfed27b787___FAM_B.Msls 1830.JPG',
        'grape___leaf_blight_(isariopsis_leaf_spot)/a77906c9-3b5d-4612-bbd7-0ce3f223882c___FAM_L.Blight 4876.JPG',
        'orange___haunglongbing_(citrus_greening)/28cc3cb0-6bf8-4476-a6fb-59c5af49ff89___UF.Citrus_HLB_Lab 1456.JPG',
        'peach___bacterial_spot/a0b7956c-841a-4f49-a163-c627c47fe43a___Rutg._Bact.S 1853.JPG',
        'pepper,_bell___bacterial_spot/a72dbf23-65d1-40c6-a7bc-82caed00c6d3___JR_B.Spot 3333.JPG',
        'pepper,_bell___healthy/53c14233-6e5f-4775-acfe-212000c81ffa___JR_HL 6009.JPG',
        'potato___early_blight/b76550de-8e3a-46f1-b06f-6bd4ed3dc8a5___RS_Early.B 8456.JPG',
        'squash___powdery_mildew/df266ee0-67e4-439f-b3d0-d00d8257aa55___MD_Powd.M 0865.JPG',
        'strawberry___healthy/4005fb13-0d7c-4a30-9ee3-73e9e4cee05e___RS_HL 1688.JPG',
        'strawberry___leaf_scorch/f7d50599-f99b-4b36-ada8-712428030a2e___RS_L.Scorch 0945.JPG',
        'tomato___bacterial_spot/144352ee-0f8d-44cc-9db1-c4f27eb5a00a___GCREC_Bact.Sp 3284.JPG',
        'tomato___early_blight/d53bdcff-1ea7-46c7-9d46-701d104bf130___RS_Erly.B 8353.JPG',
        'tomato___late_blight/01a68044-9c5b-4658-a944-6108c6862ce7___GHLB Leaf 2.1 Day 16.JPG',
        'tomato___septoria_leaf_spot/efb0fe31-821d-4259-8daf-495b04f0a0d1___JR_Sept.L.S 8503.JPG',
        'tomato___spider_mites two-spotted_spider_mite/a67b81b4-ba9a-48eb-99cd-be36350a787e___Com.G_SpM_FL 8906.JPG',
        'tomato___target_spot/40d81563-afde-4956-a67d-695d22449170___Com.G_TgS_FL 7993.JPG',
        'tomato___tomato_mosaic_virus/57b10f20-7819-40b1-924c-b484c42c515f___PSU_CG 2356.JPG',
        'tomato___tomato_yellow_leaf_curl_virus/a52cd19b-99e6-495e-bd59-c8695a0b3d17___YLCV_GCREC 5256.JPG']

data = []

for image_path in image_paths:
    real_specie, real_health = image_path.split('/')[0].split('___')
    predicted_specie, predicted_health = classify_leaf_specie_and_health(os.path.join(base_path, image_path))
    specie_correct = "SUCCESS" if real_specie == predicted_specie else "FAIL"
    health_correct = "SUCCESS" if real_health == predicted_health else "FAIL"

    data.append({
        'Image_Path': image_path,
        'Real_Specie': real_specie,
        'Predicted_Specie': predicted_specie,
        'Specie_Correct': specie_correct,
        'Real_Health': real_health,
        'Predicted_Health': predicted_health,
        'Health_Correct': health_correct
    })

predictions_df = pd.DataFrame(data)
predictions_df

Unnamed: 0,Image_Path,Real_Specie,Predicted_Specie,Specie_Correct,Real_Health,Predicted_Health,Health_Correct
0,apple___apple_scab/97d9b70c-15fc-4463-a7cd-129...,apple,squash,FAIL,apple_scab,,FAIL
1,apple___black_rot/f15c02d9-9aa6-45b5-84fc-97e2...,apple,"pepper,_bell",FAIL,black_rot,,FAIL
2,apple___healthy/bdcce69f-7520-48cc-b22b-c3ad33...,apple,apple,SUCCESS,healthy,healthy,SUCCESS
3,blueberry___healthy/d2dec908-fb8f-464b-b427-8f...,blueberry,blueberry,SUCCESS,healthy,,FAIL
4,grape___black_rot/e2dca6ce-4b58-4850-b6c0-1f15...,grape,grape,SUCCESS,black_rot,leaf_blight_(isariopsis_leaf_spot),FAIL
5,grape___healthy/2b4a92e2-a039-45eb-ad1a-ee9914...,grape,grape,SUCCESS,healthy,leaf_blight_(isariopsis_leaf_spot),FAIL
6,peach___healthy/95f38d4e-5e8d-4b72-ac83-bd7e79...,peach,peach,SUCCESS,healthy,healthy,SUCCESS
7,potato___healthy/ff700844-68ad-4e99-8427-58a39...,potato,potato,SUCCESS,healthy,healthy,SUCCESS
8,potato___late_blight/1789b0d1-d850-4e04-8a20-b...,potato,cherry_(including_sour),FAIL,late_blight,powdery_mildew,FAIL
9,raspberry___healthy/1bf5e5e4-d00c-442c-b153-49...,raspberry,grape,FAIL,healthy,healthy,SUCCESS
