In [1]:
import torchvision.models as models
from torchvision.datasets import ImageFolder
import torch.nn as nn
import torch
import os
from torchvision import transforms
import pandas as pd
from PIL import Image

In [2]:
BASE_PATH = "./"
TRAINED_DATA_DIR = os.path.join(BASE_PATH, 'DATASETS/merged_resized_pngs_splited_augmented/train')
TEST_DATA_DIR = os.path.join(BASE_PATH, 'DATASETS/merged_resized_pngs_splited_augmented/test')
NUM_CLASSES = len(os.listdir(os.path.join(BASE_PATH, f"{TRAINED_DATA_DIR}")))

print("number of classes: ", NUM_CLASSES)
print("trained data dir: ", TRAINED_DATA_DIR)


def load_model_from_local(model, model_path):
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model


transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_PATH = 'models/(trained on augmented)MobileNetV2_20240531_021014/best_model.pth'
mobilenetv2_base_model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)
mobilenetv2_base_model.classifier[1] = nn.Linear(mobilenetv2_base_model.last_channel, NUM_CLASSES)
mobilenetv2_model = load_model_from_local(mobilenetv2_base_model, MODEL_PATH)
mobilenetv2_model.eval()
mobilenetv2_model.to(device)

classes = os.listdir(os.path.join(BASE_PATH, TRAINED_DATA_DIR))
classes.sort()
print("number of classes: ", len(classes))

number of classes:  64
trained data dir:  ./DATASETS/merged_resized_pngs_splited_augmented/train
number of classes:  64


In [4]:
train_dataset = ImageFolder(TRAINED_DATA_DIR, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataset = ImageFolder(TEST_DATA_DIR, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

class_names = train_dataset.classes
plant_names = [class_name.split("__")[0] for class_name in class_names]

def find_labels_by_plant_name(plant_name, class_names):
    labels = []
    for i, class_name in enumerate(class_names):
        if class_name.split("__")[0] == plant_name:
            labels.append(i)
    return labels 

"""example_image_path_1 = os.path.join(BASE_PATH, '0c8432e0-0484-470c-a774-7cce596b9e64___JR_FrgE.S 2841.png')
example_image_path_2 = os.path.join(BASE_PATH, '0a5aacba-0363-4b71-9beb-30183982d415___FREC_Pwd.M 4919_1.png')
example_plant_name_1 = "Apple"
example_plant_name_2 = "Cherry"

example_labels = find_labels_by_plant_name(example_plant_name_2, class_names)
print("example labels: ", example_labels)

mobilenetv2_model.eval()
image = Image.open(example_image_path_2)
image = transform(image).unsqueeze(0)
image = image.to(device)
output = mobilenetv2_model(image)

# prune output
print("before prune: ", output)
output = output[:, example_labels]
print("after prune: ", output)
_, predicted_pruned_index  = torch.max(output, 1)
predicted_class_index = example_labels[predicted_pruned_index.item()]
print("predicted class index: ", predicted_class_index)
predicted_class = class_names[predicted_class_index]
print("predicted class: ", predicted_class)"""

'example_image_path_1 = os.path.join(BASE_PATH, \'0c8432e0-0484-470c-a774-7cce596b9e64___JR_FrgE.S 2841.png\')\nexample_image_path_2 = os.path.join(BASE_PATH, \'0a5aacba-0363-4b71-9beb-30183982d415___FREC_Pwd.M 4919_1.png\')\nexample_plant_name_1 = "Apple"\nexample_plant_name_2 = "Cherry"\n\nexample_labels = find_labels_by_plant_name(example_plant_name_2, class_names)\nprint("example labels: ", example_labels)\n\nmobilenetv2_model.eval()\nimage = Image.open(example_image_path_2)\nimage = transform(image).unsqueeze(0)\nimage = image.to(device)\noutput = mobilenetv2_model(image)\n\n# prune output\nprint("before prune: ", output)\noutput = output[:, example_labels]\nprint("after prune: ", output)\n_, predicted_pruned_index  = torch.max(output, 1)\npredicted_class_index = example_labels[predicted_pruned_index.item()]\nprint("predicted class index: ", predicted_class_index)\npredicted_class = class_names[predicted_class_index]\nprint("predicted class: ", predicted_class)'

In [5]:
from sklearn.metrics import precision_recall_fscore_support

true_labels = []
predicted_labels = []

# testing 
mobilenetv2_model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = mobilenetv2_model(images)

        for i in range(outputs.size(0)):
            example_labels  = find_labels_by_plant_name(plant_names[labels[i].item()], class_names)
            output = outputs[i, example_labels]
            _, predicted = torch.max(output, 0)
            predicted_class_index = example_labels[predicted.item()]

            true_labels.append(labels[i].item())
            predicted_labels.append(predicted_class_index)
            

In [7]:
# Calculate accuracy
accuracy = sum(1 for x, y in zip(true_labels, predicted_labels) if x == y) / len(true_labels)
print(f'Accuracy: {accuracy * 100:.4f}%')

# Calculate precision, recall, and F1 score
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predicted_labels, average='weighted')
print(f'Precision: {precision * 100:.4f}%')
print(f'Recall: {recall * 100:.4f}%')
print(f'F1 Score: {f1 * 100:.4f}%')

Accuracy: 96.7454%
Precision: 96.7799%
Recall: 96.7454%
F1 Score: 96.7405%


In [None]:
Accuracy: 96.7454%
Precision: 96.7799%
Recall: 96.7454%
F1 Score: 96.7405%