In [3]:
import torch
import torch.nn as nn
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.resnet18(weights=None)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 3)

# ðŸ‘‡ trucco per sbloccare DataParallel in modo sicuro
torch.serialization.add_safe_globals([torch.nn.parallel.DataParallel])

# Carica il modello salvato (anche se era DataParallel)
torch.serialization.add_safe_globals([torch.nn.parallel.DataParallel])
checkpoint = torch.load("../pytorch/checkpoints/resnet18_full.pth", map_location=device, weights_only=False)

# Se il file contiene direttamente un modello (non uno state_dict)
if isinstance(checkpoint, torch.nn.DataParallel) or isinstance(checkpoint, torch.nn.Module):
    print("Checkpoint Ã¨ un modello completo, non uno state_dict.")
    model = checkpoint.module if isinstance(checkpoint, torch.nn.DataParallel) else checkpoint
else:
    # caso classico: state_dict salvato
    from collections import OrderedDict
    state_dict = checkpoint
    if "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
    # togli 'module.' se presente
    new_state_dict = OrderedDict((k.replace("module.", ""), v) for k, v in state_dict.items())
    model.load_state_dict(new_state_dict)

model = model.to(device)
model.eval()
print("âœ… Modello caricato correttamente.")


Checkpoint Ã¨ un modello completo, non uno state_dict.
âœ… Modello caricato correttamente.


In [6]:
from PIL import Image
from torchvision import transforms
import torch

# Percorso della nuova immagine
img_path = "../data/patches/gsc/zanetti/BerolPhill1516_bzd_page-014patch_008.png"
img = Image.open(img_path).convert("L")  # grayscale

# Stesse trasformazioni usate in training
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

input_tensor = transform(img).unsqueeze(0).to(device)

with torch.no_grad():
    outputs = model(input_tensor)
    probs = torch.softmax(outputs, dim=1)
    pred_class = torch.argmax(probs, dim=1).item()

print("Predicted class:", pred_class)
print("Probabilities:", probs.cpu().numpy())

Predicted class: 0
Probabilities: [[0.7180325  0.2760529  0.00591472]]


In [8]:
import os
sorted(os.listdir("../data/patches/gsc"))

['albini', 'katelos', 'zanetti']