In [1]:
import torch
import torch.nn as nn
import os
from torchvision import models, transforms
from PIL import Image


In [2]:
# =========================
# SETTINGS
# =========================
IMG_SIZE = 224
print(torch.cuda.is_available())
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_PATH = "../mobilenet_v2_finetuned_best.pth"   # your saved model
DATA_PATH = "../cnn_dataset/test"  # test folder or single image

THRESHOLD = 0.6  # unknown threshold

CLASS_NAMES = ['1509', 'IRRI-6', 'Super White']  # must match training order
NUM_CLASSES = len(CLASS_NAMES)


True


In [3]:
# =========================
# Transforms (same as val!)
# =========================
tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225])
])


In [4]:
# =========================
# Load model
# =========================
model = models.mobilenet_v2(weights=None)
model.classifier[1] = nn.Linear(model.last_channel, NUM_CLASSES)

model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model = model.to(DEVICE)
model.eval()

print("âœ… Model loaded")

âœ… Model loaded


  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))


In [5]:
# =========================
# Predict function
# =========================
def predict_image(img_path):

    img = Image.open(img_path).convert("RGB")
    x = tfms(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        outputs = model(x)
        probs = torch.softmax(outputs, dim=1)

    probs = probs.cpu().numpy()[0]

    max_prob = probs.max()
    pred_idx = probs.argmax()

    if max_prob < THRESHOLD:
        label = "Unknown"
    else:
        label = CLASS_NAMES[pred_idx]

    return label, max_prob, probs

In [6]:
# =========================
# Test single OR folder
# =========================
if os.path.isfile(DATA_PATH):

    label, conf, probs = predict_image(DATA_PATH)

    print("\nPrediction:", label)
    print("Confidence:", conf)
    print("All probs:", probs)

else:

    print("\nTesting folder...\n")

    correct = 0
    total = 0

    # for class_name in os.listdir(DATA_PATH):

    #     class_folder = os.path.join(DATA_PATH, class_name)

    #     if not os.path.isdir(class_folder):
    #         continue

    #     for img_name in os.listdir(class_folder):

    #         img_path = os.path.join(class_folder, img_name)

    #         pred, conf, _ = predict_image(img_path)

    #         total += 1
    #         if pred == class_name:
    #             correct += 1

    #         print(f"{img_name} â†’ {pred} ({conf:.2f})")

    # acc = correct / total
    # print(f"\nðŸ”¥ Test Accuracy: {acc:.4f}")



    import random

    print("\nTesting 1 image per class...\n")

    correct = 0
    total = 0

    for class_name in CLASS_NAMES:
        print(class_name)
        img_folder = os.path.join(DATA_PATH, class_name)

        if not os.path.isdir(img_folder):
            continue

        imgs = [f for f in os.listdir(img_folder)
                if f.lower().endswith((".jpg", ".png", ".jpeg"))]

        if len(imgs) == 0:
            continue

        # pick only ONE image
        chosen = random.choice(imgs)

        img_path = os.path.join(img_folder, chosen)

        pred, conf, _ = predict_image(img_path)

        total += 1
        if pred == class_name:
            correct += 1

        print(f"Class: {class_name} | File: {chosen} â†’ {pred} ({conf:.2f})")

    acc = correct / total
    print(f"\nðŸ”¥ Accuracy (1 per class): {acc:.4f}")




Testing folder...


Testing 1 image per class...

1509
Class: 1509 | File: 1509_3607.jpg â†’ 1509 (0.96)
IRRI-6
Class: IRRI-6 | File: IRRI-6_5083.jpg â†’ IRRI-6 (0.99)
Super White
Class: Super White | File: SUPER WHITE_6847.jpg â†’ Super White (0.99)

ðŸ”¥ Accuracy (1 per class): 1.0000
