## Load model

In [None]:
import torch
from torchvision import transforms
from PIL import Image

# --- Config ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 2
checkpoint_path = "best_resnet18.pth"

import torch
import torch.nn as nn
from torchvision import models

class ResNet18Classifier(nn.Module):
    def __init__(self, num_classes: int = 2, dropout: float = 0.3):
        super().__init__()
        # load pretrained backbone
        self.backbone = models.resnet18(pretrained=True)
        in_features = self.backbone.fc.in_features
        
        # replace the FC head
        self.backbone.fc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)



# --- Recreate model and load weights ---
model = ResNet18Classifier(num_classes=num_classes, dropout=0.3)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.to(device)
model.eval()

# --- Label map (same as during training) ---
idx2label = {0: "NONPD", 1: "PD"}

# --- Inference on a single image ---
def predict_image(img_path: str):
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    img = Image.open(img_path).convert("RGB")
    tensor = preprocess(img).unsqueeze(0).to(device)  # add batch dim
    with torch.no_grad():
        logits = model(tensor)
        probs = torch.softmax(logits, dim=1)
        conf, pred = probs.max(dim=1)
    return idx2label[pred.item()], conf.item()

# Example:
label, confidence = predict_image("exampleSpectrogram/NonPD Spectrogram 31.png")
print(f"Predicted: {label} ({confidence*100:.1f}%)")

Predicted: PD (54.1%)
