In [None]:
import torch
from torchvision import models
import torch.nn as nn
import os
import kagglehub

checkpoint_path = "../models/efficientnet_b4_best.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model = models.efficientnet_b4(pretrained=False)
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, 1)
model = model.to(device)


model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

print(f"Loaded weights from: {checkpoint_path}")

In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transformation_for_valntest = transforms.Compose([transforms.Resize((380, 380)),  transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transformation_for_valntest = transforms.Compose([
    transforms.Resize((260, 260)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

test_path = "../Dataset/Test"
test_dataset = datasets.ImageFolder(root=test_path, transform=transformation_for_valntest)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.float().to(device)
        outputs = model(inputs).squeeze()

        probs = torch.sigmoid(outputs)
        preds = (probs > 0.5).float()

        all_preds.extend(preds.cpu().numpy())
        all_probs.extend(probs.cpu().numpy()) 
        all_labels.extend(labels.cpu().numpy())

# convert to int for metric computation
all_preds = [int(p) for p in all_preds]
all_labels = [int(l) for l in all_labels]

accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, zero_division=0)
recall = recall_score(all_labels, all_preds, zero_division=0)
f1 = f1_score(all_labels, all_preds, zero_division=0)
auc = roc_auc_score(all_labels, all_probs)

print(f"Accuracy:  {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1 Score:  {f1:.4f}")
print(f"AUC-ROC:   {auc:.4f}")