# Allenamento baseline ResNet50 per Traslazione + Rotazione (Pose Estimation)

Questo notebook permette di:
1. **Import e Setup**
2. **Carica dataset LineMOD**
3. **Inizializza PoseEstimator** (ResNet-50 backbone)
4. **Training**
5. **Valutazione su test set**

## 1. Import e Setup

In [None]:
import sys
from pathlib import Path
import torch
from torch import optim
import yaml
from pathlib import Path
import random
import numpy as np
import pandas as pd
from tqdm import tqdm

# Importa il config per usare path e file in altre cartelle
sys.path.insert(0, str(Path.cwd().parent))  # Aggiungi parent al path
from config import Config
from dataset.linemod_pose import create_pose_dataloaders
from utils.visualization import show_pose_samples, plot_training_validation_loss_from_csv, show_pose_samples_with_add, plot_add_per_class, plot_pinhole_error_per_class
from utils.training import train_pose_full
from utils.losses import PoseLossFull
from utils.transforms import quaternion_to_rotation_matrix_batch
from utils.metrics import compute_add_batch_full_pose, load_all_models, load_models_info
from models.pose_estimator_endtoend import PoseEstimator


## 2. Carica dataset LineMOD

In [None]:
# Carica i dataloader train, val e test
train_loader, val_loader, test_loader = create_pose_dataloaders(
    dataset_root=Config.LINEMOD_ROOT,
    batch_size=Config.POSE_BATCH_SIZE,
    crop_margin=Config.POSE_CROP_MARGIN,
    output_size=Config.POSE_IMAGE_SIZE,
    num_workers=Config.NUM_WORKERS
)

print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)} | Test batches: {len(test_loader)}")

### 2.1 Visualizza immagini training set con info su rotazione

In [None]:
# Visualizza alcuni sample del dataset (le immagini sono gi√† croppate)
batch = next(iter(train_loader))
show_pose_samples(batch, n=4)

## 3. Inizializza Modello PoseEstimatorBaseline

In [None]:
# Inizializza il modello end-to-end per rotazione + traslazione (ResNet-50)
model = PoseEstimator(
    pretrained=True,
    dropout=Config.POSE_DROPOUT,
    freeze_backbone=False
    ).to(Config.DEVICE)

# Mostra info minimali sul modello
params_info = model.get_num_parameters()
print(f"Modello PoseEstimator caricato su: {Config.DEVICE}")
print(f"Parametri totali: {params_info['total']:,}")
print(f"Parametri allenabili: {params_info['trainable']:,}")

## 4. Nome e path per allenare-caricare un modello

In [None]:
# Parametri utili anche quando non si fa training
NAME = "test_endtoend_pose_1"

checkpoint_dir = Config.CHECKPOINT_DIR / "pose" / NAME
checkpoint_weights_dir = checkpoint_dir / "weights"
best_path = checkpoint_weights_dir / "best.pt"
last_path = checkpoint_weights_dir / "last.pt"

## 5. Training (solo rotazione)

In [None]:
# Training end-to-end rotazione + traslazione (PoseLossFull)
EPOCHS = 5 #Config.POSE_EPOCHS
LR = Config.POSE_LR
ACCUMULATION_STEPS = Config.ACCUMULATION_STEPS

# Salva args.yaml
args_dict = {
    'epochs': EPOCHS,
    'learning_rate': LR,
    'accumulation_steps': ACCUMULATION_STEPS,
    'batch_size': Config.POSE_BATCH_SIZE,
    'dropout': Config.POSE_DROPOUT,
    'freeze_backbone': False,
    'device': str(Config.DEVICE),
    'crop_margin': Config.POSE_CROP_MARGIN,
    'output_size': Config.POSE_IMAGE_SIZE,
    'weight_decay': Config.POSE_WEIGHT_DECAY
}
# Assicurati che la directory esista prima di salvare il file
checkpoint_dir.mkdir(parents=True, exist_ok=True)
with open(checkpoint_dir / "args.yaml", "w") as f:
    yaml.dump(args_dict, f)

criterion = PoseLossFull()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=Config.POSE_WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-7)

history, best_loss, best_epoch = train_pose_full(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=Config.DEVICE,
    epochs=EPOCHS,
    accumulation_steps=ACCUMULATION_STEPS,
    checkpoint_dir=str(checkpoint_dir),
    training_config=args_dict,
    save_best=True,
    save_last=True,
    verbose=True
)

print(f"\nBest loss: {best_loss:.4f} @ epoch {best_epoch+1}")
print("‚úÖ Training completato!")

### 5.1 Visualizzazione loss

In [None]:
# Plotta la curva di training loss dal CSV
training_csv_path = checkpoint_dir / 'training_result.csv'
plot_training_validation_loss_from_csv(training_csv_path)

## 6. Test e Valutazione su sample del Test Set (solo rotazione)

### 6.1 Visualizza predizione su batch immagini

In [None]:
NAME = "test_endtoend_pose_1"
checkpoint_dir = Config.CHECKPOINT_DIR / "pose" / NAME
checkpoint_weights_dir = checkpoint_dir / "weights"
best_path = checkpoint_weights_dir / "best.pt"
last_path = checkpoint_weights_dir / "last.pt"

# Carica il modello trained (se necessario)
try:
    model.load_state_dict(torch.load(best_path, map_location=Config.DEVICE))
    model.eval()
    print(f"‚úÖ Modello {NAME} caricato e in modalit√† eval!")
except Exception as e:
    print(f"‚ö†Ô∏è  Modello non trovato o gi√† caricato. Errore: {e}")
    raise SystemExit("Stop right there!")

# Estrai un batch casuale direttamente dal dataset di test
print("Batch casuale: estrazione batch random dal dataset di test")
random_start = random.randint(0, len(test_loader.dataset) - test_loader.batch_size)
print(f"Indice di inizio batch casuale: {random_start}")
indices = list(range(random_start, random_start + test_loader.batch_size)) # lista indici selezionati
samples = [test_loader.dataset[i] for i in indices]
print(f"Numero di sample estratti: {len(samples)}")

# Collate i sample come fa il DataLoader
batch = {}
for k in samples[0]:
    values = [sample[k] for sample in samples]
    if isinstance(values[0], torch.Tensor):
        batch[k] = torch.stack(values)
    else:
        batch[k] = values
print(f"Chiavi batch: {list(batch.keys())}")
test_batch = batch

images = test_batch['rgb_crop'].to(Config.DEVICE)
gt_quaternions = test_batch['quaternion'].to(Config.DEVICE)
gt_translations = test_batch['translation'].to(Config.DEVICE) if 'translation' in test_batch else None
obj_ids = test_batch['obj_id']

with torch.no_grad():
    pred_quaternions, pred_translations = model(images)

# Conversione quaternioni in matrici di rotazione
print("\nConversione quaternioni in matrici di rotazione")
pred_R = quaternion_to_rotation_matrix_batch(pred_quaternions)
gt_R = quaternion_to_rotation_matrix_batch(gt_quaternions)

# Calcola la metrica ADD full pose (usa utils.metrics)
models_dict = load_all_models()
models_info = load_models_info(Config.LINEMOD_ROOT / "models" / "models_info.yml")

print("\nCalcolo metrica ADD full pose")
results = compute_add_batch_full_pose(pred_R, pred_translations.cpu().numpy(), gt_R, gt_translations.cpu().numpy(), obj_ids.cpu().numpy() if hasattr(obj_ids, 'cpu') else obj_ids, models_dict, models_info)

rot_trans_errors = results.get('add_values', None)
print(f"Test completato su {len(images)} sample")
print(f"\nüìä ADD medio sul batch:")
print(f"   Mean ADD: {np.mean(rot_trans_errors):.4f} ¬± {np.std(rot_trans_errors):.4f}")

# Visualizza 4 sample con errore di rotazione+traslazione
show_pose_samples_with_add(images, gt_quaternions, pred_quaternions, obj_ids, rot_trans_errors)

### 6.2 Statistiche su intero Test set

In [None]:
NAME = "test_endtoend_pose_1"
checkpoint_dir = Config.CHECKPOINT_DIR / "pose" / NAME
checkpoint_weights_dir = checkpoint_dir / "weights"
best_path = checkpoint_weights_dir / "best.pt"
last_path = checkpoint_weights_dir / "last.pt"

# Carica il modello trained (se necessario)
try:
    model.load_state_dict(torch.load(best_path, map_location=Config.DEVICE))
    model.eval()
    print(f"‚úÖ Modello {NAME} caricato e in modalit√† eval!")
except Exception as e:
    print(f"‚ö†Ô∏è  Modello non trovato o gi√† caricato. Errore: {e}")
    raise SystemExit("Stop right there!")

models_dict = load_all_models()
models_info = load_models_info(Config.MODELS_INFO_PATH)

all_pred_quaternions = []
all_gt_quaternions = []
all_obj_ids = []
all_pred_translations = []
all_gt_translations = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Valutazione su test set"):
        images = batch['rgb_crop'].to(Config.DEVICE)
        gt_quaternions = batch['quaternion'].to(Config.DEVICE)
        gt_translations = batch['translation'].to(Config.DEVICE) if 'translation' in batch else None
        obj_ids = batch['obj_id'].cpu().numpy()
        pred_quaternions, pred_translations = model(images)
        all_pred_quaternions.append(pred_quaternions.cpu())
        all_gt_quaternions.append(gt_quaternions.cpu())
        all_obj_ids.append(obj_ids)
        all_pred_translations.append(pred_translations.cpu().numpy())
        if gt_translations is not None:
            all_gt_translations.append(gt_translations.cpu().numpy())

print("\nconcatenazione batch")
all_pred_quaternions = torch.cat(all_pred_quaternions, dim=0)
all_gt_quaternions = torch.cat(all_gt_quaternions, dim=0)
all_obj_ids = np.concatenate(all_obj_ids, axis=0)
all_pred_translations = np.concatenate(all_pred_translations, axis=0)
if all_gt_translations:
    all_gt_translations = np.concatenate(all_gt_translations, axis=0)
else:
    all_gt_translations = None

print("conversione da quaternoni a matrici di rotazione")
pred_R = quaternion_to_rotation_matrix_batch(all_pred_quaternions)
gt_R = quaternion_to_rotation_matrix_batch(all_gt_quaternions)

print("calcolo metriche: ADD full pose")
# 1. ADD rotazione + traslazione
results_full_pose = compute_add_batch_full_pose(
    pred_R, all_pred_translations, gt_R, all_gt_translations, all_obj_ids, models_dict, models_info
)
results_full_pose['obj_ids'] = all_obj_ids

# Salva i risultati di validazione in un file CSV
validation_results = []
add_values = results_full_pose.get('add_values', None)
is_correct = results_full_pose.get('is_correct', None)
obj_ids_full = results_full_pose.get('obj_ids', None)
if add_values is not None and is_correct is not None and obj_ids_full is not None:
    for i in range(len(add_values)):
        validation_results.append({
            'obj_id': obj_ids_full[i],
            'add_value': add_values[i],
            'is_correct': is_correct[i]
        })
if validation_results:
    df_val = pd.DataFrame(validation_results)
    val_csv_path = checkpoint_dir / 'validation_result.csv'
    df_val.to_csv(val_csv_path, index=False)
    print(f"‚úÖ Risultati di validazione salvati in {val_csv_path}")

### 6.3 Tabella: Media ADD e Accuracy per Classe

La tabella seguente riporta la media dell'errore ADD e l'accuracy (percentuale di pose corrette) per ciascuna classe (oggetto) del dataset LineMOD.

In [None]:
# Carica risultati validazione dal CSV (se esiste)
NAME = "test_endtoend_pose_1"
checkpoint_dir = Config.CHECKPOINT_DIR / "pose" / NAME
checkpoint_weights_dir = checkpoint_dir / "weights"
best_path = checkpoint_weights_dir / "best.pt"
last_path = checkpoint_weights_dir / "last.pt"

val_csv_path = checkpoint_dir / 'validation_result.csv'
if val_csv_path.exists():
    df_val = pd.read_csv(val_csv_path)
    results_full_pose = {
        'obj_ids': df_val['obj_id'].values,
        'add_values': df_val['add_value'].values if 'add_value' in df_val else None,
        'is_correct': df_val['is_correct'].values if 'is_correct' in df_val else None
    }
    print(f"‚úÖ Risultati caricati da {val_csv_path}")
else:
    results_full_pose = globals().get('results_full_pose', None)
    if results_full_pose is None:
        print("‚ö†Ô∏è  Devi prima calcolare la metrica ADD full pose su tutto il test set e salvare i risultati in 'results_full_pose'.")

In [None]:
# Tabella: Performance rotazione + traslazione (ADD full pose)
if results_full_pose is not None:
    obj_ids_full = np.array(results_full_pose['obj_ids'])
    add_values = np.array(results_full_pose['add_values'])
    is_correct = np.array(results_full_pose['is_correct'])
    data = []
    for obj_id, obj_name in Config.LINEMOD_OBJECTS.items():
        mask = obj_ids_full == obj_id
        if np.sum(mask) == 0:
            continue
        mean_add = add_values[mask].mean()
        acc = is_correct[mask].mean() * 100
        data.append({
            'Classe': f"{obj_id:02d} - {obj_name.get('name')}",
            'Media ADD (full pose)': f"{mean_add:.2f}",
            'Accuracy (%)': f"{acc:.1f}"
        })
    df = pd.DataFrame(data)
    display(df)
    print("\nMedia globale ADD (full pose):", f"{add_values.mean():.2f}")
    print("Accuracy globale (full pose) (%):", f"{is_correct.mean()*100:.1f}")

### 6.4 Grafico: Media ADD per Classe

Il grafico seguente mostra la media dell'errore ADD per ciascuna classe, per un confronto visivo immediato delle performance del modello sui diversi oggetti.

In [None]:
# Grafico a barre delle medie ADD per classe (full pose)
results_full_pose = globals().get('results_full_pose', None)
if results_full_pose is None:
    print("‚ö†Ô∏è  Devi prima calcolare la metrica ADD full pose su tutto il test set e salvare i risultati in 'results_full_pose'.")
else:
    plot_add_per_class(results_full_pose, Config.LINEMOD_OBJECTS)