In [None]:
import cv2
import numpy as np
import os
import h5py
from tqdm.notebook import tqdm
from typing import Tuple
from PIL import Image
import matplotlib.pyplot as plt


In [None]:
try:
    from core.raft_stereo import RAFTStereo
except ImportError:
    import os
    os.chdir("/RAFT-Stereo")
    from core.raft_stereo import RAFTStereo
    
FRPASS = "frames_cleanpass"
from train_fusion.dataloader import StereoDataset, StereoDatasetArgs

import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader
from fusion_args import FusionArgs
args = FusionArgs()
args.hidden_dims = [128, 128, 128]
args.corr_levels = 4
args.corr_radius = 4
args.n_downsample = 3
args.context_norm = "batch"
args.n_gru_layers = 2
args.shared_backbone = True
args.mixed_precision = True
args.corr_implementation = "reg_cuda"
args.slow_fast_gru = False
args.restore_ckpt = "models/raftstereo-realtime.pth"


args.lr = 0.001
args.train_iters = 7
args.valid_iters = 12
args.wdecay = 0.0001
args.num_steps = 100000
args.valid_steps = 1000
args.name = "StereoFusion"
args.batch_size = 4
args.fusion = "AFF"
args.shared_fusion = True
args.freeze_backbone = []
args.both_side_train= False

In [None]:
model = torch.nn.DataParallel(RAFTStereo(args)).cuda()
model.load_state_dict(torch.load(args.restore_ckpt))
model.eval()
model = model.module

In [None]:
DISPARITY_MAX = 64


def cv2img_to_torch(img: np.ndarray):
    img = img.transpose(2, 0, 1)
    img = torch.tensor(img).cuda().float().unsqueeze(0)
    return img


def process_frame_spectral(frame_folder: str):
    left = cv2.imread(os.path.join(frame_folder, "left.png"))
    right = cv2.imread(os.path.join(frame_folder, "right.png"))
    left = cv2img_to_torch(left)
    right = cv2img_to_torch(right)
    _, flow = model(left, right, iters=15, test_mode=True)
    flow = -flow[0].detach().cpu().numpy().transpose(1, 2, 0)
    return flow


def store_disparity(
    frame_folder: str, channel: str, disp: np.ndarray, DISPARITY_MAX=64
):
    disparity_color = cv2.applyColorMap(
        (np.clip(disp, 0, DISPARITY_MAX) / DISPARITY_MAX * 255.0).astype(np.uint8),
        cv2.COLORMAP_MAGMA,
    )
    cv2.imwrite(os.path.join(frame_folder, channel, "disparity.png"), disparity_color)


def process_frame(frame_folder: str, overwrite=False):
    if (
        not overwrite
        and os.path.exists(os.path.join(frame_folder, "rgb", "disparity.png"))
        and os.path.exists(os.path.join(frame_folder, "nir", "disparity.png"))
    ):
        return
    disparity_rgb = process_frame_spectral(frame_folder + "/rgb")
    disparity_nir = process_frame_spectral(frame_folder + "/nir")

    disparity_max = max(disparity_rgb.max(), disparity_nir.max()) // 64 * 64
    store_disparity(frame_folder, "rgb", disparity_rgb, disparity_max)
    store_disparity(frame_folder, "nir", disparity_nir, disparity_max)
    return disparity_rgb, disparity_nir


def process_scene(folder: str):
    frame_folders = [
        os.path.join(folder, x)
        for x in tqdm(os.listdir(folder))
        if x.split("_")[-1].isdigit()
    ]
    print(len(frame_folders))
    frame_folders.sort()
    for frame in tqdm(frame_folders):
        try:
            process_frame(frame)
        except Exception as e:
            print(frame)
            print(e)
            continue

In [None]:
PATH = "/bean/depth/09-04-18-40-09/18_40_08_851/post.npz"
post = np.load(PATH)
print(post["transform"])

transform_mtx = post["transform"]


In [None]:
def transform_lidarpoints_to_image(
        lidar_points: np.ndarray, width: int, height,
        focal_length: float, cx: float, cy: float,
    ):
    '''
    라이다 points 를 카메라 coordinate로 옮깁니다.
    카메라 width, height 내부의 point만 반환합니다.
    '''
    transform_matrix = transform_mtx
    lidar_points = lidar_points.reshape(-1, 3) * 1000

    # 3D 라이다 포인트를 4xN 행렬로 변환

    lidar_points = np.concatenate(
        [lidar_points, np.ones((lidar_points.shape[0], 1))], axis=1
    ).T

    # 변환 행렬을 사용하여 라이다 포인트를 카메라 좌표계로 변환
    # camera_points = transform_matrix @ lidar_points
    camera_points = np.linalg.pinv(transform_matrix) @ lidar_points

    # 카메라 좌표계의 3D 포인트를 2D 이미지 좌표로 변환
    u = camera_points[0] * focal_length / camera_points[2] + cx
    v = camera_points[1] * focal_length / camera_points[2] + cy
    depth = camera_points[2]

    camera_surface = np.stack([u, v, depth], axis=1)
    lidar_points = lidar_points.T

    csf = camera_surface[
        (camera_surface[:, 0] > 0)
        & (camera_surface[:, 0] < width)
        & (camera_surface[:, 1] > 0)
        & (camera_surface[:, 1] < height)
        & (camera_surface[:, 2] > 0)
    ]


    csf = csf[np.argsort(csf[:, 2])[::-1]]
    return csf
    
def render_2dpoint_to_image(
        points: np.ndarray,  width: int, height: int, use_color: bool = True, MAX_DEPTH: float = 10000
    ):
    '''
    projected 2d lidar points 를 하나의 이미지로 변환합니다.
    '''
    
    if use_color:
        colormap = cv2.applyColorMap(
            np.linspace(0, 255, 256).astype(np.uint8), cv2.COLORMAP_MAGMA
        )

    canvas = (
        np.zeros((height, width, 3), dtype=np.uint8)
        if use_color
        else np.zeros((height, width), dtype=np.float32)
    )
    
    if not use_color:
        u = points[:, 0].astype(int)
        v = points[:, 1].astype(int)
        depth = points[:, 2]
        u = np.clip(u, 0, width - 3)
        v = np.clip(v, 0, height - 3)
        
        for i in range(3):
            for j in range(3):
                canvas[v + i, u + j] = depth
        
        return canvas

    for u, v, depth in points:
        radius = 3
        u = int(int(u) // 4 * 4)
        v = int(int(v) // 4 * 4)
        if use_color:
            depth_color = int(np.clip(depth / MAX_DEPTH * 255, 0, 255))

            # radius = int(depth / MAX_DEPTH * 10 + 5)
            r, g, b = map(int, colormap[depth_color][0])
            cv2.circle(canvas, (u, v), radius, (r, g, b), -1)
        else:
            for i in range(-radius, radius + 1):
                for j in range(-radius, radius + 1):
                    if (
                        0 <= int(v) + i < height
                        and 0 <= int(u) + j < width
                        and np.linalg.norm([i, j]) <= radius
                    ):
                        canvas[int(v) + i, int(u) + j] = depth


    colorbar = cv2.resize(colormap, (50, height))
    return np.concatenate([canvas, colorbar], axis=1)

        
def disparity_to_depth(disparity: np.ndarray, focal_length: float, baseline: float):
    '''
    disparity map을 depth map으로 변환합니다
    '''
    disparity = disparity.astype(np.float32)
    depth = focal_length * baseline / disparity
    depth[depth < 0] = 0
    depth[np.isnan(depth)] = 0
    depth[np.isinf(depth)] = 0
    return depth

In [None]:
def process_lidar_frame(lidar_points: np.ndarray, calibration: dict):
    '''
    라이다 포인트를 칼리브레이션을 바탕으로 처리합니다.
    '''
    focal_length: float = calibration["mtx_left"][0,0]
    cx: float = calibration["mtx_left"][0,2]
    cy: float = calibration["mtx_left"][1,2]

    image_size: Tuple[int, int] = calibration["image_size"]
    
    lidar_camera_points = transform_lidarpoints_to_image(lidar_points, image_size[0], image_size[1], focal_length, cx, cy)
    
    return lidar_camera_points

def process_lidar_h5file(h5file: str, overwrite: bool = False):
    '''
    h5file 내부의 모든 lidar points에 대해 카메라 좌표계로 변환하고, 이미지에 렌더링합니다.
    '''
    with h5py.File(h5file, "a") as f:
        calibration = f["calibration"].attrs
        depth_median = 0
        keys = list(f['frame'].keys())
        if not overwrite and "projected_points" in f.require_group(f"frame/{keys[-1]}")["lidar"]:
            f.close()
            return
        for frame in tqdm(f["frame"]):
            
            if not "image_size" in calibration:
                image = Image.open(os.path.join(os.path.dirname(h5file), frame, "rgb","left.png"))
                calibration["image_size"] = image.size
            frame = f.require_group(f"frame/{frame}")
            if "projected_points" in frame["lidar"]:
                continue
            lidar_points = frame["lidar/points"][:]
            
            lidar_projected_points = process_lidar_frame(lidar_points, calibration)
            depth_median = np.max([depth_median, np.median(lidar_projected_points[:,2])])
            if "projected_points" in frame["lidar"]:
                del frame["lidar/projected_points"]
            frame.create_dataset("lidar/projected_points", data=lidar_projected_points)
        
        for frame in f["frame"]:
            if os.path.exists(os.path.join(os.path.dirname(h5file), frame, "lidar.png")):
                continue
            frame_group = f.require_group(f"frame/{frame}")
            lidar_projected_points = frame_group["lidar/projected_points"][:]
            width, height = calibration["image_size"]
            rendered_image = render_2dpoint_to_image(lidar_projected_points, width,height, use_color=True, MAX_DEPTH=depth_median*5)
            scene_folder = os.path.dirname(h5file)
            cv2.imwrite(os.path.join(scene_folder, frame, "lidar.png"), rendered_image)
            
            # disparity_rgb = frame_group["disparity"]["rgb"][:]
            # focal_length = calibration["mtx_left"][0,0]
            # depth = disparity_to_depth(disparity_rgb, focal_length, np.linalg.norm(calibration["T"][:]))
            # depth = (np.clip(depth, 0, depth_median*5) / (depth_median*5) * 255.0).astype(np.uint8)
            # depth = cv2.applyColorMap(depth, cv2.COLORMAP_MAGMA)
            # cv2.imwrite(os.path.join(scene_folder, frame, "depth.png"), depth)
        f.close()
            


In [None]:
def frame_create_fig_png(
    scene_folder: str,
    frame_id: str,
    frame: h5py.Group,
    focal_length: float,
    baseline: float,
    use_numpy=False,
):
    """
    Frame에 대해 3x3 이미지 그리드를 생성합니다.
    각 채널의 Stereo Image, Disparity, Depth 그리고 Lidar Projected Depth를 표시합니다.
    """
    rgb_left = cv2.imread(os.path.join(scene_folder, frame_id, "rgb", "left.png"))
    rgb_right = cv2.imread(os.path.join(scene_folder, frame_id, "rgb", "right.png"))
    nir_left = cv2.imread(os.path.join(scene_folder, frame_id, "nir", "left.png"))
    nir_right = cv2.imread(os.path.join(scene_folder, frame_id, "nir", "right.png"))
    rgb_disparity = frame["disparity/rgb"][:]
    nir_disparity = frame["disparity/nir"][:]

    ########### Mis aligned frame remove
    if rgb_disparity.mean() < 64 and nir_disparity.mean() > 64:
        return

    rgb_depth = disparity_to_depth(rgb_disparity, focal_length, baseline)
    nir_depth = disparity_to_depth(nir_disparity, focal_length, baseline)
    lidar_depth = frame["lidar/projected_points"][:]

    if use_numpy:
        fig = frame_create_fig_png_numpy(
            (rgb_left, rgb_right, rgb_disparity),
            (nir_left, nir_right, nir_disparity),
            lidar_depth,
            focal_length, baseline
        )
        cv2.imwrite(os.path.join(scene_folder, frame_id, "fig.png"), fig)
        return
        # Prepare layout for 3x3 image grid
    fig, axs = plt.subplots(3, 3, figsize=(25, 15))

    depth_max = min(
        max(20000, max(np.median(rgb_depth), np.median(nir_depth)) * 3), 50000
    )
    disparity_max = min(rgb_disparity.max(), nir_disparity.max()) * 0.8

    # First row: RGB images
    axs[0, 0].imshow(cv2.cvtColor(rgb_left, cv2.COLOR_BGR2RGB))
    axs[0, 0].set_title("RGB Left")
    axs[0, 0].axis("off")

    axs[0, 1].imshow(cv2.cvtColor(rgb_right, cv2.COLOR_BGR2RGB))
    axs[0, 1].set_title("RGB Right")
    axs[0, 1].axis("off")

    # RGB Disparity with magma colormap
    im_rgb_disp = axs[0, 2].imshow(
        np.clip(rgb_disparity, 0, disparity_max),
        cmap="magma",
        vmin=0,
        vmax=disparity_max,
    )
    axs[0, 2].set_title("RGB Disparity")
    fig.colorbar(im_rgb_disp, ax=axs[0, 2])
    axs[0, 2].axis("off")

    # Second row: NIR images
    axs[1, 0].imshow(cv2.cvtColor(nir_left, cv2.COLOR_BGR2RGB))
    axs[1, 0].set_title("NIR Left")
    axs[1, 0].axis("off")

    axs[1, 1].imshow(cv2.cvtColor(nir_right, cv2.COLOR_BGR2RGB))
    axs[1, 1].set_title("NIR Right")
    axs[1, 1].axis("off")

    # NIR Disparity with magma colormap
    im_nir_disp = axs[1, 2].imshow(
        np.clip(nir_disparity, 0, disparity_max),
        cmap="magma",
        vmin=0,
        vmax=disparity_max,
    )
    axs[1, 2].set_title("NIR Disparity")
    fig.colorbar(im_nir_disp, ax=axs[1, 2])
    axs[1, 2].axis("off")
    rgb_depth[rgb_disparity < 1] = 0
    nir_depth[nir_disparity < 1] = 0
    # Third row: Depth images
    im_rgb_depth = axs[2, 0].imshow(
        np.clip(rgb_depth, 0, depth_max), cmap="magma", vmin=0, vmax=depth_max
    )
    axs[2, 0].set_title("RGB Depth")
    fig.colorbar(im_rgb_depth, ax=axs[2, 0])
    axs[2, 0].axis("off")
    im_nir_depth = axs[2, 1].imshow(
        np.clip(nir_depth, 0, depth_max), cmap="magma", vmin=0, vmax=depth_max
    )
    axs[2, 1].set_title("NIR Depth")
    fig.colorbar(im_nir_depth, ax=axs[2, 1])
    axs[2, 1].axis("off")
    # Lidar depth (point cloud projected)
    u, v = lidar_depth[:, 0], -lidar_depth[:, 1]
    z = lidar_depth[:, 2]
    sc = axs[2, 2].scatter(
        u, v, c=np.clip(z, 0, depth_max), cmap="magma", vmin=0, vmax=depth_max
    )
    axs[2, 2].set_title("Lidar Depth")
    fig.colorbar(sc, ax=axs[2, 2])
    axs[2, 2].axis("off")
    # Display the full layout
    # plt.tight_layout()
    # plt.show()
    plt.savefig(os.path.join(scene_folder, frame_id, "fig.png"))
    plt.close()


def process_h5file_create_figs(h5path: str):
    """
    h5file의 모든 frame에 대해 fig png 이미지들을 생성합니다.
    """
    scene_id = os.path.dirname(h5path).split("/")[-1]
    with h5py.File(h5path, "r") as f:
        focal_length = f["calibration"].attrs["mtx_left"][0, 0]
        baseline = np.linalg.norm(f["calibration"].attrs["T"][:])
        for frame_id in f["frame"]:
            frame_create_fig_png(
                os.path.dirname(h5path),
                frame_id,
                f["frame"].require_group(frame_id),
                focal_length,
                baseline,
            )

In [None]:
import time


rgb_label = None
nir_label = None
colorbar_disparity = None

def frame_create_fig_png_numpy(
    rgb_tuple: Tuple[np.ndarray, np.ndarray, np.ndarray],
    nir_tuple: Tuple[np.ndarray, np.ndarray, np.ndarray],
    lidar_depth: np.ndarray,
    focal_length : float, baseline: float,
):
    rgb_left, rgb_right, rgb_disparity = rgb_tuple
    H,W = rgb_left.shape[:2]
    nir_left, nir_right, nir_disparity = nir_tuple
    rgb_disparity = rgb_disparity[:H, :W]
    nir_disparity = nir_disparity[:H, :W]
    #depth_max = min(max(20000, max(np.median(rgb_depth), np.median(nir_depth)) * 3), 50000)
    disparity_max = min(rgb_disparity.max(), nir_disparity.max()) * 0.8
    rgb_disparity_color = cv2.applyColorMap(
        (np.clip(rgb_disparity, 0, disparity_max) / disparity_max * 255.0).astype(
            np.uint8
        ),
        cv2.COLORMAP_MAGMA,
    )
    nir_disparity_color = cv2.applyColorMap(
        (np.clip(nir_disparity, 0, disparity_max) / disparity_max * 255.0).astype(
            np.uint8
        ),
        cv2.COLORMAP_MAGMA,
    )
    rgb_concat = np.concatenate([rgb_left, rgb_right, rgb_disparity_color], axis=1)
    nir_concat = np.concatenate([nir_left, nir_right, nir_disparity_color], axis=1)
    
    colorbar_disparity = np.linspace(0, 255,256).reshape(256,1).astype(np.uint8)
    colorbar_disparity = cv2.applyColorMap(colorbar_disparity, cv2.COLORMAP_MAGMA)
    colorbar_disparity = cv2.resize(colorbar_disparity, (50, rgb_left.shape[0]))
    colorbar_disparity_text = np.zeros((rgb_left.shape[0], 100, 3), dtype=np.uint8) + 255
    for i in range(0, rgb_left.shape[0], 100):
        cv2.putText(
            colorbar_disparity_text,
            str(int(i / rgb_left.shape[0] * disparity_max)),
            (10, i+15),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.5,
            (0, 0, 0),
            1,
        )
    colorbar_depth_text = np.zeros((rgb_left.shape[0], 100, 3), dtype=np.uint8) + 255
    for i in range(100, rgb_left.shape[0], 100):
        cv2.putText(
            colorbar_depth_text,
            str(round(focal_length * baseline / (   i / rgb_left.shape[0] * disparity_max)/1000,1)) + "m",
            (10, i+15),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.5,
            (0, 0, 0),
            1,
        )
    rgb_concat = np.concatenate([rgb_concat, colorbar_disparity, colorbar_disparity_text], axis=1)
    nir_concat = np.concatenate([nir_concat, colorbar_disparity, colorbar_disparity_text], axis=1)
    lidar_depth[:,2] = focal_length * baseline / lidar_depth[:,2]
    lidar_disparity = render_2dpoint_to_image(lidar_depth, rgb_left.shape[1], rgb_left.shape[0], use_color=False)
    lidar_disparity_color = cv2.applyColorMap(
        (np.clip(lidar_disparity, 0, disparity_max) / disparity_max * 255.0).astype(np.uint8),
        cv2.COLORMAP_MAGMA,
    )
    rgb_concat = np.concatenate([rgb_concat, lidar_disparity_color, colorbar_disparity, colorbar_depth_text], axis=1)
    nir_padding = np.zeros((nir_concat.shape[0], rgb_concat.shape[1] - nir_concat.shape[1], 3), dtype=np.uint8) + 255
    nir_concat = np.concatenate([nir_concat, nir_padding], axis=1)
    global rgb_label, nir_label
    if rgb_label is None:
        rgb_label = np.zeros((100, rgb_concat.shape[1], 3), dtype=np.uint8) + 255
        cv2.putText(rgb_label, "RGB Left", (W // 2, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 1)
        cv2.putText(rgb_label, "RGB Right", (W // 2 * 3, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 1)
        cv2.putText(rgb_label, "RGB Disparity", (W // 2 * 5, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 1)
        cv2.putText(rgb_label, "Lidar Depth", (W // 2 * 7, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 1)
        nir_label = np.zeros((100, rgb_concat.shape[1], 3), dtype=np.uint8) + 255
        cv2.putText(nir_label, "NIR Left", (W // 2, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 1)
        cv2.putText(nir_label, "NIR Right", (W // 2 * 3, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 1)
        cv2.putText(nir_label, "NIR Disparity", (W // 2 * 5, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 1)
        

    fig = np.concatenate([rgb_label, rgb_concat, nir_label, nir_concat], axis=0)
    return fig
    
    
    
    
    

In [None]:
import threading
import time
from typing import Union
from IPython.display import Video, display

def save_video(image_paths: list[str], video_path, fps=5):
    '''
    image_paths로 읽어 온 이미지 파일들을 순서대로 비디오 파일로 저장합니다.
    image_paths: 이미지 파일 경로 리스트
    video_path: 저장할 비디오 파일 경로
    fps: 초당 프레임 수
    '''
    print(image_paths[0])
    height, width = cv2.imread(image_paths[0]).shape[:2]
    # 비디오 파일 쓰기 설정
    fourcc = cv2.VideoWriter.fourcc(*'mp4v')  # 코덱 설정
    os.makedirs(os.path.dirname(video_path), exist_ok=True)
    out = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
    
    for img in tqdm(image_paths):
        img = img.replace("frames_cleanpass", "plot")
        if not os.path.exists(img):
            break
        image = cv2.imread(img)
        out.write( image)
    
    out.release()
    print(f"비디오가 저장되었습니다: {video_path}")


def process_h5_plt_video(h5file: str, create_fig=False):
    '''
    h5file의 모든 frame에 대해 fig.png 이미지들을 하나의 비디오로 저장합니다.
    '''
    with h5py.File(h5file, "r")  as f:
        frame_ids = list(f["frame"].keys())
        frame_id_filtered = []
        focal_length = f["calibration"].attrs["mtx_left"][0,0]
        baseline = np.linalg.norm(f["calibration"].attrs["T"][:])
        if create_fig:
            fig_create_threads = []
            for idx, frame_id in enumerate(tqdm(frame_ids)):
                thread = threading.Thread(target=frame_create_fig_png, args=(os.path.dirname(h5file), frame_id, f["frame"].require_group(frame_id), focal_length, baseline,  True))
                thread.start()
                fig_create_threads.append(thread)
                if len(fig_create_threads) >= 6 or idx == len(frame_ids) - 1:
                    for thread in fig_create_threads:
                        thread.join()
                    fig_create_threads = []        
        for frame_id in frame_ids:
            if not os.path.exists(os.path.join(os.path.dirname(h5file), frame_id, "fig.png")):
                continue
            if cv2.imread(os.path.join(os.path.dirname(h5file), frame_id, "fig.png")) is None:
                continue
            frame_id_filtered.append(frame_id)
        image_paths = [os.path.join(os.path.dirname(h5file), x, "fig.png") for x in frame_id_filtered]
        save_video(image_paths, h5file.replace(".hdf5", ".mp4"))
        f.close()

In [None]:
def stereo_depth_lidar_loss(frame: h5py.Group, focal_length: float, baseline: float):
    disparity_rgb = frame["disparity/rgb"][:]
    disparity_nir = frame["disparity/nir"][:]
    projected_points = frame["lidar/projected_points"][:]

    u = projected_points[:, 0]
    v = projected_points[:, 1]
    z = focal_length * baseline / projected_points[:, 2]
    depth_rgb = disparity_rgb[v.astype(int), u.astype(int)].squeeze()
    depth_nir = disparity_nir[v.astype(int), u.astype(int)].squeeze()
    
    rsme_rgb = np.sqrt(np.mean((depth_rgb - z) ** 2))
    rsme_nir = np.sqrt(np.mean((depth_nir - z) ** 2))
    mae_rgb = np.mean(np.abs(depth_rgb - z))
    mae_nir = np.mean(np.abs(depth_nir - z))
    return (rsme_rgb, rsme_nir), (mae_rgb, mae_nir)


def process_depth_loss_h5file(h5file: str, overwrite_loss=False):
    """
    h5file의 모든 frame에 대해
    raft stereo로 계산한 depth와 라이다 depth의 차이를 계산합니다.
    rsme와 mae를 반환합니다.
    """
    with h5py.File(h5file, "a") as f:
        frame_ids = list(f["frame"].keys())
        focal_length = f["calibration"].attrs["mtx_left"][0, 0]
        baseline = np.linalg.norm(f["calibration"].attrs["T"][:])
        __rsme_rgb = []
        __rsme_nir = []
        __mae_rgb = []
        __mae_nir = []
        if not overwrite_loss and "rsme_rgb" in f.attrs:
            f.close()
            return
        for frame_id in tqdm(frame_ids):
            frame = f.require_group(f"frame/{frame_id}")
            if "lidar/projected_points" not in frame:
                continue
            (rsme_rgb, rsme_nir), (mae_rgb, mae_nir) = stereo_depth_lidar_loss(
                frame, focal_length, baseline
            )
            frame["disparity/rgb"].attrs["rsme"] = rsme_rgb
            frame["disparity/nir"].attrs["rsme"] = rsme_nir
            frame["disparity/rgb"].attrs["mae"] = mae_rgb
            frame["disparity/nir"].attrs["mae"] = mae_nir
            __rsme_rgb.append(rsme_rgb)
            __rsme_nir.append(rsme_nir)
            __mae_rgb.append(mae_rgb)
            __mae_nir.append(mae_nir)
            print(f"{frame_id} RSME RGB: {rsme_rgb:.2f} RSME NIR: {rsme_nir:.2f}")
            print(f"{frame_id} MAE RGB: {mae_rgb:.2f} MAE NIR: {mae_nir:.2f}")
        f.attrs["rsme_rgb"] = np.mean(__rsme_rgb)
        f.attrs["rsme_nir"] = np.mean(__rsme_nir)
        f.attrs["mae_rgb"] = np.mean(__mae_rgb)
        f.attrs["mae_nir"] = np.mean(__mae_nir)
        f.close()

In [None]:
def process_disparity_h5file(h5file: str, overwrite=False):
    '''
    h5file의 모든 frame에 대해
    Raft Stereo 모델을 사용하여 disparity를 추출하고 h5 파일에 저장합니다.
    disparity를 color map으로 변환하여 png 파일로 저장합니다.
    '''
    with h5py.File(h5file, "a")  as f:
        frame_ids = list(f["frame"].keys())
        if not overwrite and "disparity" in f["frame"][frame_ids[-1]]:
            f.close()
            return
        for frame_id in tqdm(frame_ids):
            frame = f.require_group(f"frame/{frame_id}")
            if not overwrite and "disparity" in frame:
                continue
            output = process_frame(os.path.join(os.path.dirname(h5file), frame_id))
            if output is None:
                continue
            disparity_rgb, disparity_nir = output
            
            frame.create_dataset("disparity/rgb", data=disparity_rgb)
            frame.create_dataset("disparity/nir", data=disparity_nir)
            if not "disparity" in frame:
                continue
            disparity_rgb = frame["disparity/rgb"][:]
            disparity_nir = frame["disparity/nir"][:]
            if disparity_nir.mean() > 64 and disparity_rgb.mean() < 64:
                frame.attrs["align_error"] = True
                print(f"A frame {frame_id} has an alignment error")
            else:
                frame.attrs["align_error"] = False
        f.close()

In [None]:
def process_h5file_all(
    h5file: str,
    overwrite_disparity=False,
    overwrite_loss=False,
    overwrite_plot=False,
    overwrite_lidar=False,
):
    """
    h5file에 대해 disparity, plot, lidar를 모두 처리합니다.
    """
    process_disparity_h5file(h5file, overwrite=overwrite_disparity)
    process_lidar_h5file(h5file, overwrite=overwrite_lidar)
    process_depth_loss_h5file(h5file, overwrite_loss=overwrite_loss)
    if overwrite_plot or not os.path.exists(h5file.replace(".hdf5", ".mp4")):
        process_h5_plt_video(h5file, create_fig=True)

    print(f"Finished processing {h5file}")

In [None]:
folders = [x for x in os.listdir("/bean/depth") if x.startswith("09-10")]
for folder in folders:
    print(f"Processing {folder}")
    h5files = [os.path.join("/bean/depth", folder, x) for x in os.listdir(os.path.join("/bean/depth", folder)) if x.endswith(".hdf5")]
    h5files.sort(key=lambda x: int(os.path.basename(x).split(".")[0]))
    for h5file in h5files:
        print(f"Processing {h5file}")
        process_lidar_h5file(h5file)

In [None]:

FOLDERS = [x for x in os.listdir("/bean/depth") if os.path.isdir( os.path.join("/bean/depth",x))]
FOLDERS = [x for x in FOLDERS if x.startswith("09-10")]
h5files_pending = []
for folder in FOLDERS:
    print(folder)
    h5files = os.listdir(os.path.join("/bean/depth",folder))
    h5files = [x for x in h5files if x.endswith(".hdf5")]
    h5files.sort(key=lambda x : int(x.split('.')[0]) )
    for h5 in h5files:
        h5files_pending.append(os.path.join("/bean/depth",folder,h5))
print(len(h5files_pending))
for h5 in tqdm(h5files_pending):
    if os.path.exists(h5.replace(".hdf5", ".mp4")):
        continue
    process_h5_plt_video(h5, create_fig=True)

2310 - 2160 150

In [None]:
process_h5file_all("/bean/depth/09-09-20-04-34/0.hdf5")