In [30]:
import sys,os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np

import matplotlib.pyplot as plt
import cv2
import os

# sys.path.append('../')

In [31]:
import os
import matplotlib.pyplot as plt
import numpy as np

def save_vis_from_lists(rgb_list, ours_list, dfv_list, depth_anything_list, gt_depth_list, dir_name='default'):
    """
    Displays multiple sets of images in rows: RGB, Ours (Predictions), DFV Depth, Depth Anything, and GT Depth.
    Handles input images of shape (1, 1, h, w) or (3, h, w) for RGB. Calculates MIN_DISP and MAX_DISP based on GT depths
    separately for each sample.
    
    Parameters:
    - rgb_list: List of RGB images (shape (3, h, w) or (1, 1, h, w)).
    - ours_list: List of model prediction images (shape (1, 1, h, w)).
    - dfv_list: List of DFV depth images (shape (1, 1, h, w)).
    - depth_anything_list: List of Depth Anything depth images (shape (1, 1, h, w)).
    - gt_depth_list: List of ground truth (GT) depth images (shape (1, 1, h, w)).
    - dir_name: Directory name to save the image visualization.
    """
    # pred viz
    outdir = '/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs'
    img_save_pth = os.path.join(os.path.abspath(outdir), dir_name)
    
    # Create directory if it doesn't exist
    if not os.path.isdir(img_save_pth):
        os.makedirs(img_save_pth)
    
    num_images = len(rgb_list)

    fig, axs = plt.subplots(num_images, 5, figsize=(9, 4))  # Adjust size based on number of images
    plt.subplots_adjust(wspace=0.00, hspace=0.04)  # Less vertical space between rows
    
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['font.size'] = 11
    
    for i in range(num_images):
        # Extract the (h, w) shape from (1, 1, h, w) or handle (3, h, w) for RGB
        rgb_img = rgb_list[i].squeeze()
        if rgb_img.shape[0] == 3:  # Handle (3, h, w) case for RGB images
            rgb_img = rgb_img.transpose(1, 2, 0)  # Convert (3, h, w) -> (h, w, 3)
            
            # Normalize RGB data to [0, 1] if it's float and outside this range
            if rgb_img.dtype == np.float32 or rgb_img.dtype == np.float64:
                rgb_img = np.clip(rgb_img, 0.0, 1.0)
            elif rgb_img.dtype == np.int32 or rgb_img.dtype == np.int64:
                rgb_img = np.clip(rgb_img, 0, 255)  # Clip integer values to the range [0, 255]

        ours_img = ours_list[i].squeeze()  
        dfv_img = dfv_list[i].squeeze()
        depth_any_img = depth_anything_list[i].squeeze()
        gt_depth_img = gt_depth_list[i].squeeze()

        # Calculate MIN_DISP and MAX_DISP for this specific sample based on GT
        MIN_DISP = gt_depth_img.min()
        MAX_DISP = gt_depth_img.max()

        # Rotate last images by 180 degrees
        if i + 1 == num_images:
            ours_img = np.rot90(ours_img, 2)
            dfv_img = np.rot90(dfv_img, 2)
            depth_any_img = np.rot90(depth_any_img, 2)
            gt_depth_img = np.rot90(gt_depth_img, 2)
            rgb_img = np.rot90(rgb_img, 2)
        
        # Plot RGB image
        axs[i, 0].imshow(rgb_img)  # Removed cmap for RGB since it's now in (h, w, 3)
        axs[i, 0].axis('off')

        # Plot Ours (Predictions)
        axs[i, 1].imshow(ours_img, cmap='plasma', vmin=MIN_DISP, vmax=MAX_DISP)
        axs[i, 1].axis('off')

        # Plot DFV Depth
        axs[i, 2].imshow(dfv_img, cmap='plasma', vmin=MIN_DISP, vmax=MAX_DISP)
        axs[i, 2].axis('off')

        # Plot Depth Anything
        axs[i, 3].imshow(depth_any_img, cmap='plasma', vmin=MIN_DISP, vmax=MAX_DISP)
        axs[i, 3].axis('off')

        # Plot Ground Truth (GT) Depth
        axs[i, 4].imshow(gt_depth_img, cmap='plasma')
        axs[i, 4].axis('off')

    # Save the figure
    plt.savefig(os.path.join(img_save_pth, f'{dir_name}_pred_viz_diff.png'), bbox_inches='tight', pad_inches=0, dpi=300)
    plt.close()


In [32]:
# load depth anything
depthAnything_260 = np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/depthAnything_data/learned_pred_260.npy')
depthAnything_381= np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/depthAnything_data/learned_pred_381.npy')
depthAnything_1700 = np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/depthAnything_data/learned_pred_1700.npy')

depthAnything = [depthAnything_260, depthAnything_381, depthAnything_1700]

ours_260 = np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/ours_data/260_depth.npy')
ours_381 = np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/ours_data/381_depth.npy')
ours_1700 = np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/ours_data/1700_depth.npy')

ours = [ours_260,ours_381, ours_1700]

dfv_260 = np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/dfv_data/260_DFV_prd.npy')
dfv_381 = np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/dfv_data/381_DFV_prd.npy')
dfv_1700 = np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/dfv_data/1700_DFV_prd.npy')


dfs = [dfv_260, dfv_381, dfv_1700]

gt_260 = np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/gt_data/260_gt.npy')
gt_381 = np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/gt_data/381_gt.npy')
gt_1700 = np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/gt_data/1700_gt.npy')

gts = [gt_260, gt_381, gt_1700]

rgb_260 = np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/rgb_data/260_rgb.npy')
rgb_381 = np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/rgb_data/381_rgb.npy')
rgb_1700 = np.load('/home/ashkanganj/workspace/2023-HybridDepth-DepthProject/results/imgs/ARKitSceneVis/rgb_data/1700_rgb.npy')


rgbs = [rgb_260, rgb_381, rgb_1700]

In [34]:
save_vis_from_lists(rgbs, ours, dfs, depthAnything, gts, 'ARKitSceneVis')