In [41]:
import os
import random
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image

In [42]:
# ======================
# CONFIG
# ======================
DATA_DIR = "../cnn_dataset/test"
MODEL_PATH = "../resnet18_best_finetuned.pth"
IMG_SIZE = 224
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(DEVICE)

cuda


In [43]:
# ======================
# Transforms (same as training)
# ======================
tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225])
])

# ======================
# Classes
# ======================
classes = sorted(os.listdir(DATA_DIR))
NUM_CLASSES = len(classes)

print("Classes:", classes)

Classes: ['1509', 'IRRI-6', 'Super White']


In [44]:
# ======================
# Load model
# ======================
model = models.resnet18(weights=None)

in_features = model.fc.in_features
model.fc = nn.Linear(in_features, NUM_CLASSES)

model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model = model.to(DEVICE)
model.eval()

print("âœ… Model loaded")

âœ… Model loaded


  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))


In [45]:

# ======================
# Predict 1 random image per class
# ======================
softmax = nn.Softmax(dim=1)

for class_name in classes:

    class_path = os.path.join(DATA_DIR, class_name)

    imgs = os.listdir(class_path)
    img_name = random.choice(imgs)

    img_path = os.path.join(class_path, img_name)

    # load image
    img = Image.open(img_path).convert("RGB")
    x = tfms(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        outputs = model(x)
        probs = softmax(outputs)

    conf, pred_idx = torch.max(probs, 1)

    pred_class = classes[pred_idx.item()]
    conf = conf.item()

    print("\n----------------------------")
    print(f"Image : {img_name}")
    print(f"True  : {class_name}")
    print(f"Pred  : {pred_class}")
    print(f"Conf  : {conf:.4f}")

print("\nðŸ”¥ Testing finished")



----------------------------
Image : 1509_1799.jpg
True  : 1509
Pred  : 1509
Conf  : 0.8813

----------------------------
Image : IRRI-6_3739.jpg
True  : IRRI-6
Pred  : IRRI-6
Conf  : 0.9819

----------------------------
Image : SUPER WHITE_6839.jpg
True  : Super White
Pred  : Super White
Conf  : 0.9521

ðŸ”¥ Testing finished
