In [41]:
import sys,os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np

import matplotlib.pyplot as plt
import cv2
import os

In [42]:
import os
import matplotlib.pyplot as plt
import numpy as np

def save_vis_from_lists(rgb_list, ours_list, depth_anything_list, gt_depth_list, dir_name='default'):
    """
    Displays multiple sets of images in rows: RGB, Ours (Predictions), DFV Depth, Depth Anything, and GT Depth.
    Handles input images of shape (1, 1, h, w) or (3, h, w) for RGB. Calculates MIN_DISP and MAX_DISP based on GT depths
    separately for each sample.
    
    Parameters:
    - rgb_list: List of RGB images (shape (3, h, w) or (1, 1, h, w)).
    - ours_list: List of model prediction images (shape (1, 1, h, w)).
    - dfv_list: List of DFV depth images (shape (1, 1, h, w)).
    - depth_anything_list: List of Depth Anything depth images (shape (1, 1, h, w)).
    - gt_depth_list: List of ground truth (GT) depth images (shape (1, 1, h, w)).
    - dir_name: Directory name to save the image visualization.
    """
    # pred viz
    outdir = '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs'
    img_save_pth = os.path.join(os.path.abspath(outdir), dir_name)
    
    # Create directory if it doesn't exist
    if not os.path.isdir(img_save_pth):
        os.makedirs(img_save_pth)
    
    num_images = len(rgb_list)

    fig, axs = plt.subplots(num_images, 4, figsize=(5,7))  # Adjust size based on number of images
    plt.subplots_adjust(wspace=0.05, hspace=0.0)  # Less vertical space between rows
    
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['font.size'] = 11
    
    for i in range(num_images):
        # Extract the (h, w) shape from (1, 1, h, w) or handle (3, h, w) for RGB
        rgb_img = rgb_list[i].squeeze()
        if rgb_img.shape[0] == 3:  # Handle (3, h, w) case for RGB images
            rgb_img = rgb_img.transpose(1, 2, 0)  # Convert (3, h, w) -> (h, w, 3)
            
            # Normalize RGB data to [0, 1] if it's float and outside this range
            if rgb_img.dtype == np.float32 or rgb_img.dtype == np.float64:
                rgb_img = np.clip(rgb_img, 0.0, 1.0)
            elif rgb_img.dtype == np.int32 or rgb_img.dtype == np.int64:
                rgb_img = np.clip(rgb_img, 0, 255)  # Clip integer values to the range [0, 255]

        ours_img = ours_list[i].squeeze()  
        depth_any_img = depth_anything_list[i].squeeze()
        gt_depth_img = gt_depth_list[i].squeeze()

        # Calculate MIN_DISP and MAX_DISP for this specific sample based on GT
        MIN_DISP = gt_depth_img.min()
        MAX_DISP = gt_depth_img.max()

        
        # Plot RGB image
        axs[i, 0].imshow(rgb_img)  # Removed cmap for RGB since it's now in (h, w, 3)
        axs[i, 0].axis('off')

        # Plot Ours (Predictions)
        axs[i, 1].imshow(ours_img, cmap='plasma', vmin=MIN_DISP, vmax=MAX_DISP)
        axs[i, 1].axis('off')

        # Plot Depth Anything
        axs[i, 2].imshow(depth_any_img, cmap='plasma', vmin=MIN_DISP, vmax=MAX_DISP)
        axs[i, 2].axis('off')

        # Plot Ground Truth (GT) Depth
        axs[i, 3].imshow(gt_depth_img, cmap='plasma')
        axs[i, 3].axis('off')

    # Save the figure
    plt.savefig(os.path.join(img_save_pth, f'{dir_name}_pred_viz_diff.png'), bbox_inches='tight', pad_inches=0, dpi=300)
    plt.close()


In [43]:
gt_depth_list_pth = [
    # '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/gt_data/0_gt.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/gt_data/10_gt.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/gt_data/20_gt.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/gt_data/30_gt.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/gt_data/100_gt.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/gt_data/110_gt.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/gt_data/160_gt.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/gt_data/180_gt.npy',
    # '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/gt_data/380_gt.npy',
    # '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/gt_data/400_gt.npy'
]

gt_depth = [np.load(pth) for pth in gt_depth_list_pth]


rgb_list_pth = [
    # '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/rgb_data/0_rgb.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/rgb_data/10_rgb.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/rgb_data/20_rgb.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/rgb_data/30_rgb.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/rgb_data/100_rgb.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/rgb_data/110_rgb.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/rgb_data/160_rgb.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/rgb_data/180_rgb.npy',
    # '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/rgb_data/380_rgb.npy',
    # '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/rgb_data/400_rgb.npy'
]

rgb = [np.load(pth) for pth in rgb_list_pth]

ours_list_pth = [
    # '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/ours_data/0_depth.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/ours_data/10_depth.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/ours_data/20_depth.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/ours_data/30_depth.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/ours_data/100_depth.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/ours_data/110_depth.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/ours_data/160_depth.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/ours_data/180_depth.npy',
    # '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/ours_data/380_depth.npy',
    # '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/ours_data/400_depth.npy'
]

ours = [np.load(pth) for pth in ours_list_pth]

depthanything_list_pth = [
    # '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/depthanything_data/0_depthanything.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/depthanything_data/10_depthanything.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/depthanything_data/20_depthanything.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/depthanything_data/30_depthanything.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/depthanything_data/100_depthanything.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/depthanything_data/110_depthanything.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/depthanything_data/160_depthanything.npy',
    '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/depthanything_data/180_depthanything.npy',
    # '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/depthanything_data/380_depthanything.npy',
    # '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/NyuVis/depthanything_data/400_depthanything.npy'
]

depthanything = [np.load(pth) for pth in depthanything_list_pth]

In [44]:
save_vis_from_lists(rgb, ours, depthanything, gt_depth, dir_name='NyuVis')