In [1]:
from torch.utils.data import DataLoader
from statistics import mode
import matplotlib.pyplot as plt
import numpy as np

import os
import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm
import torch.optim as optim

# Preliminaries

### Dataset

In [2]:
# import os, random, glob

import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torch_geometric.data import Data
import h5py, torch, cv2
import numpy as np
from statistics import mode
import torch.nn.functional as F

from torch import nn, optim
import glob



class DepthDataset(Dataset):
    
    def __init__(self, data_path = "/home3/fsml62/LLM_and_SGG_for_MDE/dataset/nyu_depth_v2/official_splits", transform=None, ext="jpg", mode='train', threshold=0.06):
        
        self.data_path = data_path
        
        self.filenames = glob.glob(os.path.join(self.data_path, mode, '**', '*.{}'.format(ext)), recursive=True)
        self.pt_path = "/home3/fsml62/LLM_and_SGG_for_MDE/GNN_for_MDE/results/depth_embedding/nyu_depth_v2/official_splits"
        self.depth_map_path = "/home3/fsml62/LLM_and_SGG_for_MDE/GNN_for_MDE/results/depth_map/nyu_depth_v2/official_splits"
#         self.depth_map_path = "/home3/fsml62/LLM_and_SGG_for_MDE/GNN_for_MDE/results/adabins/depth_map/nyu_depth_v2/official_splits"
        self.sg_path = "/home3/fsml62/LLM_and_SGG_for_MDE/GNN_for_MDE/results/SGG/nyu_depth_v2/official_splits"

        self.transform = transform if transform else transforms.ToTensor()
        self.mode = mode
        self.threshold = threshold
        
        self.cache = {}
        
        



    def __getitem__(self, idx):
        
        
        if idx in self.cache:
            return self.cache[idx]
        

        # image path
        img_path = self.filenames[idx]
        # get the image
        img = Image.open(img_path)

        # get the relative path
        relative_path = os.path.relpath(img_path, self.data_path)
        # get depth embedding path
        depth_emb_path = os.path.join(self.pt_path, '{}.pt'.format(relative_path.split('.')[0]))
        # get depth map path
        depth_path = os.path.join(self.mode, self.depth_map_path, '{}.pt'.format(relative_path.split('.')[0]))

        #get the scene graph path
        scenegraph_path = os.path.join(self.sg_path, '{}.h5'.format(relative_path.split('.')[0]))



        ## Depth Embedding
        depth_emb = self.normalize(torch.load(depth_emb_path))

        ## depth map
        depth_map = self.normalize(torch.load(depth_path))

        ## get the actual depth
        actual_depth_path = img_path.replace("rgb", "sync_depth").replace('.jpg', '.png')
        actual_depth = Image.open(actual_depth_path)

        
        # the resize size
        target_size = (25, 25)


        with h5py.File(scenegraph_path, 'r') as h5_file:
            loaded_output_dict = {key: torch.tensor(np.array(h5_file[key])) for key in h5_file.keys()}


        probas = loaded_output_dict['rel_logits'].softmax(-1)[0, :, :-1]
        probas_sub = loaded_output_dict['sub_logits'].softmax(-1)[0, :, :-1]
        probas_obj = loaded_output_dict['obj_logits'].softmax(-1)[0, :, :-1]
        
        
        keep = torch.logical_and(probas.max(-1).values > self.threshold, 
                                torch.logical_and(probas_sub.max(-1).values > self.threshold, probas_obj.max(-1).values > self.threshold))
        
        
        
      
        
        sub_bboxes_scaled = self.rescale_bboxes(loaded_output_dict['sub_boxes'][0, keep], img.size)
        obj_bboxes_scaled = self.rescale_bboxes(loaded_output_dict['obj_boxes'][0, keep], img.size)
        relations = loaded_output_dict['rel_logits'][0, keep]

        valid_sub_bboxes = self.validate_bounding_boxes(sub_bboxes_scaled, img.size, target_size)
        valid_obj_bboxes = self.validate_bounding_boxes(obj_bboxes_scaled, img.size, target_size)

        # Combine validity of subject and object bounding boxes
        valid_pairs = torch.tensor([vs and vo for vs, vo in zip(valid_sub_bboxes, valid_obj_bboxes)], dtype=torch.bool)

        # Apply the updated keep mask
        sub_bboxes_scaled = sub_bboxes_scaled[valid_pairs]
        obj_bboxes_scaled = obj_bboxes_scaled[valid_pairs]
        relations = relations[valid_pairs]
        
        
        
        sub_idxs, nodes1 = self.assign_index(sub_bboxes_scaled, [], threshold=0.7)
        obj_idxs, nodes2 = self.assign_index(obj_bboxes_scaled, nodes1, threshold=0.7)
        
        all_idxs = sub_idxs + obj_idxs
        bbox_lists = torch.concat((sub_bboxes_scaled, obj_bboxes_scaled), dim=0)
        
        unique_idxs = set()
        filtered_idxs = []
        filtered_bboxes = []

        for idx, bbox in zip(all_idxs, bbox_lists):   
            if idx not in unique_idxs:
                unique_idxs.add(idx)
                filtered_idxs.append(idx)
                filtered_bboxes.append(bbox.tolist())  

        # Sort the filtered indices along with their corresponding bounding boxes
        sorted_wrapped = sorted(zip(filtered_idxs, filtered_bboxes), key=lambda x: x[0])
        sorted_idxs, sorted_bboxes = zip(*sorted_wrapped)

        # Convert them back to torch tensors
        sorted_idxs = torch.tensor(sorted_idxs, dtype=torch.long)
        sorted_bboxes = torch.tensor(sorted_bboxes, dtype=torch.int32)        
        

        
        
        # Apply transform to image and depth
        img = self.transform(img)
        img = self.normalize(img)
        actual_depth = self.transform(actual_depth).float()
        actual_depth = self.normalize(actual_depth)
        
        device = 'cpu'
        
#         pooled_images = self.pool_visual_content_and_depth(
#             sorted_bboxes=sorted_bboxes,
#             embedding=img,
#             target_size=target_size).to(device)
        
#         pooled_depths = self.pool_visual_content_and_depth(
#             sorted_bboxes=sorted_bboxes,
#             embedding=depth_emb[0],
#             target_size=target_size).to(device)
        
#         pooled_act_depths = self.resize_depth_map(
#             sorted_bboxes=sorted_bboxes,
#             embedding=actual_depth.squeeze(0),
#             target_size=target_size,
#             method='mode').to(device)

        pooled_images = self.linear_visual_content_and_depth(
            sorted_bboxes=sorted_bboxes,
            embedding=img,
            target_size=target_size).to(device)
        
        pooled_depths = self.linear_visual_content_and_depth(
            sorted_bboxes=sorted_bboxes,
            embedding=depth_emb[0],
            target_size=target_size).to(device)
        
        
#         pooled_depth_maps = self.resize_depth_map(
#             sorted_bboxes=sorted_bboxes,
#             embedding=depth_map.squeeze(0),
#             target_size=target_size,
#             method='linear').to(device)

    ## Adabins
        pooled_depth_maps = self.resize_depth_map(
            sorted_bboxes=sorted_bboxes,
            embedding=depth_map,
            target_size=target_size,
            method='linear').to(device)
        
        
        
        pooled_act_depths = self.resize_depth_map(
            sorted_bboxes=sorted_bboxes,
            embedding=actual_depth.squeeze(0),
            target_size=target_size,
            method='linear').to(device)
        
        act_depths = self.resize_depth_map(
            sorted_bboxes=sorted_bboxes,
            embedding=actual_depth.squeeze(0),
            target_size=target_size,
            method='real')
        
#         act_depths = torch.tensor(act_depths).to(device)
        
        
        
        
        
        if isinstance(depth_map, np.ndarray):
            depth_map = torch.from_numpy(depth_map)
        
        
        # node embedding
        node_embeddings = torch.cat([pooled_images, pooled_depths], dim=1).to(device)
        
        edge_index = torch.tensor([sub_idxs, obj_idxs]).to(device)
        
#         edge_embeddings = torch.tensor(relations, dtype=torch.float).to(device)
        edge_embeddings = relations.clone().detach().float().to(device)
        
        
        
        gnndata = Data(x=node_embeddings, edge_index=edge_index, edge_attr=edge_embeddings)


        ## return data
        data = {
            'img_path': img_path,
            'image': img,
            'depth_emb': depth_emb,
            'depth_map': depth_map,
            'depth': actual_depth,
            'pooled_depth_maps': pooled_depth_maps,
            'pooled_act_depths': pooled_act_depths,
            'act_depths': act_depths,
            'bboxs': filtered_bboxes,
            'gnndata': gnndata
        }
        
        self.cache[idx] = data
        
        return data

    def __len__(self):
        return len(self.filenames)
    
    def box_cxcywh_to_xyxy(self, x):

        x_c, y_c, w, h = x.unbind(1)
        b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    
        return torch.stack(b, dim=1)
    
    def rescale_bboxes(self, out_bbox, size):

        img_w, img_h = size
        b = self.box_cxcywh_to_xyxy(out_bbox)
        b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
        
        b = torch.round(b).int()

        b[:, 0] = torch.clamp(b[:, 0], min=0, max=img_w)
        b[:, 1] = torch.clamp(b[:, 1], min=0, max=img_h)
        b[:, 2] = torch.clamp(b[:, 2], min=0, max=img_w)
        b[:, 3] = torch.clamp(b[:, 3], min=0, max=img_h)

        return b

    def validate_bounding_boxes(self, bboxes, img_size, target_size):
        """Return a list of booleans indicating whether each bounding box is valid."""
        valid_bboxes = []
        img_w, img_h = img_size
        
        t_w, t_h = target_size

        for bbox in bboxes:
            x1, y1, x2, y2 = bbox
            if x1 < 0: x1 = 0
            if y1 < 0: y1 = 0
            if x2 > img_w: x2 = img_w
            if y2 > img_h: y2 = img_h
            
            if (x2 - x1) >= t_w and (y2 - y1) >= t_h:
                valid_bboxes.append(True)
            else:
                valid_bboxes.append(False)

        return valid_bboxes


    
    def pool_visual_content_and_depth(self, sorted_bboxes, embedding, target_size=(50, 50)):

        pool = nn.AdaptiveAvgPool2d(target_size)

        pooled_embs = []

        for bbox in sorted_bboxes:
            cropped_emb = embedding[:, bbox[1]:bbox[3], bbox[0]:bbox[2]]

            if cropped_emb.dim() == 2:
                cropped_emb = cropped_emb.unsqueeze(0)  

            pooled_emb = pool(cropped_emb.unsqueeze(0)).squeeze(0)

            flattened_emb = pooled_emb.view(-1)
            
            pooled_embs.append(flattened_emb)


        pooled_embs = torch.stack(pooled_embs) if pooled_embs else torch.empty(0)

        return pooled_embs


    def linear_visual_content_and_depth(self, sorted_bboxes, embedding, target_size=(50, 50)):

        pooled_embs = []

        for bbox in sorted_bboxes:
            # Crop the embedding to the bounding box
            cropped_emb = embedding[:, bbox[1]:bbox[3], bbox[0]:bbox[2]]

            # If the cropped embedding is 2D, add a channel dimension
            if cropped_emb.dim() == 2:
                cropped_emb = cropped_emb.unsqueeze(0)

            # Bilinear downsampling using F.interpolate
            pooled_emb = F.interpolate(cropped_emb.unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze(0)

            # Flatten the downsampled embedding
            flattened_emb = pooled_emb.view(-1)

            # Append the flattened embedding to the list
            pooled_embs.append(flattened_emb)

        # Stack all the pooled embeddings or return an empty tensor if none
        pooled_embs = torch.stack(pooled_embs) if pooled_embs else torch.empty(0)

        return pooled_embs
    
    
    def calculate_iou(self, box1, box2):

        x1, y1, x2, y2 = box1
        x3, y3, x4, y4 = box2

        # Calculate intersection coordinates
        x_inter1 = max(x1, x3)
        y_inter1 = max(y1, y3)
        x_inter2 = min(x2, x4)
        y_inter2 = min(y2, y4)

        # Calculate intersection dimensions
        width_inter = max(0, x_inter2 - x_inter1)
        height_inter = max(0, y_inter2 - y_inter1)

        # Calculate intersection area
        area_inter = width_inter * height_inter

        # Calculate areas of the input boxes
        width_box1 = abs(x2 - x1)
        height_box1 = abs(y2 - y1)
        area_box1 = width_box1 * height_box1

        width_box2 = abs(x4 - x3)
        height_box2 = abs(y4 - y3)
        area_box2 = width_box2 * height_box2

        # Calculate union area
        area_union = area_box1 + area_box2 - area_inter

        # Calculate IoU
        if area_union == 0:
            return 0  # avoid division by zero
        iou = area_inter / area_union

        return iou
    
    
    def assign_index(self, bounding_boxes, nodes, threshold=0.5):
        indices = []
        existing_boxes = nodes

        for box in bounding_boxes:
            found_match = False
            for idx, existing_box in enumerate(existing_boxes):
                if self.calculate_iou(box, existing_box) > threshold:
                    indices.append(idx)
                    found_match = True
                    break

            if not found_match:
                existing_boxes.append(box)
                indices.append(len(existing_boxes) - 1)

        return indices, existing_boxes
    
    
    def normalize(self, tensor):
        tensor_min = tensor.min()
        tensor_max = tensor.max()
        normalized_tensor = (tensor - tensor_min) / ((tensor_max - tensor_min) + 1e-8)
        return normalized_tensor
    
    
### here
    def resize_depth_map(self, sorted_bboxes, embedding, target_size=(25, 25), method='linear'):

        pooled_embs = []

        for bbox in sorted_bboxes:
            cropped_emb = embedding[bbox[1]:bbox[3], bbox[0]:bbox[2]]

            if method == 'linear':
                pooled_emb = self.downsize_depth_map_bilinear(cropped_emb, target_size)
                flattened_emb = pooled_emb.view(-1)
            elif method == 'mode':
                pooled_emb = self.downsize_depth_map_mode(cropped_emb, target_size)
                flattened_emb = pooled_emb.view(-1)
            else:
                pooled_emb = cropped_emb
                flattened_emb = pooled_emb.reshape(-1)

#             flattened_emb = pooled_emb.view(-1)
            
            pooled_embs.append(flattened_emb)
        
        
        if method == 'real':
            return pooled_embs

        pooled_embs = torch.stack(pooled_embs) if pooled_embs else torch.empty(0)

        return pooled_embs
    
    
    def downsize_depth_map_mode(self, depth_map, new_size):
        
        original_height, original_width = depth_map.shape[-2], depth_map.shape[-1]
        new_height, new_width = new_size

        window_height = original_height // new_height
        window_width = original_width // new_width

        downsized_map = torch.zeros((new_height, new_width), dtype=depth_map.dtype, device=depth_map.device)

        for i in range(new_height):
            for j in range(new_width):
                # Define the window boundaries
                start_row = i * window_height
                end_row = start_row + window_height
                start_col = j * window_width
                end_col = start_col + window_width

                # Extract the window
                window = depth_map[start_row:end_row, start_col:end_col]

                # Flatten the window
                flat_window = window.flatten().cpu().numpy()  # Convert to numpy for mode calculation

                downsized_map[i, j] = mode(flat_window)

        return downsized_map
    
    def downsize_depth_map_bilinear(self, depth_map, new_size):

        if isinstance(depth_map, np.ndarray):
            depth_map = torch.from_numpy(depth_map)

        if depth_map.dim() == 2:
            depth_map = depth_map.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions

        new_height, new_width = new_size

        resized_map = F.interpolate(depth_map, size=(new_height, new_width), mode='bilinear', align_corners=False)

        return resized_map.squeeze(0).squeeze(0)

In [3]:
import torch
from torch_geometric.data import Batch, Data


def custom_collate(batch):
    
    
    batch = [item for item in batch if item is not None]

    if len(batch) == 0:
        return None
    img_path = [item['img_path'] for item in batch]
    
    # Handle images, depth maps, and other tensors separately
    images = torch.stack([item['image'] for item in batch])
    depth_embs = torch.stack([item['depth_emb'] for item in batch])
    depth_maps = torch.stack([item['depth_map'] for item in batch])
    depths = torch.stack([item['depth'] for item in batch])
    
#     pooled_act_depths = [torch.tensor(item['pooled_act_depths'], dtype=torch.int) for item in batch]
    pooled_act_depths = [item['pooled_act_depths'].clone().detach().float() for item in batch]
    pooled_depth_maps = [item['pooled_depth_maps'].clone().detach().float() for item in batch]
#     act_depths = [item['act_depths'].clone().detach().float() for item in batch]
    act_depths = []
    for item in batch:
        if isinstance(item['act_depths'], list):
            act_depths.append([x.clone().detach().float() for x in item['act_depths']])
        else:
            act_depths.append(item['act_depths'].clone().detach().float())


#     bboxs = torch.stack([item['bboxs'] for item in batch])
    bboxs = [torch.tensor(item['bboxs'], dtype=torch.int) for item in batch]

    gnndata_list = [item['gnndata'] for item in batch]

    # Batch the graph data using PyTorch Geometric's Batch
    graph_batch = Batch.from_data_list(gnndata_list)

    return {
        'img_path': img_path,
        'image': images,
        'depth_emb': depth_embs,
        'depth_map': depth_maps,
        'depth': depths,
        'pooled_act_depths': pooled_act_depths,
        'act_depths': act_depths,
        'pooled_depth_maps': pooled_depth_maps,
        'bboxs': bboxs,
        'gnndata': graph_batch,  # Batched graph data
    }

### GNN

In [4]:
from torch_geometric.nn import MessagePassing
import torch.nn as nn

class DepthGNNModel(MessagePassing):
    def __init__(self, node_features_size, edge_features_size, hidden_channels, output_size):
        super(DepthGNNModel, self).__init__(aggr='add')

        self.message_mlp = nn.Sequential(
            nn.Linear(2 * node_features_size + edge_features_size, 1024),  # Reduced size
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(1024, hidden_channels)
        )

        self.node_mlp = nn.Sequential(
            nn.Linear(node_features_size + hidden_channels, 2048),  # Reduced size
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(2048, 1024),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(1024, output_size)  # Output a flattened 25x25 depth map
        )

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        message_input = torch.cat([x_i, edge_attr, x_j], dim=-1)
        return self.message_mlp(message_input)

    def update(self, aggr_out, x):
        updated_node_features = torch.cat([x, aggr_out], dim=-1)
        return self.node_mlp(updated_node_features)


In [5]:
checkpoint_path = './model_weights/current_linear_checkpoint.pth'


def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']

In [6]:
### load model
node_features_size =  41875 # 82075 # 
edge_features_size = 52
hidden_channels = 728
output_size = 625 #1225 # 

gnn_model = DepthGNNModel(node_features_size, edge_features_size, hidden_channels, output_size).cuda()
optimizer = optim.Adam(gnn_model.parameters(), lr=1e-4)

model_check_point = load_checkpoint(gnn_model, optimizer, checkpoint_path)
print(f"Model loaded from checkpoint, starting evaluation from epoch {model_check_point + 1}")

Model loaded from checkpoint, starting evaluation from epoch 48


# Main part

### loading data

In [7]:
test_dataset = DepthDataset(mode='test')
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=custom_collate)


# Evaluation

## RUNNING ALL TESTDATASET

In [8]:
def normalize_tensor(tensor):

    epsilon = 1e-8
#     epsilon = 0
    
    min_val = torch.min(tensor)
    max_val = torch.max(tensor)
    normalized_tensor = (tensor - min_val) / (max_val - min_val + epsilon)
    
    return normalized_tensor

In [9]:
import torch
import torch.nn.functional as F


def from_flatten_map(flatten_map, depth_map_size, target_size, method='nearest'):
    
    # Ensure that flatten_map is a torch tensor
    if not isinstance(flatten_map, torch.Tensor):
        raise TypeError("flatten_map should be a PyTorch tensor.")
    
    # Reshape flatten map to depth map size
    depth_map = flatten_map.view(depth_map_size)  # Convert 1D to desired depth map size (H, W)

    # Add batch and channel dimensions for interpolation (N, C, H, W)
    depth_map = depth_map.unsqueeze(0).unsqueeze(0)  # Add batch dimension and channel dimension

    # Perform interpolation to the target size
    # Note: F.interpolate expects the input tensor format (N, C, H, W)
    upsampled_map = F.interpolate(depth_map, size=target_size, mode=method)
    
    # Remove batch and channel dimensions
    upsampled_map = upsampled_map.squeeze(0).squeeze(0)  # Remove batch and channel dimensions

    return upsampled_map
    

In [12]:
import torch

def mae(pred, target):
    """Calculate Mean Absolute Error (MAE) between prediction and target."""
    return torch.mean(torch.abs(pred - target))

def rmse(pred, target):
    """Calculate Root Mean Squared Error (RMSE) between prediction and target."""
    return torch.sqrt(torch.mean((pred - target) ** 2))

def threshold_accuracy(pred, target, threshold):
    """Calculate the percentage of pixels where the prediction is within a certain threshold of the target."""
    ratio = torch.max(pred / target, target / pred)
    return (ratio < threshold).float().mean()


In [13]:
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def check_for_nan(values, name=""):
    """Check if any of the values in the list are NaN."""
    tensor_values = torch.tensor(list(values), dtype=torch.float32, device=device)
    if torch.isnan(tensor_values).any():
        print(f"Warning: {name} contains NaN values!")
        return True
    return False

def compute_metrics(pred, target):
    """Compute selected metrics between prediction and target with additional checks."""
    # Avoid division by zero and ensure positive values for log calculation
    pred = torch.clamp(pred, max=1, min=1e-6)
    target = torch.clamp(target, max=1, min=1e-6)
    
    metrics = {
        'mae': mae(pred, target).item(),
        'rmse': rmse(pred, target).item(),
        'delta_1': threshold_accuracy(pred, target, 1.25).item(),
        'delta_2': threshold_accuracy(pred, target, 1.25 ** 2).item(),
        'delta_3': threshold_accuracy(pred, target, 1.25 ** 3).item()
    }

    return metrics

with torch.no_grad():
    num_metrics = 5  # Updated number of metrics being calculated
    total_metrics = {
        'reshaped_ground_truth': torch.zeros(num_metrics, device=device),
        'node_output': torch.zeros(num_metrics, device=device),
        'ada_output': torch.zeros(num_metrics, device=device),
        'ada_pool_output': torch.zeros(num_metrics, device=device)
    }
    valid_samples_count = {
        'reshaped_ground_truth': 0,
        'node_output': 0,
        'ada_output': 0,
        'ada_pool_output': 0
    }
    
    bad_imgs = []
    
    gnn_model.eval()

    for single_batch in tqdm(test_dataloader, desc="Evaluating"):
        gnndata = single_batch['gnndata'].to(device)
        output = gnn_model(gnndata.x, gnndata.edge_index, gnndata.edge_attr)
        count = len(single_batch['act_depths'][0])
        
        stored_img = False
        
        for img_idx in range(count):

            depth_map_size = (25, 25)  
            bbox = single_batch['bboxs'][0][img_idx]

            target_size = (
                int(bbox[3] - bbox[1]),  # Height as integer
                int(bbox[2] - bbox[0])   # Width as integer
            )

            # Ground truth
            ground_truth = single_batch['act_depths'][0][img_idx]
            ground_truth = ground_truth.reshape(target_size).to(device)

            # Reshaped ground truth
            reshaped_ground_truth = single_batch['pooled_act_depths'][0][img_idx]
            reshaped_ground_truth = from_flatten_map(reshaped_ground_truth, depth_map_size, target_size, 'nearest').to(device)

            # Node output
            node_output = output[img_idx]
            node_output = from_flatten_map(node_output, depth_map_size, target_size, 'nearest').to(device)

            # Adabins output
            ada_output = single_batch['depth_map'][0][bbox[1]:bbox[3], bbox[0]:bbox[2]].to(device)

            # Downsampled Adabins output
            ada_pool_output = single_batch['pooled_depth_maps'][0][img_idx]
            ada_pool_output = from_flatten_map(ada_pool_output, depth_map_size, target_size, 'nearest').to(device)
            
            # Normalize tensors
            ground_truth = normalize_tensor(ground_truth.float())
            reshaped_ground_truth = normalize_tensor(reshaped_ground_truth.float())
            node_output = normalize_tensor(node_output.float())
            ada_output = normalize_tensor(ada_output.float())
            ada_pool_output = normalize_tensor(ada_pool_output.float())
            
            # Compute metrics
            metrics_reshaped_ground_truth = compute_metrics(reshaped_ground_truth, ground_truth)
            metrics_node_output = compute_metrics(node_output, ground_truth)
            metrics_ada_output = compute_metrics(ada_output, ground_truth)
            metrics_ada_pool_output = compute_metrics(ada_pool_output, ground_truth)

            # Only accumulate metrics if they do not contain NaN values
            if not check_for_nan(metrics_reshaped_ground_truth.values(), "reshaped_ground_truth"):
                total_metrics['reshaped_ground_truth'] += torch.tensor(list(metrics_reshaped_ground_truth.values()), device=device)
                valid_samples_count['reshaped_ground_truth'] += 1
                stored_img = True
            
            if not check_for_nan(metrics_node_output.values(), "node_output"):
                total_metrics['node_output'] += torch.tensor(list(metrics_node_output.values()), device=device)
                valid_samples_count['node_output'] += 1
                stored_img = True

            if not check_for_nan(metrics_ada_output.values(), "ada_output"):
                total_metrics['ada_output'] += torch.tensor(list(metrics_ada_output.values()), device=device)
                valid_samples_count['ada_output'] += 1
                stored_img = True
                
            if not check_for_nan(metrics_ada_pool_output.values(), "ada_pool_output"):
                total_metrics['ada_pool_output'] += torch.tensor(list(metrics_ada_pool_output.values()), device=device)
                valid_samples_count['ada_pool_output'] += 1
                stored_img = True
                
        if stored_img:
            bad_imgs.append(single_batch['img_path'])
            
            # Clear GPU memory
            del ground_truth, reshaped_ground_truth, node_output, ada_output, ada_pool_output

# Compute average metrics across all valid samples
avg_metrics = {}
for key in total_metrics:
    if valid_samples_count[key] > 0:
        avg_metrics[key] = total_metrics[key] / valid_samples_count[key]
    else:
        avg_metrics[key] = torch.tensor([float('nan')] * num_metrics, device=device)  # Set to NaN if no valid samples

# Print the results
metric_names = ['MAE', 'RMSE', 'δ1', 'δ2', 'δ3']

print("Average Metrics for Each Output (excluding NaN values):")
for key, value in avg_metrics.items():
    print(f"{key}:")
    for name, metric_value in zip(metric_names, value):
        print(f"  {name} = {metric_value.item():.4f}")


Evaluating: 100%|██████████| 654/654 [05:00<00:00,  2.18it/s]

Average Metrics for Each Output (excluding NaN values):
reshaped_ground_truth:
  MAE = 0.0464
  RMSE = 0.1341
  δ1 = 0.9152
  δ2 = 0.9337
  δ3 = 0.9414
node_output:
  MAE = 0.2533
  RMSE = 0.3035
  δ1 = 0.3358
  δ2 = 0.5770
  δ3 = 0.6572
ada_output:
  MAE = 0.3741
  RMSE = 0.4416
  δ1 = 0.1969
  δ2 = 0.3263
  δ3 = 0.4148
ada_pool_output:
  MAE = 0.3750
  RMSE = 0.4447
  δ1 = 0.2014
  δ2 = 0.3321
  δ3 = 0.4206



