# Imports

In [None]:

import os
import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from tqdm.notebook import tqdm
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torchvision.transforms.functional as TF
import cv2
from tqdm.notebook import tqdm
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from fastdepth_model import MobileNetSkipAdd


# Globals

In [None]:
#configuration dictionary
CONFIG = {
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "TRAIN_DIR": "nyu_data/data/nyu2_train",
    "TEST_DIR": "nyu_data/data/nyu2_test",
    "LEARNING_RATE": 1e-3, 
    "BATCH_SIZE": 16,
    "EPOCHS": 20,
    "IMG_WIDTH": 224,
    "IMG_HEIGHT": 224,
    "MAX_DEPTH_METERS": 10.0,
    "ENABLE_SELECTIVITY": True, #change if you want to enable/disable selectivity training
    "SELECTIVITY_LAYER_NAME": "decode_conv5",#choose the layer to apply selectivity
    "SELECTIVITY_LAMBDA": 0.1,#weight of the selectivity loss
    "NUM_DEPTH_BINS": 27,# number of depth bins for selectivity
    "RESUME_CHECKPOINT": None,#checkpoint path to resume training, none to start from the beginning
}
if CONFIG["ENABLE_SELECTIVITY"]:
    CONFIG["MODEL_SUFFIX"] = "_interpretable"
else:
    CONFIG["MODEL_SUFFIX"] = "_baseline"


CONFIG["BEST_MODEL_PATH"] = f"models/best_model_{CONFIG['MODEL_SUFFIX']}.pth"
CONFIG["LATEST_CHECKPOINT_PATH"] = f"checkpoint/latest_checkpoint{CONFIG['MODEL_SUFFIX']}.pth"
CONFIG["RESUME_CHECKPOINT"] = None

print(f"Device in use: {CONFIG['DEVICE']}")
print(f"Training with selectivity: {'Enabled' if CONFIG['ENABLE_SELECTIVITY'] else 'Disabled'}")
if CONFIG['ENABLE_SELECTIVITY']:
    print(f"Layer Target: {CONFIG['SELECTIVITY_LAYER_NAME']}")
    print(f"Lambda: {CONFIG['SELECTIVITY_LAMBDA']}")
    print(f"Num Bin: {CONFIG['NUM_DEPTH_BINS']}")
print(f"Best model to save in: {CONFIG['BEST_MODEL_PATH']}")
print(f"Last checkpoint: {CONFIG['LATEST_CHECKPOINT_PATH']}")

# Utils

In [None]:

#function to discretize the depth range in a logarithmic way
def discretize_depth(depth, num_bins, max_depth, min_depth=0.1): 
    mask = (depth > min_depth) & (depth < max_depth)#mask for valid values
    bins = torch.full_like(depth, -1, dtype=torch.long)#initialize bins with -1
    
    depth_log = torch.log(depth[mask])# convert to log space
    min_log = torch.log(torch.tensor(min_depth, device=depth.device))
    max_log = torch.log(torch.tensor(max_depth, device=depth.device))
    
    normalized_depth = (depth_log - min_log) / (max_log - min_log)# normalize from 0 to 1
    bin_values = torch.floor(normalized_depth * num_bins).long()# multiply by num bins and floor to get bin index
    
    bins[mask] = torch.clamp(bin_values, 0, num_bins - 1)# insert valid bins into the output tensor
    return bins

#function to capture the activations
activations_capture = {} #dictionary to store the activations
def get_activation(name):
    def hook(model, input, output):
        activations_capture[name] = output #take the output of the layer
    return hook


## Calculate valid bins

In [None]:
train_dir = CONFIG['TRAIN_DIR'] 
depth_files = sorted(glob.glob(os.path.join(train_dir, '**', '*.png'), recursive=True))
num_bins = CONFIG['NUM_DEPTH_BINS']
max_depth = CONFIG['MAX_DEPTH_METERS']
pixel_counts_np = np.zeros(num_bins, dtype=np.int64)# inizialize pixel counts array

for depth_path in tqdm(depth_files, desc="Count pixels in bins"):
    with Image.open(depth_path) as depth_pil:# open the depth map
        depth_array_raw = np.array(depth_pil)# convert the depth map to numpy array
        
        if depth_array_raw.max() <= 255:#handle 8 and 16 bit images
            depth_meters = depth_array_raw.astype(np.float32) / 255.0 * max_depth
        else:
            depth_meters = depth_array_raw.astype(np.float32) / 1000.0
        
        depth_tensor = torch.from_numpy(depth_meters) # convert to tensor
        gt_bins = discretize_depth(depth_tensor, num_bins, max_depth).numpy() #discretize the depth map
        
        valid_bins = gt_bins[gt_bins != -1]#remove invalid bins
        
        if valid_bins.size > 0: #if there are valid bins
            bin_counts_batch = np.bincount(valid_bins, minlength=num_bins)# count pixels per bin
            pixel_counts_np += bin_counts_batch# update global pixel counts

pixel_counts = torch.from_numpy(pixel_counts_np)# convert to tensor
valid_bin_ids = torch.where(pixel_counts > 0)[0].to(CONFIG['DEVICE'])#get indices of valid bins
print(f"\nValid bins: {valid_bin_ids.tolist()}")
if len(valid_bin_ids) == 0:
    raise ValueError("No depth bins found")

#function to assign depth to units
def assign_depths_to_units(num_units, valid_bins):
    num_valid_bins = len(valid_bins)#count valid bins
    if num_valid_bins == 0:
        return torch.full((num_units,), -1, dtype=torch.long, device=valid_bins.device)
    assign_ids = torch.floor( torch.arange(num_units, device=valid_bins.device) * (num_valid_bins / num_units)).long()# genrate assignment indices
    assign_ids = torch.clamp(assign_ids, 0, num_valid_bins - 1) #clamp to valid range
    return valid_bins[assign_ids] #map the assigned indices to valid bins

## Define Selectivity Loss

In [None]:
class SelectivityLoss(nn.Module):
    def __init__(self, num_units, num_bins, valid_bins):
        super().__init__()
        self.num_units = num_units #num of units in the layer
        self.num_bins_config = num_bins 
        
        assigned_depths = assign_depths_to_units(num_units, valid_bins)#assign depths to units
        self.register_buffer('assigned_depths', assigned_depths)# register as buffer

    def forward(self, activations, gt_depths, max_depth, min_depth=0.1):
        device = activations.device
        batch_size, num_units, _, _ = activations.shape #get the batch size and num of units of the layer

        gt_bins = discretize_depth(gt_depths, self.num_bins_config, max_depth, min_depth)#discretize ground truth depths
        
        if activations.shape[-2:] != gt_depths.shape[-2:]: #resize depth activations if needed
            activations_resized = F.interpolate(activations, size=gt_depths.shape[-2:], mode='bilinear', align_corners=False)
        else:
            activations_resized = activations

        abs_activations = torch.abs(activations_resized)# take absolute value of activations
        
        avg_responses_batch = torch.zeros(batch_size, num_units, self.num_bins_config, device=device) #inizialize a tensor to store average responses
        
        for d in range(self.num_bins_config):
            mask = (gt_bins == d).float() #mask for current bin
            pixel_counts = torch.sum(mask, dim=[1, 2, 3]) + 1e-8 #count valid pixels per image
            sum_activations = torch.sum(abs_activations * mask, dim=[2, 3])# sum activations for the current bin
            avg_responses_batch[:, :, d] = sum_activations / pixel_counts.unsqueeze(1)# insert average responses for the current bin

        total_loss = 0.0
        data_units = 0

        new_num_bins = avg_responses_batch.shape[2]# get the actual number of bins in the tensor

        for i in range(self.num_units):
            bin_i = self.assigned_depths[i].item()#get the assigned bin for the current unit
            
            if bin_i >= new_num_bins: #if the assigned bin is out of range
                print(f"Jump the unit {i} because its assigned bin ({bin_i}) is out of range "
                      f"from the tensor with dimension: {new_num_bins}).")
                continue 

            r_bin_i = avg_responses_batch[:, i, bin_i] #extract the response for the assigned bin

            other_bins_mask = torch.ones(new_num_bins, dtype=bool, device=device)# mask for other bins
            other_bins_mask[bin_i] = False # exclude the assigned bin
            
            valid_bins_in_batch = (torch.sum(avg_responses_batch[:, i, other_bins_mask], dim=0) > 0)
            if torch.any(valid_bins_in_batch):
                r_mean = torch.mean(avg_responses_batch[:, k, other_bins_mask][:, valid_bins_in_batch], dim=1)#mean response for other bins
            else:
                r_mean = torch.zeros_like(r_bin_i)

            num = r_bin_i - r_mean
            den = r_bin_i + r_mean + 1e-8
            ds_score = num / den #selectivity score
            loss = -torch.mean(ds_score)#minimize selectivity score
            
            if not torch.isnan(loss):
                total_loss += loss
                data_units += 1

        if data_units == 0:
            return torch.tensor(0.0, device=device)

        return total_loss / data_units

# Data

In [None]:
class NYUDepthDataset(Dataset):
    def __init__(self, data_dir, transform=None, depth_transform=None, train=True, val=False):
        self.data_dir = data_dir
        self.transform = transform
        self.depth_transform = depth_transform
        
        self.augmentation = train and not val #flag to apply augumentation
        
        if self.augmentation:
            self.color_jitter = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)

        train_or_val = train or val #different folder structure
        
        if train_or_val:
            self.image_paths = sorted(glob.glob(os.path.join(data_dir, '**', '*.jpg'), recursive=True)) 
            self.depth_paths = sorted(glob.glob(os.path.join(data_dir, '**', '*.png'), recursive=True))
        else:
            image_png = sorted(glob.glob(os.path.join(data_dir, '*.png')))
            self.image_paths = [p for p in image_png if 'colors' in os.path.basename(p)]
            self.depth_paths = [p for p in image_png if 'depth' in os.path.basename(p)]

        assert len(self.image_paths) > 0, f"No images found in '{data_dir}' (mode: train={train}, val={val})"
        assert len(self.depth_paths) > 0, f"No depth map found in '{data_dir}' (mode: train={train}, val={val})"
        assert len(self.image_paths) == len(self.depth_paths), \
            f"Number of images ({len(self.image_paths)}) and depth map ({len(self.depth_paths)}) doesn't match"

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')# open the image in rgb mode

        if self.augmentation: #apply augumentation
            image = self.color_jitter(image)
        
        depth_path = self.depth_paths[idx]
        depth_pil = Image.open(depth_path)
        depth = np.array(depth_pil)# convert the depth map to numpy array

        if depth.max() <= 255 and depth.ndim == 2:# if 8 bit scale to meters
            depth = depth.astype(np.float32) / 255.0 * CONFIG['MAX_DEPTH_METERS']
        else:
            depth = depth.astype(np.float32) / 1000.0# else i 16 bit

        if self.transform: 
            image = self.transform(image)#apply image transform
        
        depth = torch.from_numpy(depth.copy()).unsqueeze(0)
        
        if self.depth_transform:
            depth = self.depth_transform(depth)#apply depth transform

        return image, depth

image_transform = transforms.Compose([
    transforms.Resize((CONFIG['IMG_HEIGHT'], CONFIG['IMG_WIDTH'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

depth_transform = transforms.Compose([
    transforms.Resize((CONFIG['IMG_HEIGHT'], CONFIG['IMG_WIDTH']), interpolation=transforms.InterpolationMode.NEAREST),
])

try:
    #subset for train
    train_dataset = NYUDepthDataset(
        CONFIG['TRAIN_DIR'], 
        transform=image_transform, 
        depth_transform=depth_transform, 
        train=True, 
        val=False  # Non è validation
    )
    
    # subset for val
    val_dataset = NYUDepthDataset(
        CONFIG['TRAIN_DIR'], 
        transform=image_transform, 
        depth_transform=depth_transform, 
        train=False, # false for validation
        val=True     # true for iterate over training folder
    )
    
    #test set
    test_dataset = NYUDepthDataset(
        CONFIG['TEST_DIR'], 
        transform=image_transform, 
        depth_transform=depth_transform, 
        train=False, 
        val=False
    )

    assert len(train_dataset) == len(val_dataset) #check if train and val have the same lenght before split

    #divide indices for train and val
    val_percent = 0.1
    n_total = len(train_dataset)
    n_val = int(n_total * val_percent)
    n_train = n_total - n_val
    
    generator = torch.Generator().manual_seed(19)
    train_ids, val_ids = random_split(range(n_total), [n_train, n_val], generator=generator)

    #create subsets
    train_subset = Subset(train_dataset, train_ids)
    val_subset = Subset(val_dataset, val_ids)

    #create dataloaders
    train_loader = DataLoader(train_subset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_subset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=0, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=0, pin_memory=True)

    print(f"Dataset loaded successfully:")
    print(f"Training set:   {len(train_subset)} samples ")
    print(f"Validation set: {len(val_subset)} samples ")
    print(f"Test set:       {len(test_dataset)} samples ")

except (FileNotFoundError, AssertionError, TypeError) as e:
    print(e)

# Network

In [None]:

model = MobileNetSkipAdd(output_size=(224, 224), pretrained=True) #create model instance

model.to(CONFIG['DEVICE'])


## Assign Depths to Units

In [None]:


try:
    target_layer = dict(model.named_modules())[CONFIG['SELECTIVITY_LAYER_NAME']]# get the target layer
    
    conv_layer = [m for m in target_layer.modules() if isinstance(m, nn.Conv2d)]# find all Conv2d layers in the target module
    if not conv_layer:
        raise ValueError("No layer found")
    layer_units = conv_layer[-1].out_channels#get the output channels of the last Conv2d layer
    
    assign_depths = assign_depths_to_units(layer_units, valid_bin_ids) #assign depths to units
    
    print(f"Assigned {layer_units} uunits at {len(valid_bin_ids)} valid bins")
 

except Exception as e:
    print(f"\nNo bins assigned: {e}")

# Train

In [None]:

def train_one_epoch(model, dataloader, optimizer, loss_fn, interpretable_loss_fn, curr_lambda, config):
    model.train()
    total_l1_loss = 0.0
    total_sel_loss = 0.0
    
    pbar = tqdm(dataloader, desc="Training Epoch")
    for images, depths in pbar:
        images, depths = images.to(config['DEVICE']), depths.to(config['DEVICE'])
        optimizer.zero_grad() #zero the gradients of the previous step
        
        pred_depths = model(images) #forward pass
        
        predicted_depths_resized = F.interpolate(pred_depths, size=depths.shape[-2:], mode='bilinear', align_corners=False)# resize to match gt size
        mask = depths > 0 #mask for valid depths
        l1_loss = loss_fn(predicted_depths_resized[mask], depths[mask])# calculate L1 loss

        sel_loss = torch.tensor(0.0, device=config['DEVICE'])
        if curr_lambda > 0 and interpretable_loss_fn is not None:# if selectivity is enabled
            layer_name = config['SELECTIVITY_LAYER_NAME']# get the target layer name
            if layer_name in activations_capture:
                interpretable_activations = activations_capture[layer_name]#take the activations from the hook
                base_sel_loss = interpretable_loss_fn(interpretable_activations, depths, config['MAX_DEPTH_METERS'])
                sel_loss = curr_lambda * base_sel_loss# weight the selectivity loss
        
        total_loss = l1_loss + sel_loss
        if not torch.isnan(total_loss):
            total_loss.backward()# backpropagation
            optimizer.step()# update weights
            total_l1_loss += l1_loss.item()
            total_sel_loss += sel_loss.item()

        pbar.set_postfix(l1_loss=f"{l1_loss.item():.4f}", sel_loss=f"{sel_loss.item():.4f}")
    
    return total_l1_loss / len(dataloader), total_sel_loss / len(dataloader)

def validate_one_epoch(model, dataloader, loss_fn, device, desc="Validating"):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for images, depths in tqdm(dataloader, desc=desc):
            images, depths = images.to(device), depths.to(device)
            predicted_depths = model(images)
            predicted_depths = F.interpolate(predicted_depths, size=depths.shape[-2:], mode='bilinear', align_corners=False)
            mask = depths > 0
            loss = loss_fn(predicted_depths[mask], depths[mask])# calculate L1 loss on valid set
            if not torch.isnan(loss):
                total_loss += loss.item()
    return total_loss / len(dataloader)

loss_fn = nn.L1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['LEARNING_RATE'])
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

interpretable_loss_fn = None
hook_handle = None

if CONFIG['ENABLE_SELECTIVITY']:
    try:
        target_layer_module = dict(model.named_modules())[CONFIG['SELECTIVITY_LAYER_NAME']]
        
        conv_layers = [m for m in target_layer_module.modules() if isinstance(m, nn.Conv2d)]
        if not conv_layers:
            raise ValueError(f"No Conv2d layers found in the target module '{CONFIG['SELECTIVITY_LAYER_NAME']}'")
        num_units = conv_layers[-1].out_channels #take the output channels of the last layer
        
        interpretable_loss_fn = SelectivityLoss(# create the loss instance
            num_units=num_units,
            num_bins=CONFIG['NUM_DEPTH_BINS'],
            valid_bins=valid_bin_ids
        ).to(CONFIG['DEVICE'])
        
        target_layer = dict(model.named_modules())[CONFIG['SELECTIVITY_LAYER_NAME']]
        hook_handle = target_layer.register_forward_hook(get_activation(CONFIG['SELECTIVITY_LAYER_NAME']))#take the activation with a hook and save into the dictionary activation capture

    except (KeyError, AttributeError, NameError) as e:
        print(f"Error in the setup of the loss: {e}")
        raise

start_epoch = 0
best_val_loss = float('inf')
if CONFIG.get('RESUME_CHECKPOINT') and os.path.exists(CONFIG['RESUME_CHECKPOINT']):#if wantt to resume a training
    print(f"Resume training from {CONFIG['RESUME_CHECKPOINT']}")
    checkpoint = torch.load(CONFIG['RESUME_CHECKPOINT'], map_location=CONFIG['DEVICE'])
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch']
    best_val_loss = checkpoint['best_val_loss']
    print(f"Checkpoint loaded {start_epoch}, best_val_loss = {best_val_loss:.4f}")
else:
    print("Start a new training")

for epoch in range(start_epoch, CONFIG['EPOCHS']):#loop over epochs
    print(f"\n--- Epoch [{epoch+1}/{CONFIG['EPOCHS']}] ---")
    
    current_lambda = 0.0
    if CONFIG['ENABLE_SELECTIVITY']:
        current_lambda = CONFIG['SELECTIVITY_LAMBDA']

    avg_l1_loss, avg_sel_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, interpretable_loss_fn, current_lambda, CONFIG)
    
    avg_val_loss = validate_one_epoch(model, val_loader, loss_fn, CONFIG['DEVICE'])
    scheduler.step()
    
    print(f"Epoch [{epoch+1}/{CONFIG['EPOCHS']}], Avg L1 Loss: {avg_l1_loss:.4f}, Avg Sel Loss: {avg_sel_loss:.4f}, Avg Val Loss: {avg_val_loss:.4f}")
    
    if avg_val_loss < best_val_loss and (not CONFIG['ENABLE_SELECTIVITY'] or epoch >= CONFIG['WARMUP_EPOCHS']):
        best_val_loss = avg_val_loss
        torch.save({'model_state_dict': model.state_dict()}, CONFIG['BEST_MODEL_PATH'])# save best model based on val loss
        print(f"New best model saved in '{CONFIG['BEST_MODEL_PATH']}' (Val Loss: {best_val_loss:.4f})")
        
    latest_checkpoint = {#save latest checkpoint
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_val_loss': best_val_loss
    }
    torch.save(latest_checkpoint, CONFIG['LATEST_CHECKPOINT_PATH'])

if hook_handle:
    hook_handle.remove()# remove the hook

print("\nTraining completed")

# Evaluation

## Load best model or checkpoint path

In [None]:
#checkpoint = torch.load(CONFIG['LATEST_CHECKPOINT_PATH'], map_location=CONFIG['DEVICE'])
#model.load_state_dict(checkpoint['model_state_dict'])
checkpoint = torch.load(CONFIG['BEST_MODEL_PATH'], map_location=CONFIG['DEVICE'])
model.load_state_dict(checkpoint['model_state_dict'])
model.to(CONFIG['DEVICE'])


## Evaluate depth performance

In [None]:
#function to evaluate depth performance metrics
def evaluate_performance(model, dataloader, device):
    model.eval()
    preds, gts = [], []
    with torch.no_grad():
        for images, gt_depths in tqdm(dataloader, desc="Evaluating Performance"):
            images, gt_depths_gpu = images.to(device), gt_depths.to(device)
            pred_depths = model(images)#obtain predictions
            pred_depths = F.interpolate(pred_depths, size=gt_depths_gpu.shape[-2:], mode='bilinear', align_corners=False)#resize
            
            mask = gt_depths_gpu > 0#mask for valid depths
            valid_preds = pred_depths[mask].cpu()
            valid_gts = gt_depths_gpu[mask].cpu()
            preds.append(valid_preds)
            gts.append(valid_gts)

    preds = torch.cat(preds).numpy()
    gts = torch.cat(gts).numpy()
    
    #calculate the metrics
    abs_diff = np.abs(gts - preds)
    mae = np.mean(abs_diff)
    rmse = np.sqrt(np.mean((gts - preds) ** 2))
    abs_rel = np.mean(abs_diff / gts)
    
    #calculate delta metrics
    ratio = np.maximum((gts / preds), (preds / gts))
    delta1 = (ratio < 1.25).mean()
    delta2 = (ratio < 1.25**2).mean()
    delta3 = (ratio < 1.25**3).mean()
    
    results = {
        "rmse": rmse, "mae": mae, "rel": abs_rel,
        "delta1": delta1, "delta2": delta2, "delta3": delta3
    }
    return results

print(f"\nModel loaded:'{CONFIG['BEST_MODEL_PATH']}'")

#checkpoint = torch.load(CONFIG['LATEST_CHECKPOINT_PATH'], map_location=CONFIG['DEVICE'])
#model.load_state_dict(checkpoint['model_state_dict'])
checkpoint = torch.load(CONFIG['BEST_MODEL_PATH'], map_location=CONFIG['DEVICE'])
model.load_state_dict(checkpoint['model_state_dict'])
model.to(CONFIG['DEVICE'])

performance_metrics = evaluate_performance(model, test_loader, CONFIG['DEVICE'])#call the evaluation function


print("\n Performance on Test Set:")
print(f"RMSE:{performance_metrics['rmse']:.4f}")
print(f"MAE: {performance_metrics['mae']:.4f}")
print(f"REL: {performance_metrics['rel']:.4f}")

print(f"δ < 1.25: {performance_metrics['delta1']:.4%}")
print(f"δ < 1.25²: {performance_metrics['delta2']:.4%}")
print(f"δ < 1.25³: {performance_metrics['delta3']:.4%}")

# function to visulize depth predictions
def visualize_predictions(model, dataloader, device, num_samples=5):
    model.eval()
    samples_shown = 0
    with torch.no_grad():
        for images, gt_depths in dataloader:
            images = images.to(device)
            predicted_depths = model(images).cpu()
            
            for i in range(images.size(0)):
                if samples_shown >= num_samples: return

                img = images[i].cpu().permute(1, 2, 0).numpy()#convert the dimension for matplotlib
                mean, std = np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225])# denormalize
                img = np.clip(std * img + mean, 0, 1)

                gt = gt_depths[i].squeeze().numpy()#convert depth and prediction to numpy
                pred = predicted_depths[i].squeeze().numpy()
                
                #plot the images gt and prediction
                fig, axes = plt.subplots(1, 3, figsize=(18, 6))
                axes[0].imshow(img); axes[0].set_title("RGB image"); axes[0].axis('off')
                axes[1].imshow(gt, cmap='magma', vmin=0, vmax=CONFIG['MAX_DEPTH_METERS']); axes[1].set_title("Ground Truth"); axes[1].axis('off')
                im = axes[2].imshow(pred, cmap='magma', vmin=0, vmax=CONFIG['MAX_DEPTH_METERS']); axes[2].set_title("Prediction"); axes[2].axis('off')
                fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.7)
                plt.show()
                samples_shown += 1

visualize_predictions(model, test_loader, CONFIG['DEVICE'])

## Selectivity evaluation

### Neuron selectivity

In [None]:
#function to visualzie the selectivity of the neuron

def visualize_neuron_selectivity(model, dataloader, layer_name, config, valid_bin_ids, num_units_to_show=32, num_batches=50):
   
    model.eval()
    device = config['DEVICE']
    num_bins = config['NUM_DEPTH_BINS']
    max_depth = config['MAX_DEPTH_METERS']
    
    hook_handle = None
    avg_responses = None 
    num_units = 0

    try:
        target_layer = dict(model.named_modules())[layer_name]
        num_units = [m.out_channels for m in target_layer.modules() if isinstance(m, nn.Conv2d)][-1]
        hook_handle = target_layer.register_forward_hook(get_activation(layer_name))
        
        sum_responses = torch.zeros(num_units, num_bins, device=device)
        pixel_counts = torch.zeros(num_units, num_bins, device=device)
        
        with torch.no_grad():
            pbar = tqdm(total=min(num_batches, len(dataloader)), desc="Calculating Selectivity")
            for i, (images, gt_depths) in enumerate(dataloader):
                if i >= num_batches: break
                images, gt_depths = images.to(device), gt_depths.to(device)
                _ = model(images)# forward pass to capture activations
                
                activations = activations_capture[layer_name]# get the activations from the hook
                activations_resized = F.interpolate(activations, size=gt_depths.shape[-2:], mode='bilinear', align_corners=False)
                gt_bins = discretize_depth(gt_depths, num_bins, max_depth)
                abs_activations = torch.abs(activations_resized)# take absolute value of the activations
                
                for k in range(num_units):
                    for d in range(num_bins):
                        mask_d = (gt_bins == d).float()#mask for the current bin
                        sum_responses[k, d] += torch.sum(abs_activations[:, k, :, :] * mask_d)#sum the activations for the current bin
                        pixel_counts[k, d] += torch.sum(mask_d)# count valid pixels for the current bin
                pbar.update(1)
            pbar.close()

        avg_responses = (sum_responses / (pixel_counts + 1e-8)).cpu().numpy()# calculate average responses

    except (KeyError, AttributeError, IndexError) as e:
        print(f"error in visualization: {e}")
        return 

    finally:
        if hook_handle:
            hook_handle.remove()

    if avg_responses is None:
        return

    indices_to_show = np.linspace(0, num_units - 1, min(num_units_to_show, num_units), dtype=int)#select units to show
    assigned_depths_map = None
    if config.get("ENABLE_SELECTIVITY", False):
        assigned_depths_tensor = assign_depths_to_units(num_units, valid_bin_ids.cpu())#assign target bin
        assigned_depths_map = assigned_depths_tensor.numpy()
        
    print(f"Visualize the units: {indices_to_show}")
    

    fig, axes = plt.subplots(1, len(indices_to_show), figsize=(5 * len(indices_to_show), 4), sharey=True)
    fig.suptitle(f"Avg response of the layer'{layer_name}'at every depth bin", fontsize=16)
    
    #calculate the depth values
    log_min = np.log(0.1) 
    log_max = np.log(max_depth)
    log_bin_edges = np.linspace(log_min, log_max, num_bins + 1)
    depth_bin_values = np.exp(log_bin_edges[:-1]) #use the beginning of each bin as representative value
    x_positions = np.arange(num_bins) #x axe is the number or bin

    for i, unit_idx in enumerate(indices_to_show):
        ax = axes if len(indices_to_show) == 1 else axes[i]#iterate over the units
        responses = avg_responses[unit_idx]
        ax.bar(x_positions, responses, width=0.8) #height of the bar is the average response
        
        title = f"Unit {unit_idx}"
        if assigned_depths_map is not None:# if selectivity is enabled
            assigned_bin = assigned_depths_map[unit_idx]#obtain the assigned bin
            assigned_depth_val = depth_bin_values[assigned_bin]
            title += f"\n(Target Bin: {assigned_bin} ≈ {assigned_depth_val:.1f}m)"
            if assigned_bin < len(ax.patches):
                 ax.patches[assigned_bin].set_facecolor('orangered')
        ax.set_title(title)
        ax.set_xlabel("Depth (m)")

        num_labels = 5  
        tick_indices = np.linspace(0, num_bins - 1, num_labels, dtype=int)#linear ticks
        tick_labels = [f"{depth_bin_values[idx]:.1f}" for idx in tick_indices]
        ax.set_xticks(tick_indices)
        ax.set_xticklabels(tick_labels)

        if i == 0:
            ax.set_ylabel("Avg response")
        ax.grid(axis='y', linestyle='--', alpha=0.7)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()


visualize_neuron_selectivity(model, train_loader, CONFIG['SELECTIVITY_LAYER_NAME'], CONFIG, valid_bin_ids)



### Visualize activation map

In [None]:

#function to show activation maps
def visualize_activation_maps(model, dataloader, layer_name, config, valid_bin_ids, num_images=3, image_idx=None, units_to_show=[1, 24, 28]):

    model.eval()
    device = config['DEVICE']
    hook_handle = None

    try:
        target_layer = dict(model.named_modules())[layer_name]
        hook_handle = target_layer.register_forward_hook(get_activation(layer_name))
        num_units = [m.out_channels for m in target_layer.modules() if isinstance(m, nn.Conv2d)][-1]
        
        assigned_targets = assign_depths_to_units(num_units, valid_bin_ids.cpu())
        log_min = np.log(0.1)
        log_max = np.log(config['MAX_DEPTH_METERS'])
        log_bin_edges = np.linspace(log_min, log_max, config['NUM_DEPTH_BINS'] + 1)
        log_bin_centers = (log_bin_edges[:-1] + log_bin_edges[1:]) / 2
        depth_bin_centers = np.exp(log_bin_centers)

        print(f"Visualize units: {units_to_show}")
        
        images_shown = 0
        curr_idx = 0
        with torch.no_grad():
            for images, gt_depths in dataloader:#iterate over batches
                for i in range(images.size(0)):
                    show = False
                    if image_idx is not None:
                        if curr_idx == image_idx:
                            show = True
                    elif images_shown < num_images:
                        show = True

                    if show:
                        print(f"\nVisualize image number {curr_idx}")
                        
                        image_tensor = images[i].unsqueeze(0).to(device)
                        
                        _ = model(image_tensor)
                        activations = activations_capture[layer_name]

                        #plot the image gt depth and feature maps
                        img_rgb_normalized = images[i].permute(1, 2, 0).numpy()
                        mean, std = np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225])
                        img_rgb = np.clip(std * img_rgb_normalized + mean, 0, 1)
                        
                        gt_depth = gt_depths[i].squeeze().numpy()
                        
                        num_plots = 2 + len(units_to_show)
                        fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 6))
                        
                        axes[0].imshow(img_rgb)
                        axes[0].set_title("Image RGB")
                        axes[0].axis('off')

                        axes[1].imshow(gt_depth, cmap='magma', vmin=0, vmax=config['MAX_DEPTH_METERS'])
                        axes[1].set_title("Depth (m) ")
                        axes[1].axis('off')
                        
                        for j, unit_idx in enumerate(units_to_show):
                            activation_map = torch.abs(activations[0, unit_idx, :, :]).cpu().numpy()
                            
                            h, w, _ = img_rgb.shape
                            activation_map_resized = cv2.resize(activation_map, (w, h))
                            
                            heatmap = cv2.normalize(activation_map_resized, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
                            heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
                            heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
                            
                            img_with_heatmap = cv2.addWeighted((img_rgb * 255).astype(np.uint8), 0.5, heatmap, 0.5, 0)
                            
                            ax = axes[2 + j]
                            ax.imshow(img_with_heatmap)
                            target_bin = assigned_targets[unit_idx].item()
                            target_depth = depth_bin_centers[target_bin]
                            ax.set_title(f"Unit{unit_idx}\n(Target: Bin {target_bin} ≈ {target_depth:.1f}m)")
                            ax.axis('off')

                        plt.tight_layout()
                        plt.show()
                        
                        images_shown += 1

                    
                    curr_idx += 1
                    
                    if (image_idx is not None and images_shown > 0) or \
                       (image_idx is None and images_shown >= num_images):
                        break
                
                if (image_idx is not None and images_shown > 0) or \
                   (image_idx is None and images_shown >= num_images):
                    break
            
    except (KeyError, IndexError, AttributeError) as e:
        print(f"error: {e}")
        return
    finally:
        if hook_handle:
            hook_handle.remove()


visualize_activation_maps(model, test_loader, CONFIG['SELECTIVITY_LAYER_NAME'], CONFIG, valid_bin_ids)


### Evaluate selectivity metrics

In [None]:


def evaluate_selectivity(model, dataloader, layer_name, config, valid_bin_ids):
    model.eval()
    device = config['DEVICE']
    num_bins = config['NUM_DEPTH_BINS']
    max_depth = config['MAX_DEPTH_METERS']
    
    hook_handle = None
    try:
    
        target_layer = dict(model.named_modules())[layer_name]
        hook_handle = target_layer.register_forward_hook(get_activation(layer_name))
        num_units = [m.out_channels for m in target_layer.modules() if isinstance(m, nn.Conv2d)][-1]
        
        sum_responses = torch.zeros(num_units, num_bins, device=device)
        pixel_counts = torch.zeros(num_bins, device=device)

        print(f"Calculate selectivity for '{layer_name}' with ({num_units} units")
        with torch.no_grad():
            for images, gt_depths in tqdm(dataloader, desc="Calculating Selectivity"):
                images, gt_depths = images.to(device), gt_depths.to(device)
                _ = model(images) 
                
                activations = activations_capture[layer_name]# get the activations from the hook
                activations_resized = F.interpolate(activations, size=gt_depths.shape[-2:], mode='bilinear', align_corners=False)
                gt_bins = discretize_depth(gt_depths, num_bins, max_depth)
                abs_activations = torch.abs(activations_resized)
                
                for d in range(num_bins):
                    mask_d = (gt_bins == d)#mask for the current bin
                    sum_responses[:, d] += torch.sum(abs_activations * mask_d.float(), dim=(0, 2, 3))#sum activations for the current bin
                    pixel_counts[d] += torch.sum(mask_d)# count valid pixels for the current bin
        
        avg_responses = sum_responses / (pixel_counts.unsqueeze(0) + 1e-8)
    
        assigned_targets = assign_depths_to_units(num_units, valid_bin_ids.to(device))# assign target bins to units
        
        ds_scores, ds_scores_target = [], []#list to store the ds scores
        correct_assignments = 0# counter for correct assignments
        
        for k in range(num_units):
            responses_k = avg_responses[k]
            target_bin = assigned_targets[k].item()
            if torch.all(responses_k == 0): continue#skip if no response in any bin
            
            R_max_val, R_max_idx = torch.max(responses_k), torch.argmax(responses_k)#find the max response and its index
            other_mask = torch.ones(num_bins, dtype=bool, device=device); other_mask[R_max_idx] = False#mask for other bins
            valid_bins_mask = pixel_counts > 0
            final_mask = other_mask & valid_bins_mask
            R_bar_other = torch.mean(responses_k[final_mask]) if torch.sum(final_mask) > 0 else 0.0#mean response for other bins
            ds_generic = (R_max_val - R_bar_other) / (R_max_val + R_bar_other + 1e-8)#calcuculate the ds score
            ds_scores.append(ds_generic.item())
            
            R_target_val = responses_k[target_bin]# response for the target bin
            other_target_mask = torch.ones(num_bins, dtype=bool, device=device); other_target_mask[target_bin] = False#mean for other bins excluding the target
            final_target_mask = other_target_mask & valid_bins_mask
            R_bar_not_target = torch.mean(responses_k[final_target_mask]) if torch.sum(final_target_mask) > 0 else 0.0
            ds_target = (R_target_val - R_bar_not_target) / (R_target_val + R_bar_not_target + 1e-8)
            ds_scores_target.append(ds_target.item())
            
            if R_max_idx == target_bin:
                correct_assignments += 1#coorect assignment if the max response bin is the target bin
                
        num_valid_units = len(ds_scores) if len(ds_scores) > 0 else 1
        
        results = {
            "avg_ds_score_generic": np.mean(ds_scores) if ds_scores else 0.0,
            "avg_ds_score_target": np.mean(ds_scores_target) if ds_scores_target else 0.0,
            "assignment_accuracy": (correct_assignments / num_valid_units)
        }
        

        print(f" DS Score: {results['avg_ds_score_generic']:.4f}")
        print(f"Ds score target :{results['avg_ds_score_target']:.4f}")
        print(f"Assigning accuracy {results['assignment_accuracy']:.2%}")
        
        return results

    except (KeyError, AttributeError, IndexError) as e:
        print(f"error {e}")
        return {"avg_ds_score_generic": 0.0, "avg_ds_score_target": 0.0, "assignment_accuracy": 0.0}

    finally:
        if hook_handle:
            hook_handle.remove()

selectivity_metrics = evaluate_selectivity(model, train_loader, CONFIG['SELECTIVITY_LAYER_NAME'], CONFIG, valid_bin_ids)
