In [None]:
try:
    from core.raft_stereo import RAFTStereo
except ImportError:
    import os
    os.chdir("/RAFT-Stereo")
    from core.raft_stereo import RAFTStereo
    
FRPASS = "frames_cleanpass"

In [None]:
from train_fusion.dataloader import StereoDataset, StereoDatasetArgs

import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader


In [None]:
from fusion_args import FusionArgs
args = FusionArgs()
args.hidden_dims = [128, 128, 128]
args.corr_levels = 4
args.corr_radius = 4
args.n_downsample = 3
args.context_norm = "batch"
args.n_gru_layers = 2
args.shared_backbone = True
args.mixed_precision = True
args.corr_implementation = "reg_cuda"
args.slow_fast_gru = False
args.restore_ckpt = "models/raftstereo-realtime.pth"


args.lr = 0.001
args.train_iters = 7
args.valid_iters = 12
args.wdecay = 0.0001
args.num_steps = 100000
args.valid_steps = 1000
args.name = "StereoFusion"
args.batch_size = 4
args.fusion = "AFF"
args.shared_fusion = True
args.freeze_backbone = []
args.both_side_train= False

In [None]:
dataset = StereoDataset(
    StereoDatasetArgs(
        "/bean/depth",
        flying3d_json=True,
        flow3d_driving_json=False,
        gt_depth=True,
        validate_json=True,
        synth_no_filter=True,
    )
)
valid_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)

In [None]:
print(len(dataset))

In [None]:
model = torch.nn.DataParallel(RAFTStereo(args)).cuda()
model.load_state_dict(torch.load(args.restore_ckpt))
model.eval()
model = model.module

In [None]:
import cv2
import matplotlib.pyplot as plt
from train_fusion.noise_generator import NoiseGenerator

iterator = iter(valid_loader)
inputs = next(iterator)
image0, image1, image2, image3, dis1, dis2 = [x.cuda() for x in inputs[1:]]
image_path = inputs[0][0][0][0]
disparity_path_color = [
    image_path.replace("frames_cleanpass", f"flow_raft_rgb_env_color_{i}")
    for i in range(4)
]
disparity_rgb = [cv2.imread(x) for x in disparity_path_color]
image_path_nir = [
    image_path.replace(FRPASS, f"flow_raft_nir_env_color_{i}")
    for i in range(4)
]
disparity_nir = [cv2.imread(x) for x in image_path_nir]

plt.figure(figsize=(20, 20))


o_img_l, o_img_r = cv2.imread(image_path), cv2.imread(
    image_path.replace("left", "right")
)

noise = NoiseGenerator()


al_left = noise.all_noise(o_img_l)
al_right = noise.all_noise(o_img_r)
image_dict = {
    "origin": (o_img_l, o_img_r),
    "noised": (al_left, al_right),
    # "read_noise": (noise.add_read_noise(o_img_l), noise.add_read_noise(o_img_r)),
    # "read_noise_std20": (noise.add_read_noise(o_img_l, std_dev=20), noise.add_read_noise(o_img_r, std_dev=20)),
    # "read_noise_std40": (noise.add_read_noise(o_img_l, std_dev=40), noise.add_read_noise(o_img_r, std_dev=40)),
    # "shot_noise": (noise.add_shot_noise(o_img_l), noise.add_shot_noise(o_img_r)),
    # "thermal_noise_std20": (noise.add_thermal_noise(o_img_l, std_dev=20), noise.add_thermal_noise(o_img_r, std_dev=20)),
    # "thermal_noise_std40": (noise.add_thermal_noise(o_img_l, std_dev=40), noise.add_thermal_noise(o_img_r, std_dev=40)),
    # "dark_current_noise": (noise.add_dark_current_noise(o_img_l), noise.add_dark_current_noise(o_img_r)),
    # "quantization_noise": (noise.add_quantization_noise(o_img_l), noise.add_quantization_noise(o_img_r)),
    # "fixed_pattern_noise": (noise.add_fixed_pattern_noise(o_img_l), noise.add_fixed_pattern_noise(o_img_r)),
    # "burned_image_64": (
    #     noise.larger_intensity(al_left, 64, 224),
    #     noise.larger_intensity(al_right, 64, 224),
    # ),
    # "burned_image_64128": (
    #     noise.larger_intensity(al_left, 64, 128),
    #     noise.larger_intensity(al_right, 64, 128),
    # ),
    # "burned_image_64192": (
    #     noise.larger_intensity(al_left, 64, 192),
    #     noise.larger_intensity(al_right, 64, 192),
    # ),
    # "burned_image_96": (
    #     noise.larger_intensity(al_left, 96, 224),
    #     noise.larger_intensity(al_right, 96, 224),
    # ),
    # "burned_image_128": (
    #     noise.larger_intensity(al_left, 128, 224),
    #     noise.larger_intensity(al_right, 128, 224),
    # ),
    # "all_noise_darker_beta_50_gamma01": (
    #     noise.darker_image(o_img_l, beta=50, gamma=0.1),
    #     noise.darker_image(o_img_r, beta=50, gamma=0.1),
    # ),
    # "all_noise_darker_beta_50_gamma02": (
    #     noise.darker_image(o_img_l, beta=50, gamma=0.2),
    #     noise.darker_image(o_img_r, beta=50, gamma=0.2),
    # ),
    # "all_noise_darker_burned_64": (
    #     noise.darker_image(noise.larger_intensity(al_left, 64, 224), beta=50),
    #     noise.darker_image(noise.larger_intensity(al_right, 64, 224), beta=50),
    # ),
    # "all_noise_darker_burned_64128": (
    #     noise.darker_image(noise.larger_intensity(al_left, 64, 128), beta=50),
    #     noise.darker_image(noise.larger_intensity(al_right, 64, 128), beta=50),
    # ),
    "burned_v2_image_64100" :(
        noise.burnt_effect(o_img_l, 64, 100),
        noise.burnt_effect(o_img_r, 64, 100),
    ),
    "burned_v2_image_64500" :(
        noise.burnt_effect(o_img_l,64, 500),
        noise.burnt_effect(o_img_r, 64, 500),
    ),
    "burned_v2_image_641000" :(
        noise.burnt_effect(o_img_l, 64, 1000),
        noise.burnt_effect(o_img_r, 64, 1000),
    ),
    "noised_burned_v2_image_64100" :(
        noise.burnt_effect(al_left, 64, 100),
        noise.burnt_effect(al_right, 64, 100),
    ),
    "noised_burned_v2_image_64500" :(
        noise.burnt_effect(al_left,64, 500),
        noise.burnt_effect(al_right, 64, 500),
    ),
    "noised_burned_v2_image_641000" :(
        noise.burnt_effect(al_left, 64, 1000),
        noise.burnt_effect(al_right, 64, 1000),
    ),
    "darker_shadow":(
        noise.darken_shadows(o_img_l, 64, 500),
        noise.darken_shadows(o_img_r, 64, 500),
    ),
    "darker_shadow_burnt":(
        noise.darken_shadows(noise.burnt_effect(o_img_l,64, 500), 64, 500),
        noise.darken_shadows(noise.burnt_effect(o_img_r,64, 500), 64, 500),
    ),
    # "darker_bunred_200": (
    #      noise.darker_image(noise.burnt_effect(o_img_l, 64, 500), beta=50),
    #      noise.darker_image(noise.burnt_effect(o_img_r, 64, 500), beta=50),
    #  ),
    # "darker_bunred_150": (
    #     noise.darker_image(noise.burnt_effect(o_img_l, 96, 500), beta=50),
    #     noise.darker_image(noise.burnt_effect(o_img_r, 96, 500), beta=50),
    # ),
    # "darker_bunred_128": (
    #     noise.darker_image(noise.larger_intensity(o_img_l, 128, 224), beta=50),
    #     noise.darker_image(noise.larger_intensity(o_img_r, 128, 224), beta=50),
    # ),
    # "all_noise": (all_noise(o_img_l), all_noise(o_img_r)),
    # "all_noise_darker": (darker_image(all_noise(o_img_l)), darker_image(all_noise(o_img_r))),
    # "all_noise_darker_alpha_2": (darker_image(all_noise(o_img_l), alpha=2.0), darker_image(all_noise(o_img_r), alpha=2.0)),
    # "all_noise_darker_alpha_3": (darker_image(all_noise(o_img_l), alpha=3.0), darker_image(all_noise(o_img_r), alpha=3.0)),
    # "all_noise_darker_beta_50": (darker_image(all_noise(o_img_l), beta=50), darker_image(all_noise(o_img_r), beta=50)),
    # "all_noise_darker_beta_100": (darker_image(all_noise(o_img_l), beta=100), darker_image(all_noise(o_img_r), beta=100)),
    # "all_noise_darker_alpha_2_beta_50": (darker_image(all_noise(o_img_l), alpha=2.0, beta=50), darker_image(all_noise(o_img_r), alpha=2.0, beta=50)),
    # "all_noise_darker_alpha_3_beta_50": (darker_image(all_noise(o_img_l), alpha=3.0, beta=50), darker_image(all_noise(o_img_r), alpha=3.0, beta=50)),
    # "all_noise_darker_alpha_2_beta_100": (darker_image(all_noise(o_img_l), alpha=2.0, beta=100), darker_image(all_noise(o_img_r), alpha=2.0, beta=100)),
    # "all_noise_darker_alpha_3_beta_100": (darker_image(all_noise(o_img_l), alpha=3.0, beta=100), darker_image(all_noise(o_img_r), alpha=3.0, beta=100)),
    "nir_rendered": (image2[0], image3[0]),
    "nir_ambient": (image2[1], image3[1]),
    
}

_, axs = plt.subplots(
    len(image_dict.keys()), 4, figsize=(20, 5 * len(image_dict.keys()))
)
for i, (key, image) in enumerate(image_dict.items()):
    DISP_MAX = 128
    DISP_MIN = 16
    left, right = image
    if type(left) == np.ndarray:
        left = torch.from_numpy(left).permute(2, 0, 1).unsqueeze(0).float().cuda()
        right = torch.from_numpy(right).permute(2, 0, 1).unsqueeze(0).float().cuda()
    else:
        left = left.unsqueeze(0).float()
        right = right.unsqueeze(0).float()
    if left.shape[1] == 1:
        left = torch.cat([left, left, left], dim=1)
        right = torch.cat([right, right, right], dim=1)
    with torch.no_grad():
        _, flow = model(left, right, iters=args.valid_iters, test_mode=True)
    flow = -flow[0].permute(1, 2, 0).cpu().numpy()

    axs[i, 0].imshow(left[0].permute(1, 2, 0).cpu().numpy().astype(np.uint8))
    axs[i, 0].set_title(f"{key} Left Image")
    axs[i, 1].imshow(right[0].permute(1, 2, 0).cpu().numpy().astype(np.uint8))
    axs[i, 1].set_title(f"{key} Right Image")
    et = axs[i, 2].imshow(flow, vmin=DISP_MIN, vmax=DISP_MAX, cmap="magma")
    axs[i, 2].set_title(f"{key} Flow")
    axs[i, 2].axis("off")
    plt.colorbar(et, ax=axs[i, 2])

    disparity_gt = dis1[0].permute(1, 2, 0).cpu().numpy()
    flow = flow[: disparity_gt.shape[0], : disparity_gt.shape[1]]

    error = noise.compute_disparity_gt_error(disparity_gt, flow, False)

    et = axs[i, 3].imshow(255 - error, vmin=0, vmax=255, cmap="jet")
    axs[i, 3].set_title(f"{key} GT Error")
    axs[i, 3].axis("off")
    plt.colorbar(et, ax=axs[i, 3])

    # et = axs[i, 4].imshow(disparity_gt, vmin=DISP_MIN, vmax=DISP_MAX, cmap="magma")
    # axs[i, 4].set_title(f"{key} Gt")
    # axs[i, 4].axis("off")
    # plt.colorbar(et, ax=axs[i, 4])


plt.show()

In [None]:
import dis
import json
from typing import List
import cv2
import os
from tqdm.notebook import tqdm
from train_fusion.loss_function import warp_reproject_loss, gt_loss
from train_fusion.noise_generator import NoiseGenerator

noise = NoiseGenerator()


batch_len = len(valid_loader)

input_title = [
    "rgb",
    "burnt",
    "burnt_light",
    "darken",
    "darken_gain",
    "nir_rendered",

]
mode_len = len(input_title)
total_loss_dict = [{} for _ in range(mode_len)]


def numpy_to_torch(imgs: list[np.ndarray]):
    return [torch.from_numpy(img).float().cuda().permute(2, 0, 1) for img in imgs]


def rgb_noised_input_pairs(rgb_left: np.ndarray, rgb_right: np.ndarray):
    (
        rgb_left,
        rgb_right,
        rgb_burnt_left,
        rgb_burnt_right,
        rgb_light_left,
        rgb_light_right,
        rgb_darken_left,
        rgb_darken_right,
        rgb_darken_gain_left,
        rgb_darken_gain_right,
    ) = numpy_to_torch(
        [
            rgb_left,
            rgb_right,
            noise.filter_image_burn(rgb_left),
            noise.filter_image_burn(rgb_right),
            noise.filter_image_burn_light(rgb_left),
            noise.filter_image_burn_light(rgb_right),
            noise.filter_image_dark(rgb_left),
            noise.filter_image_dark(rgb_right),
            noise.filter_image_dark_high_gain(rgb_left),
            noise.filter_image_dark_high_gain(rgb_right),
        ]
    )
    left_arr = torch.stack(
        [
            rgb_left,
            rgb_burnt_left,
            rgb_light_left,
            rgb_darken_left,
            rgb_darken_gain_left,
        ]
    )
    right_arr = torch.stack(
        [
            rgb_right,
            rgb_burnt_right,
            rgb_light_right,
            rgb_darken_right,
            rgb_darken_gain_right,
        ]
    )
    return left_arr, right_arr


def unpack_batch_create_pair_arr(img_cuda: List[torch.Tensor]):
    img_cuda[2] = img_cuda[2].repeat(1, 3, 1, 1)
    img_cuda[3] = img_cuda[3].repeat(1, 3, 1, 1)

    rgb_left = img_cuda[0][0].permute(1, 2, 0).cpu().numpy().astype(np.uint8)
    rgb_right = img_cuda[1][0].permute(1, 2, 0).cpu().numpy().astype(np.uint8)
    nir_rendered_left = img_cuda[2][0]
    nir_rendered_right = img_cuda[3][0]
    left_arr, right_arr = rgb_noised_input_pairs(rgb_left, rgb_right)
    left_arr = torch.cat(
        [left_arr, nir_rendered_left.unsqueeze(0)]
    )
    right_arr = torch.cat(
        [right_arr, nir_rendered_right.unsqueeze(0)]
    )
    return left_arr, right_arr


def compute_absolute_error(flow, flow_gt):
    h, w = flow.shape[-2:]
    hf = int(h / 2 - h / 4)
    ht = int(h / 2 + h / 4)
    wf = int(w / 2 - w / 4)
    wt = int(w / 2 + w / 4)

    rmse_rgb = torch.sqrt(torch.mean((-_flow - flow_gt)[:, :, hf:ht, wf:wt] ** 2))

    mae_rgb = torch.mean(torch.abs(-_flow - flow_gt)[:, :, hf:ht, wf:wt])

    ard_rgb = torch.mean(
        torch.abs(-_flow - flow_gt)[:, :, hf:ht, wf:wt] / flow_gt[:, :, hf:ht, wf:wt]
    )
    return {
        "rmse": rmse_rgb.item(),
        "mae": mae_rgb.item(),
        "ard": ard_rgb.item(),
    }


def save_disparity(image_path: str, disparity: np.ndarray):
    disparity_path_npz = image_path.replace(
        FRPASS, f"disparity_raft_env_{env_title}"
    ).replace("png", "npz")
    os.makedirs(os.path.dirname(disparity_path_npz), exist_ok=True)
    np.savez_compressed(disparity_path_npz, disparity=disparity)
    disparity_path_color = image_path.replace(
        FRPASS, f"disparity_raft_env_color_{env_title}"
    )
    os.makedirs(os.path.dirname(disparity_path_color), exist_ok=True)

    disparity_color = cv2.applyColorMap(
        np.clip(disparity, 0, 256).astype(np.uint8), cv2.COLORMAP_MAGMA
    )

    cv2.imwrite(disparity_path_color, disparity_color)


def save_image_pair(
    image_path_rgb: str, left_tensor: torch.Tensor, right_tensor: torch.Tensor
):
    os.makedirs(os.path.dirname(image_path_rgb), exist_ok=True)
    image_left = left_tensor.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
    print(image_path_rgb, image_left.shape)
    cv2.imwrite(image_path_rgb, image_left)
    image_path_rgb = image_path_rgb.replace("left", "right")
    os.makedirs(os.path.dirname(image_path_rgb), exist_ok=True)
    image_right = right_tensor.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
    cv2.imwrite(image_path_rgb, image_right)




In [None]:
for i_batch, input_valid in enumerate(tqdm(valid_loader)):

    image_list, *blob = input_valid
    img_cuda = [img.cuda() for img in blob]
    left_arr, right_arr = unpack_batch_create_pair_arr(img_cuda)
    image_path = image_list[0][0]
    for i in range(mode_len):
        env_title = input_title[i]
        if not env_title in ["rgb", "nir_rendered", "nir_ambient"]:
            save_image_pair(
                image_path.replace(
                    FRPASS, f"frame_{env_title}_filtered"
                ),
                left_arr[i],
                right_arr[i],
            )


In [None]:
for i_batch, input_valid in enumerate(tqdm(valid_loader)):
    with torch.no_grad():
        image_list, *blob = input_valid
        img_cuda = [img.cuda() for img in blob]
        left_arr, right_arr = unpack_batch_create_pair_arr(img_cuda)

        _, flows = model(left_arr, right_arr, iters=args.valid_iters, test_mode=True)

        flows = flows[:, :, : img_cuda[0].shape[2], : img_cuda[0].shape[3]]
        image_path = image_list[0][0][0]
        flow_gt = img_cuda[4][0:1]
        for i in range(mode_len):
            _flow = flows[i : i + 1]
            _, loss_gt = gt_loss(None, flow_gt, [_flow])
            _, loss_self = warp_reproject_loss(
                [_flow], left_arr[i : i + 1], right_arr[i : i + 1]
            )
            ab_loss = compute_absolute_error(_flow, flow_gt)
            loss = {
                **loss_gt,
                **loss_self,
                **ab_loss,
            }

            env_title = input_title[i]

            disparity = -_flow[0].permute(1, 2, 0).cpu().numpy()

            for key in loss:
                if key not in total_loss_dict[i]:
                    total_loss_dict[i][key] = 0
                total_loss_dict[i][key] += loss[key] / batch_len
            loss_dict_path = image_path.replace(
                FRPASS, f"loss_raft_env_{env_title}"
            ).replace(".png", ".json")

            os.makedirs(os.path.dirname(loss_dict_path), exist_ok=True)
            json.dump(loss, open(loss_dict_path, "w"))

            """
            Save disparity and filtered input image
            """
            save_disparity(image_path, disparity)

            if not env_title in ["rgb", "nir_rendered", "nir_ambient"]:
                save_image_pair(
                    image_path.replace(
                        FRPASS, f"frame_{env_title}_filtered"
                    ),
                    left_arr[i],
                    right_arr[i],
                )


print(total_loss_dict)
json.dump(total_loss_dict, open("loss_raft_env.json", "w"))

In [None]:
total_loss_dict = [{}, {}, {}, {}, {}, {}, {}]
batch_len = 0
for i_batch, input_valid in enumerate(tqdm(valid_loader)):
    image_path = image_list[0][0][0]
    for i in range(mode_len):
        env_title = input_title[i]
        loss_dict_path = image_path.replace(
                FRPASS, f"loss_raft_env_{env_title}"
            ).replace(".png", ".json")
        if not os.path.exists(loss_dict_path):
            batch_len = i_batch
            break
        loss_dict = json.load(open(loss_dict_path))
        for key, value in loss_dict.items():
            if key not in total_loss_dict[i]:
                total_loss_dict[i][key] = 0
            total_loss_dict[i][key] += value
            


In [None]:
for i in range(mode_len):
    for key in total_loss_dict[i]:
        total_loss_dict[i][key] /= 1318
        
print(total_loss_dict)
with open("loss_raft_env.json", "w") as f:
    json.dump(total_loss_dict, f)

In [None]:
total_loss_dict_adjusted = [
    {key: value for key, value in loss_dict.items()}
    for loss_dict in total_loss_dict
]
for dict in total_loss_dict_adjusted:
    for key in dict:
        dict[key] = dict[key]  * 26694/ 4369
print (total_loss_dict_adjusted)

json.dump(total_loss_dict_adjusted, open("loss_raft_env_adjusted3.json", "w"))

In [None]:
for dict in total_loss_dict_adjusted:
    print(dict)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import json
total_loss_dict_adjusted = json.load(open("loss_raft_env_adjusted3.json"))

# 각 dictionary의 키
keys = list(total_loss_dict_adjusted[0].keys())
keys = [key for key in keys if not "ssim" in key and not "l1" in key]




width = 0.2  # the width of the bars

fig, axs = plt.subplots(len(keys), 1, figsize=(10, 30))


for i, key in enumerate(keys):
    values = [dict[key] for dict in total_loss_dict_adjusted]
    print(len(input_title), len(values))
    axs[i].bar(input_title[:6], values, width)
    axs[i].set_ylabel(key)
    axs[i].set_title(key)



plt.show()

In [None]:
import os
import cv2


def get_scene_list(path: str):
    scene_list = [os.path.join(path, x) for x in os.listdir(path) if x.endswith(".png")]
    scene_list.sort()
    return scene_list


def get_scene_prop_path(path: str):
    path_img_left = path
    path_img_right = path.replace("left", "right")
    path_disparity_gt = path.replace(FRPASS, "disparity").replace(".png", ".pfm")
    path_dict = {
        "rgb": (path_img_left, path_img_right),
    }
    flow_path_dict = {
        
    }
    for title in ["rgb", "burnt", "burnt_light", "darken", "darken_gain", "nir_rendered"]:
        if title == "rgb" or title == "nir_ambient":
            continue
        if title == "nir_rendered":
            path_name = path_img_left.replace(FRPASS, "nir_rendered")
        else:
            path_name = path_img_left.replace(FRPASS, f"frame_{title}_filtered")
        path_dict[title] = (
            path_name,
            path_name.replace("left", "right"),
        )
    for title in ["rgb", "burnt", "burnt_light", "darken", "darken_gain", "nir_rendered"]:
        if title == "nir_ambient":
            continue
        path_name = path.replace(FRPASS, f"disparity_raft_env_{title}").replace("png", "npz")
        flow_path_dict[title] = path_name
    return {
        "input": path_dict,
        "output": flow_path_dict,
        "gt": path_disparity_gt,
    }
        
        


folder = "/bean/flyingthings3d/frames_cleanpass/TRAIN/A/0000/left"
scene_root_list = get_scene_list(folder)

In [None]:

import cv2

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from train_fusion.noise_generator import NoiseGenerator



def plot_disparity_grid(scenes, keys = ["rgb", "burnt", "burnt_light", "darken", "darken_gain", "nir_rendered"]):
    """
    scenes: list of dictionaries containing paths to images and titles for each row
    """
    num_rows = len(keys)
    
    # Define GridSpec with a smaller width ratio for the 3rd column
    gs = gridspec.GridSpec(num_rows, 6, width_ratios=[5, 5, 5, 1, 5, 1])
    
    fig = plt.figure(figsize=(28, 4 * num_rows))

    for row_idx, key in enumerate(keys):
        # Load images
        img_left = scenes["input"][key][0]
        img_right = scenes["input"][key][1]
        disparity = scenes["output"][key]
        gt = scenes["gt"]
        print(img_left, img_right, disparity)
        disparity = np.load(disparity)["disparity"].squeeze()
        
        img_left = cv2.imread(img_left)
        img_right = cv2.imread(img_right)
        disparity_error = NoiseGenerator().compute_disparity_gt_error(gt, disparity)
        title = key

        # Convert BGR to RGB for correct color display
        if len(img_left.shape) == 3 and img_left.shape[2] == 3:
            img_left = cv2.cvtColor(img_left, cv2.COLOR_BGR2RGB)
        if len(img_right.shape) == 3 and img_right.shape[2] == 3:
            img_right = cv2.cvtColor(img_right, cv2.COLOR_BGR2RGB)
        if len(disparity.shape) == 3 and disparity.shape[2] == 3:
            disparity = cv2.cvtColor(disparity, cv2.COLOR_BGR2RGB)

        # Plot images
        ax1 = fig.add_subplot(gs[row_idx, 0])
        ax1.imshow(img_left, cmap="gray" if img_left.ndim == 2 else None)
        ax1.axis("off")
        ax1.set_title("Left Image", fontsize=12)

        ax2 = fig.add_subplot(gs[row_idx, 1])
        ax2.imshow(img_right, cmap="gray" if img_right.ndim == 2 else None)
        ax2.axis("off")
        ax2.set_title("Right Image", fontsize=12)

        ax3 = fig.add_subplot(gs[row_idx, 2])
        ax3.imshow(disparity, cmap="gray" if disparity.ndim == 2 else None)
        ax3.axis("off")
        ax3.set_title("Disparity", fontsize=12)

        # Add colorbar as an image with adjusted aspect ratio
        gradient = np.linspace(0, 1, 256).reshape(-1, 1)
        gradient = np.repeat(gradient, 10, axis=1)  # Make it 1/10th the width of its height

        ax4 = fig.add_subplot(gs[row_idx, 3])
        ax4.imshow(gradient, aspect='auto', cmap='magma', vmin=0, vmax=1)
        ax4.set_aspect(1)
        ax4.set_yticks(np.linspace(0, 255, num=5))
        ax4.set_yticklabels(np.linspace(0, 255, num=5, dtype=int))
        ax4.set_xticks([])

        # Compute and plot disparity error
        
        ax5 = fig.add_subplot(gs[row_idx, 4])
        ax5.imshow(disparity_error)
        ax5.axis("off")
        ax5.set_title("Gt Error", fontsize=12)


        # Add colorbar as an image with adjusted aspect ratio
        gradient2 = np.linspace(0, 1, 256).reshape(-1, 1)
        gradient2 = np.repeat(gradient2, 10, axis=1)  # Make it 1/10th the width of its height

        ax6 = fig.add_subplot(gs[row_idx, 5])
        ax6.imshow(gradient2, aspect='auto', cmap='jet', vmin=0, vmax=1)
        ax6.set_aspect(1)
        ax6.set_yticks(np.linspace(0, 255, num=5))
        ax6.set_yticklabels(np.linspace(0, 4, num=5, dtype=int))
        ax6.set_xticks([])
        # Add row title
        ax1.text(
            -0.5,
            0.5,
            title,
            fontsize=14,
            fontweight="bold",
            va="center",
            ha="right",
            rotation=90,
            transform=ax1.transAxes,
        )
        
        
        

    plt.tight_layout()
    plt.show()
    plt.savefig("disparity_grid.png")
    plt.close()
    return cv2.imread("disparity_grid.png")







In [None]:
from tqdm.notebook import tqdm
scene_root_list = get_scene_list("/bean/flyingthings3d/frames_cleanpass/TEST/A/0000/left")

for scene in tqdm(scene_root_list):
    scene_prop = get_scene_prop_path(scene)
    plot_image = plot_disparity_grid(scene_prop)
    output_path = scene.replace("frames_cleanpass", "plot")
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    cv2.imwrite(scene.replace("frames_cleanpass", "plot"), plot_image)





In [None]:
from IPython.display import Video, display
# 비디오 파일로 저장하는 함수
def save_video(image_paths: list[str], video_path, fps=2):
    height, width = cv2.imread(image_paths[0].replace("frames_cleanpass", "plot")).shape[:2]
    print(width, height)
    # 비디오 파일 쓰기 설정
    fourcc = cv2.VideoWriter.fourcc(*'mp4v')  # 코덱 설정
    out = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
    
    for img in tqdm(image_paths):
        img = img.replace("frames_cleanpass", "plot")
        if not os.path.exists(img):
            break
        image = cv2.imread(img)
        out.write( image)
    
    out.release()
    print(f"비디오가 저장되었습니다: {video_path}")

# 비디오 파일로 저장
video_path = "output_video.mp4"
save_video(scene_root_list, video_path)

# 주피터 노트북에서 비디오 재생
display(Video(video_path, embed=True))