In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
from torchvision.models.segmentation import deeplabv3_resnet50
from torch.utils.data import DataLoader
import torchvision.transforms as T
from sklearn.metrics import roc_auc_score, average_precision_score
from PIL import Image
import matplotlib.pyplot as plt

from lostandfound_dataloader import LostAndFoundObstacleDataset

# === 1. Estrai file solo se non già estratti ===
zip_path = "/content/drive/MyDrive/SAPIENZA/CV/Project/leftImg8bit_trainvaltest.zip"
extract_path = "/content/drive/MyDrive/SAPIENZA/CV/Project/leftImg8bit_trainvaltest"
if not os.path.exists(extract_path):
    os.system(f'unzip "{zip_path}" -d "{os.path.dirname(extract_path)}"')

# === 2. Trasformazioni ===
resize_dims = (512, 1024)
img_transform = T.Compose([
    T.Resize(resize_dims),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])
mask_transform = T.Resize(resize_dims, interpolation=T.InterpolationMode.NEAREST)

# === 3. Dataset e DataLoader ===
test_dataset = LostAndFoundObstacleDataset(
    root="/content/drive/MyDrive/SAPIENZA/CV/Project",
    img_transform=img_transform,
    mask_transform=mask_transform
)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# === 4. Carica modello ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = deeplabv3_resnet50(pretrained=False)
model.classifier[4] = nn.Conv2d(256, 8, kernel_size=1)
model.load_state_dict(torch.load("/content/drive/MyDrive/Project/deeplabv3_binary_cityscapes.pth"))
model = model.eval().to(device)

# === 5. Funzione di valutazione + visualizzazione ===
def evaluate_ood(model, dataloader, device, save_dir="visual_outputs"):
    model.eval()
    os.makedirs(save_dir, exist_ok=True)
    all_scores = []
    all_targets = []

    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            images = batch['image'].to(device)
            object_gt = batch['object_mask'].cpu().numpy()  # shape: (B, H, W)
            outputs = model(images)['out']
            logits = outputs[:, 7, :, :]
            S = torch.sigmoid(logits)  # objectness score

            # Flatten for AUROC / AP
            all_scores.append(S.view(-1).cpu().numpy())
            all_targets.append(object_gt.reshape(-1))

            # Visualizza e salva predizioni
            for j in range(images.size(0)):
                pred_mask = (S[j] > 0.5).cpu().numpy() * 255
                gt_mask = object_gt[j] * 255
                pred_img = Image.fromarray(pred_mask.astype(np.uint8))
                gt_img = Image.fromarray(gt_mask.astype(np.uint8))

                pred_img.save(os.path.join(save_dir, f"pred_{i}_{j}.png"))
                gt_img.save(os.path.join(save_dir, f"gt_{i}_{j}.png"))
                Image.fromarray((img * 255).astype(np.uint8)).save("predictions/rgb_0_0.png")
                Image.fromarray((gt * 255).astype(np.uint8)).save("predictions/gt_0_0.png")

    all_scores = np.concatenate(all_scores)
    all_targets = np.concatenate(all_targets)

    auroc = roc_auc_score(all_targets, all_scores)
    ap = average_precision_score(all_targets, all_scores)

    return auroc, ap

# === 6. Esegui la valutazione ===
auroc, ap = evaluate_ood(model, test_loader, device, save_dir="visual_outputs")
print(f"[✅ TEST su LostAndFound] AUROC: {auroc:.4f}, AP: {ap:.4f}")


In [None]:
import matplotlib.pyplot as plt

for i in range(4):  # visualizza i primi 4 campioni
    img = test_dataset[i]['image'].permute(1, 2, 0).numpy()
    pred = (S[i] > 0.5).cpu().numpy()
    gt = test_dataset[i]['object_mask'].numpy()

    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(img)
    axs[0].set_title("RGB Image")

    axs[1].imshow(pred, cmap='gray')
    axs[1].set_title("Predicted Objectness")

    axs[2].imshow(gt, cmap='gray')
    axs[2].set_title("Ground Truth Mask")

    for ax in axs: ax.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
plt.imshow(S[i].cpu().numpy(), cmap='hot')
plt.colorbar()
plt.title("Objectness Score Heatmap")
plt.axis('off')
plt.show()


In [None]:
overlay = img.copy()
overlay[pred == 1] = [1.0, 0.0, 0.0]  # evidenzia in rosso

plt.imshow(overlay)
plt.title("Object Detection Overlay")
plt.axis('off')
plt.show()
