In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import glob
import random
import os
import cv2
from tqdm import tqdm
import numpy as np
from PIL import Image, ImageOps
import timm
import pandas as pd

# ---- 1) Cartella dataset ----
source_dir = './ruote_catalogate_def/'  # aggiorna se necessario

# ---- 2) Funzione per caricare percorsi e label ----
def load_image_paths_and_labels(base_dir):
    if not os.path.exists(base_dir):
        raise FileNotFoundError(f"La cartella {base_dir} non esiste!")
    
    classes = sorted(os.listdir(base_dir))
    image_paths = []
    labels = []

    for cls in classes:
        cls_path = os.path.join(base_dir, cls)
        if cls.startswith('.') or not os.path.isdir(cls_path):
            continue
        try:
            x, y = cls.split('_')
        except:
            continue
        label = (int(x)/3, int(y)/3)

        for img_file in glob.glob(os.path.join(cls_path, '*.*')):
            image_paths.append(img_file)
            labels.append(label)

    return image_paths, labels

# ---- 3) Caricamento dataset ----
all_paths, all_labels = load_image_paths_and_labels(source_dir)

# Shuffle e split train/test
combined = list(zip(all_paths, all_labels))
random.shuffle(combined)
split_idx = int(len(combined) * 0.7)
train_data = combined[:split_idx]
test_data  = combined[split_idx:]

train_paths, train_labels = zip(*train_data)
test_paths,  test_labels  = zip(*test_data)
train_paths, train_labels = list(train_paths), list(train_labels)
test_paths,  test_labels  = list(test_paths),  list(test_labels)

# ---- 4) Dataset personalizzato ----
class WheelDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None, edge_transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.edge_transform = edge_transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = cv2.imread(self.image_paths[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Edge detection
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        edges = cv2.Canny(gray, 20, 60)
        edges_rgb = np.stack([edges]*3, axis=-1)

        img = Image.fromarray(img)
        img = ImageOps.exif_transpose(img)
        edges_rgb = Image.fromarray(edges_rgb)
        edges_rgb = ImageOps.exif_transpose(edges_rgb)

        if self.transform:
            img = self.transform(img)
        if self.edge_transform:
            edges_rgb = self.edge_transform(edges_rgb)

        combined = torch.cat((img, edges_rgb), dim=0)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return combined, label

# ---- 5) Trasformazioni ----
class CenterPadCrop:
    def __init__(self, final_size=224):
        self.final_size = final_size

    def __call__(self, img: Image.Image):
        w, h = img.size
        if h > w:
            new_h = self.final_size
            new_w = int(w * self.final_size / h)
        else:
            new_w = self.final_size
            new_h = int(h * self.final_size / w)
        img = img.resize((new_w, new_h), resample=Image.BILINEAR)

        pad_left = (self.final_size - new_w) // 2
        pad_right = self.final_size - new_w - pad_left
        pad_top = (self.final_size - new_h) // 2
        pad_bottom = self.final_size - new_h - pad_top

        img = transforms.functional.pad(img, padding=(pad_left, pad_top, pad_right, pad_bottom), fill=0)
        return img

transform_rgb = transforms.Compose([
    CenterPadCrop(final_size=224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

edge_transform = transforms.Compose([
    CenterPadCrop(final_size=224),
    transforms.ToTensor(),
])

# ---- 6) DataLoader ----
train_dataset = WheelDataset(train_paths, train_labels, transform=transform_rgb, edge_transform=edge_transform)
test_dataset  = WheelDataset(test_paths, test_labels, transform=transform_rgb, edge_transform=edge_transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
test_loader  = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- 7) MobileNetV3-Large con 6 canali input ----
model = timm.create_model('mobilenetv3_large_100', pretrained=True)

# Modifica primo conv per 6 canali
orig_conv = model.conv_stem
model.conv_stem = nn.Conv2d(
    in_channels=6,
    out_channels=orig_conv.out_channels,
    kernel_size=orig_conv.kernel_size,
    stride=orig_conv.stride,
    padding=orig_conv.padding,
    bias=orig_conv.bias is not None
)
with torch.no_grad():
    model.conv_stem.weight[:, :3] = orig_conv.weight
    model.conv_stem.weight[:, 3:] = orig_conv.weight

# Head finale regressione 2D
in_features = model.classifier.in_features
model.classifier = nn.Linear(in_features, 2)

class MobileNetV3Regressor(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model

    def forward(self, x):
        x = self.base(x)
        x = torch.sigmoid(x)
        return x

model = MobileNetV3Regressor(model).to(device)

# ---- 8) Loss e optimizer ----
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# ---- 9) Training loop per 1 epoca ----
num_epochs = 1
test_results = []

for epoch in range(num_epochs):
    # ----- Training -----
    model.train()
    running_loss = 0.0

    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        inputs, labels = inputs.to(device), labels.float().to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_dataset)
    print(f"Epoch {epoch+1}/{num_epochs} - Training Loss: {epoch_loss:.4f}")

    # ----- Validation/Test -----
    model.eval()
    running_loss_test = 0.0
    is_last_epoch = (epoch == num_epochs - 1)

    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(tqdm(test_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Testing")):
            inputs, labels = inputs.to(device), labels.float().to(device)
            outputs = model(inputs)

            pred_x, pred_y = outputs[0].cpu().numpy()
            real_x, real_y = labels[0].cpu().numpy()

            loss = criterion(outputs, labels)
            running_loss_test += loss.item() * inputs.size(0)

            if is_last_epoch:
                test_results.append({
                    "path": test_paths[idx],
                    "x_pred": float(pred_x),
                    "y_pred": float(pred_y),
                    "x_real": float(real_x),
                    "y_real": float(real_y),
                })

    test_loss = running_loss_test / len(test_dataset)
    print(f"Epoch {epoch+1}/{num_epochs} - Test MSE: {test_loss:.4f}\n")

# ---- 10) Salvataggio risultati test in CSV ----
df_test = pd.DataFrame(test_results)
df_test.to_csv("test_predictions_mobilenetv3.csv", index=False)
print(f"Salvati {len(df_test)} risultati in 'test_predictions_mobilenetv3.csv'")
torch.save(model.state_dict(), "regression_mobilenetv3_finetuned.pth")
