In [1]:
import torchvision.models as models
from torchvision.datasets import ImageFolder
import torch.nn as nn
import torch
import os
from torchvision import transforms
import pandas as pd
from PIL import Image

In [2]:
BASE_PATH = "./"
TRAINED_DATA_DIR = os.path.join(BASE_PATH, 'DATASETS/merged_resized_pngs_splited_augmented/train')
TEST_DATA_DIR = os.path.join(BASE_PATH, 'DATASETS/merged_resized_pngs_splited_augmented/test')
NUM_CLASSES = len(os.listdir(os.path.join(BASE_PATH, f"{TRAINED_DATA_DIR}")))

print("number of classes: ", NUM_CLASSES)
print("trained data dir: ", TRAINED_DATA_DIR)


def load_model_from_local(model, model_path):
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model


transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_PATH = 'models/MobileNetV2_20240531_021014/best_model.pth'
mobilenetv2_base_model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)
mobilenetv2_base_model.classifier[1] = nn.Linear(mobilenetv2_base_model.last_channel, NUM_CLASSES)
mobilenetv2_model = load_model_from_local(mobilenetv2_base_model, MODEL_PATH)
mobilenetv2_model.eval()
mobilenetv2_model.to(device)

classes = os.listdir(os.path.join(BASE_PATH, TRAINED_DATA_DIR))
classes.sort()
print("number of classes: ", len(classes))

number of classes:  64
trained data dir:  ./DATASETS/merged_resized_pngs_splited_augmented/train
number of classes:  64


In [3]:
train_dataset = ImageFolder(TRAINED_DATA_DIR, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataset = ImageFolder(TEST_DATA_DIR, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

class_names = train_dataset.classes
plant_names = [class_name.split("__")[0] for class_name in class_names]

def find_labels_by_plant_name(plant_name, class_names):
    labels = []
    for i, class_name in enumerate(class_names):
        if class_name.split("__")[0] == plant_name:
            labels.append(i)
    return labels 

example_image_path_1 = os.path.join(BASE_PATH, '0c8432e0-0484-470c-a774-7cce596b9e64___JR_FrgE.S 2841.png')
example_image_path_2 = os.path.join(BASE_PATH, '0a5aacba-0363-4b71-9beb-30183982d415___FREC_Pwd.M 4919_1.png')
example_plant_name_1 = "Apple"
example_plant_name_2 = "Cherry"

example_labels = find_labels_by_plant_name(example_plant_name_2, class_names)
print("example labels: ", example_labels)

mobilenetv2_model.eval()
image = Image.open(example_image_path_2)
image = transform(image).unsqueeze(0)
image = image.to(device)
output = mobilenetv2_model(image)

# prune output
print("before prune: ", output)
output = output[:, example_labels]
print("after prune: ", output)
_, predicted_pruned_index  = torch.max(output, 1)
predicted_class_index = example_labels[predicted_pruned_index.item()]
print("predicted class index: ", predicted_class_index)
predicted_class = class_names[predicted_class_index]
print("predicted class: ", predicted_class)

example labels:  [4, 5]
before prune:  tensor([[ 4.1964e-02, -8.4229e-03,  7.2694e-01,  1.0631e+00, -1.0895e+00,
          1.0599e+01, -6.0295e-01,  6.2821e-01,  5.2467e-01,  2.8852e-01,
         -1.1933e+00,  5.8391e-01, -1.7345e+00,  1.0210e-01,  5.2058e-01,
         -1.5617e+00,  8.0694e-02,  9.7825e-01,  2.9813e-01,  6.3134e-01,
         -3.9957e-01, -8.3195e-01, -9.0069e-01,  1.2850e+00, -5.6000e-01,
         -5.2574e-01, -7.5184e-01,  1.6021e-01, -3.8565e-01, -1.4111e+00,
         -1.2495e+00, -1.3900e+00, -9.1956e-01, -5.5036e-01,  1.5790e+00,
         -1.2418e+00, -7.0023e-01, -5.2419e-01, -1.6087e-01, -6.1729e-01,
         -6.0234e-01, -6.7896e-01, -1.0629e+00, -1.3607e+00, -3.8894e-01,
          2.6982e-02,  6.6836e-01, -8.4422e-01, -2.2255e-01,  1.6911e+00,
          1.9547e-01, -1.6722e-02,  4.2677e-04,  3.6578e+00,  6.6889e-01,
          1.2079e-01, -3.2807e-01, -4.1721e-01, -3.4990e-02,  5.0905e-01,
          8.2462e-02, -8.4921e-01, -6.4979e-01, -7.8063e-01]], device='cu

In [6]:
# testing 
mobilenetv2_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = mobilenetv2_model(images)

        for i in range(outputs.size(0)):
            example_labels  = find_labels_by_plant_name(plant_names[labels[i].item()], class_names)
            output = outputs[i, example_labels]
            _, predicted = torch.max(output, 0)
            predicted_class_index = example_labels[predicted.item()]

            if predicted_class_index == labels[i].item():
                correct += 1
            total += 1
            
            
accuracy = correct / total
print(f'Accuracy: {accuracy * 100:.2f}%')

Accuracy: 96.75%
