In [2]:
# evaluate.py
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import argparse
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def eval_model(model_path, data_root, batch_size=32):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # transforms
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    # dataset + dataloader
    ds = datasets.ImageFolder(root=data_root, transform=transform)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=4)

    # model
    model = models.mobilenet_v2(pretrained=False)
    in_f = model.classifier[1].in_features
    model.classifier[1] = torch.nn.Linear(in_f, 2)   # two classes
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device).eval()

    y_true, y_pred = [], []
    with torch.no_grad():
        for imgs, labels in dl:
            imgs = imgs.to(device)
            logits = model(imgs)
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            y_true.extend(labels.numpy().tolist())
            y_pred.extend(preds.tolist())

    # metrics
    print(classification_report(y_true, y_pred, digits=4, target_names=ds.classes))
    cm = confusion_matrix(y_true, y_pred)
    print("Confusion matrix:\n", cm)

    # plot confusion matrix
    plt.figure(figsize=(6,6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=ds.classes, yticklabels=ds.classes)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()

    return y_true, y_pred

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='checkpoints/best_model.pth')
    parser.add_argument('--data_root', default='data/clf-data')  # matches your folder
    args = parser.parse_known_args()
    eval_model(args.model, args.data_root)

AttributeError: 'tuple' object has no attribute 'model'