In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import glob
import random
import os
import cv2
from tqdm import tqdm
import numpy as np
from PIL import Image, ImageOps
import timm

# Cartella sorgente (con sottocartelle chiamate x,y)
source_dir = './ruote_catalogate_def/'


# Carica percorsi e label (x,y) dal nome della cartella
def load_image_paths_and_labels(base_dir):
    classes = sorted(os.listdir(base_dir))  # es. ['0_1', '1_0']
    image_paths = []
    labels = []

    for cls in classes:
        cls_path = os.path.join(base_dir, cls)
        if cls.startswith('.'):
            continue
        if not os.path.isdir(cls_path):
            continue
        # Estrai x,y
        x, y = cls.split(',')
        label = (int(x)/3, int(y)/3)
        
        # Cerca immagini
        for img_file in glob.glob(os.path.join(cls_path, '*.*')):
            image_paths.append(img_file)
            labels.append(label)

            # Visualizza l'immagine come la vedresti sul PC
            #img_pre = Image.open(img_file)
            #img_pre.show()
            #img = ImageOps.exif_transpose(img_pre)  # ruota automaticamente in base all'EXIF
            #img.show()  # apre l'immagine in una finestra
            #return

    return image_paths, labels

# Carica tutto da ruote_catalogate_def
all_paths, all_labels = load_image_paths_and_labels(source_dir)

# Mischia mantenendo accoppiamenti
combined = list(zip(all_paths, all_labels))
random.shuffle(combined)
split_idx = int(len(combined) * 0.7)
train_data = combined[:split_idx]
test_data  = combined[split_idx:]

# Separa paths e labels
train_paths, train_labels = zip(*train_data)
test_paths,  test_labels  = zip(*test_data)
train_paths, train_labels = list(train_paths), list(train_labels)
test_paths,  test_labels  = list(test_paths),  list(test_labels)

class WheelDataset(Dataset):
    # salva quelle variabili come attributi della classe
    def __init__(self, image_paths, labels, transform=None, edge_transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.edge_transform = edge_transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = cv2.imread(self.image_paths[idx]) # legge l'immagine
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # converte BGR a RGB, fondamentale per i modelli torch

        # Edge detection
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # converte a scala di grigi
        #gray_rgb = np.stack([gray]*3, axis=-1)
        edges = cv2.Canny(gray, 20, 60) # rilevamento bordi con Canny
        edges_rgb = np.stack([edges]*3, axis=-1)  # convert to 3 channels per compatibilitÃ  con edge_transform

        img = Image.fromarray(img) # converte a PIL Image
        img = ImageOps.exif_transpose(img)
        #gray_rgb = Image.fromarray(gray_rgb)
        edges_rgb = Image.fromarray(edges_rgb)
        edges_rgb = ImageOps.exif_transpose(edges_rgb)

        if self.transform:
            img = self.transform(img)
            #gray_rgb = self.transform(gray_rgb)
        if self.edge_transform:
            edges_rgb = self.edge_transform(edges_rgb)

        # Concatenate RGB image and edge image along channel dimension
        combined = torch.cat((img, edges_rgb), dim=0) # concatenazione lungo la dimensione dei canali, avremo 6 canali

        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return combined, label

class CenterPadCrop:
    """Ridimensiona mantenendo aspect ratio e aggiunge padding centrato per ottenere quadrato finale"""
    def __init__(self, final_size=256):
        self.final_size = final_size

    def __call__(self, img: Image.Image):
        # --- Resize lato lungo a final_size ---
        w, h = img.size
        if h > w:
            new_h = self.final_size
            new_w = int(w * self.final_size / h)
        else:
            new_w = self.final_size
            new_h = int(h * self.final_size / w)
        img = img.resize((new_w, new_h), resample=Image.BILINEAR)

        # --- Calcola padding per rendere quadrato centrato ---
        pad_left = (self.final_size - new_w) // 2
        pad_right = self.final_size - new_w - pad_left
        pad_top = (self.final_size - new_h) // 2
        pad_bottom = self.final_size - new_h - pad_top

        # --- Applica padding ---
        img = transforms.functional.pad(img, padding=(pad_left, pad_top, pad_right, pad_bottom), fill=0)

        return img

# --- Uso nella pipeline di trasformazioni ---
transform_rgb = transforms.Compose([
    CenterPadCrop(final_size=224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

edge_transform = transforms.Compose([
    CenterPadCrop(final_size=224),
    transforms.ToTensor(),
])

train_dataset = WheelDataset(train_paths, train_labels, transform=transform_rgb, edge_transform=edge_transform)
test_dataset  = WheelDataset(test_paths, test_labels, transform=transform_rgb, edge_transform=edge_transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
test_loader  = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Carichiamo il modello vit
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# Cambiamo il primo layer per 6 canali (RGB + Edge RGB), mantenendo le altre impostazioni
orig_conv = model.patch_embed.proj
model.patch_embed.proj = nn.Conv2d(
    in_channels=6,
    out_channels=orig_conv.out_channels,
    kernel_size=orig_conv.kernel_size,
    stride=orig_conv.stride,
    padding=orig_conv.padding,
    bias=orig_conv.bias is not None
)

# Output layer per un valore regressivo
model.head = nn.Linear(model.head.in_features, 2)

class ViTRegressor(nn.Module):
    def __init__(self, vit_model):
        super().__init__()
        self.vit = vit_model

    def forward(self, x):
        x = self.vit(x)
        x = torch.sigmoid(x)  # output tra 0 e 1
        return x

model = ViTRegressor(model).to(device)

In [None]:
# --- Fine-tuning setup per regressione su due valori ---
criterion = nn.MSELoss()           
optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 1
test_results = []    # qui salveremo le informazioni dell'ultima epoca

for epoch in range(num_epochs):

    # -------- Training --------
    model.train()
    running_loss = 0.0

    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        
        inputs = inputs.to(device)
        labels = labels.float().to(device)      # shape: [batch, 2]
        optimizer.zero_grad()
        outputs = model(inputs)                 # shape: [batch, 2]

        loss = criterion(outputs, labels)       # MSE su entrambi valori
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
    
    epoch_loss = running_loss / len(train_dataset)
    print(f"Epoch {epoch+1}/{num_epochs} - Training Loss: {epoch_loss:.4f}")


    # -------- Evaluation --------
    model.eval()
    running_loss_test = 0.0
    is_last_epoch = (epoch == num_epochs - 1)

    print("\n--- VALORI PREDETTI NEL TEST SET ---")

    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(
            tqdm(test_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Testing")
        ):

            inputs = inputs.to(device)
            labels = labels.float().to(device)        # shape: [1, 2] nel test

            outputs = model(inputs)                   # shape: [1, 2]

            pred_x, pred_y = outputs[0].cpu().numpy()
            real_x, real_y = labels[0].cpu().numpy()

            # Stampa predizione e valori reali
            #if idx % 0 == 0:
            #    print(f"Ruota idx={idx} | "
            #        f"Predetto=({pred_x:.4f}, {pred_y:.4f}) | "
            #        f"Reale=({real_x:.4f}, {real_y:.4f})")

            loss = criterion(outputs, labels)
            running_loss_test += loss.item() * inputs.size(0)

            if is_last_epoch:
                test_results.append({
                    "path": test_paths[idx],
                    "x_pred": float(pred_x),
                    "y_pred": float(pred_y),
                    "x_real": float(real_x),
                    "y_real": float(real_y),
                })

    test_loss = running_loss_test / len(test_dataset)
    print(f"Epoch {epoch+1}/{num_epochs} - Test MSE: {test_loss:.4f}\n")
