In [2]:
import torch
import torch.nn as nn
from torchvision import models

# 1. Carica modello pre-addestrato su ImageNet
model = models.resnet18(pretrained=True)  # oppure resnet50

# 2. Blocca i pesi della parte convoluzionale
for param in model.parameters():
    param.requires_grad = False

# 3. Sostituisci il classifier finale con uno adatto al tuo numero di classi
class_names = ['zanetti', 'albini', 'katelos']
num_classes = len(class_names)  # es. 2 per ants/bees
model.fc = nn.Linear(model.fc.in_features, num_classes)

# 4. Sposta modello su device (GPU se disponibile)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# 5. Definisci loss e ottimizzatore
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)


In [5]:
import os
import random
from PIL import Image
import torch
from torch.utils.data import Dataset

class ThreeChannelDataset(Dataset):
    def __init__(self, base_dir, classes, types, transform=None, split="train", split_ratio=0.8, seed=42):
        """
        base_dir: cartella dove ci sono bzd/ctr/gsc
        classes: lista di classi es. ["zanetti", "albini", "katelos"]
        types: lista dei tipi di immagine ["bzd", "ctr", "gsc"]
        split: "train" o "val"
        split_ratio: percentuale di train
        """
        self.samples = []  # (manuscript_path, class_idx)
        self.base_dir = base_dir
        self.classes = classes
        self.types = types
        self.transform = transform

        random.seed(seed)

        for cls_idx, cls in enumerate(classes):
            manuscripts = os.listdir(os.path.join(base_dir, types[0], cls))
            for manuscript in manuscripts:
                # Lista immagini comuni a tutti i tipi
                imgs = set(os.listdir(os.path.join(base_dir, types[0], cls, manuscript)))
                for t in types[1:]:
                    imgs &= set(os.listdir(os.path.join(base_dir, t, cls, manuscript)))
                imgs = sorted(list(imgs))
                
                n_train = int(len(imgs) * split_ratio)
                if split == "train":
                    selected_imgs = imgs[:n_train]
                else:
                    selected_imgs = imgs[n_train:]

                for img in selected_imgs:
                    self.samples.append((manuscript, cls_idx, img))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        manuscript, cls_idx, img_name = self.samples[idx]
        channels = []
        for t in self.types:
            img_path = os.path.join(self.base_dir, t, self.classes[cls_idx], manuscript, img_name)
            img = Image.open(img_path).convert("L")  # scala di grigi
            if self.transform:
                img = self.transform(img)
            channels.append(img)
        # Combina in tensore 3xHxW
        combined = torch.stack(channels, dim=0)
        return combined, cls_idx
