In [1]:

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'
try:
    from core.raft_stereo import RAFTStereo
except ImportError:
    import os
    os.chdir("/RAFT-Stereo")
    from core.raft_stereo import RAFTStereo
    
FRPASS = "frames_cleanpass"
from train_fusion.dataloader import StereoDataset, StereoDatasetArgs

import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader
from fusion_args import FusionArgs
args = FusionArgs()
args.hidden_dims = [128, 128, 128]
args.corr_levels = 4
args.corr_radius = 4
args.n_downsample = 3
args.context_norm = "batch"
args.n_gru_layers = 2
args.shared_backbone = False
args.mixed_precision = True
args.corr_implementation = "reg_cuda"
args.slow_fast_gru = False
args.restore_ckpt = "models/raftstereo-realtime.pth"
args.shared_backbone = True

args.lr = 0.001
args.train_iters = 15
args.valid_iters = 24
args.wdecay = 0.0001
args.num_steps = 100000
args.valid_steps = 1000
args.name = "ColorFusion"
args.batch_size = 8
args.fusion = "AFF"
args.shared_fusion = True
args.freeze_backbone = []
args.both_side_train= False

device = "cuda:5"

In [2]:
from color_fusion_model import RGBNIRFusionNet
from rgb_thermal_fusion_net import RGBThermalFusionNet

from encoding.parallel import DataParallelModel, DataParallelCriterion
args.input_channel = 3

raft_model = DataParallelModel(RAFTStereo(args)).cuda()
raft_model.load_state_dict(torch.load(args.restore_ckpt))
raft_model.eval()
raft_model.module.freeze_bn()
raft_model = raft_model.module





In [3]:
def compute_disparity(left: torch.Tensor, right: torch.Tensor):
        if left.shape[-3] == 1:
            left = left.repeat(1, 3, 1, 1)
            right = right.repeat(1, 3, 1, 1)
        _, flow = raft_model(left, right, test_mode=True)
        return flow

In [None]:
from train_fusion.my_h5_dataloader import MyH5DataSet
from train_fusion.dataloader import StereoDataset, StereoDatasetArgs, EntityDataSet
from torch.utils.data import DataLoader
import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader
dataset_real = MyH5DataSet( frame_cache=True,use_right_shift = True)
print(len(dataset_real))
dataset_real = MyH5DataSet( frame_cache=False,use_right_shift = True)
print(len(dataset_real))


dataset_drive = StereoDataset(StereoDatasetArgs(flow3d_driving_json=True, noised_input=True, shift_filter=True, vertical_scale=True, fast_test=True))
dataset_flying = StereoDataset(StereoDatasetArgs(flying3d_json=True, noised_input=True, shift_filter=True, fast_test=True))
dataset = EntityDataSet(dataset_real.input_list + dataset_drive.input_list + dataset_flying.input_list)
cnt = len(dataset)
train_cnt = int(cnt * 0.9)
valid_cnt = cnt - train_cnt
print(cnt)
dataset_train = EntityDataSet(dataset.input_list[:train_cnt])
dataset_valid =EntityDataSet(dataset.input_list[train_cnt:])
train_loader = DataLoader(dataset_train, batch_size=1, shuffle=True, num_workers=0, drop_last=True)


In [None]:

driver_loader = DataLoader(dataset_drive, batch_size=1, shuffle=True, num_workers=0, drop_last=True)
flying_loader = DataLoader(dataset_flying, batch_size=1, shuffle=True, num_workers=0, drop_last=True)
print(len(dataset_drive))
print(len(dataset_flying))
print(len(dataset_real))

In [6]:
valid_loader = DataLoader(dataset_real, batch_size=1, shuffle=True, num_workers=0, drop_last=True)

In [7]:
import torch
from train_fusion.loss_function import reproject_disparity

def input_reduce_disparity(inputs: list[torch.Tensor]):
    shift = int(inputs[-1].min() // 16) * 16
    if shift <= 0:
        return inputs
    warp_right = reproject_disparity(-inputs[-1] + shift, data[0])
    warp_right_nir = reproject_disparity(-inputs[-1] + shift, data[2])
    rolled_rgb_right = torch.roll(inputs[1], shifts = shift, dims = -1)
    rolled_nir_right = torch.roll(inputs[3], shifts = shift, dims = -1)
    rolled_rgb_right[...,:shift] = warp_right[...,:shift]
    rolled_nir_right[...,:shift] = warp_right_nir[...,:shift]
    
    inputs[-1] -= shift
    inputs[-2][...,:2] -= shift
    return [
        inputs[0], rolled_rgb_right, inputs[2], rolled_nir_right, inputs[-2], inputs[-1]
    ]

In [8]:
import torch.nn.functional as F
def crop_and_resize_height(image: torch.Tensor, h_to = 360):
    b, c, h, w = image.shape
    print(image.shape)
    cropped = image[:,:,:h_to,:]
    resized = F.interpolate(cropped, size=(h, w), mode='bilinear',)
    return resized

In [None]:
from train_fusion.loss_function import reproject_disparity
from train_fusion.ssim.utils import SSIM, warp

#data = next(iter(valid_loader))
idx = 33
# dataset_drive.input_list[idx].noise_target = "nir"
# dataset_drive.input_list[idx].guided_noise = 15
# dataset_drive.input_list[idx].gamma_noise = 3
# dataset_drive.input_list[idx].shift_filter = True


data = dataset_real.input_list[idx].get_item()
data = [x.unsqueeze(0) for x in data]
import matplotlib.pyplot as plt
plt.figure(figsize=(15,15))
plt.subplot(321)
plt.title("GT Disparity")
sc = plt.imshow(data[-1][0,0].numpy(), vmin=0, vmax=64, cmap="rainbow")
plt.colorbar(sc)
plt.subplot(322)
plt.title("Left RGB")
plt.imshow(data[0][0].permute(1,2,0).numpy().astype(np.uint8))
plt.subplot(324)
plt.title("Right RGB")
plt.imshow(data[1][0].permute(1,2,0).numpy().astype(np.uint8))

plt.subplot(325)
plt.title("Left NIR")
plt.imshow(data[2][0].permute(1,2,0).numpy().astype(np.uint8), cmap="gray")
plt.subplot(326)
plt.title("Right NIR")
plt.imshow(data[3][0].permute(1,2,0).numpy().astype(np.uint8), cmap="gray")

with torch.no_grad():
    flow = raft_model(data[2].repeat(1,3,1,1).cuda(), data[3].repeat(1,3,1,1).cuda(), test_mode=True)[1][:,:,:540,:720]
    disp = -flow[0,0]
    right = torch.concat((data[1], data[3]), dim = 1).cuda() / 255
    warp_right = warp(torch.concat([data[0], data[2]] , dim = 1).cuda() / 255, flow)
    mask_right = warp(torch.ones_like( right).cuda(), flow, padding_mode="zeros")
    ssim_loss = SSIM()(warp_right, right)
    l1_loss = torch.abs(warp_right - right )
    loss = (ssim_loss * 0.85 + 0.15 * l1_loss.mean(1, True))[mask_right > 0]
    
    print(loss)
    print(loss.mean())



plt.subplot(323)
plt.title("RAFT Stereo")
sc =plt.imshow(disp.cpu().numpy(), vmin=0, vmax=64, cmap="rainbow")
plt.colorbar(sc)
plt.subplot(325)
plt.imshow(warp_right[0].permute(1,2,0).cpu().numpy()[...,:3])
# plt.subplot(326)
# plt.imshow(crop_and_resize_height(data[5])[0].permute(1,2,0).numpy(), cmap="plasma", vmin=0, vmax=64)

In [11]:
from rgb_thermal_fusion_net import RGBThermalFusionNet
import matplotlib.pyplot as plt
from color_fusion_model import RGBNIRFusionNet


def loss_fn_detph_gt(flow: torch.Tensor, target_gt: torch.Tensor):
    gt_u = target_gt[:, :, 1].long()
    gt_v = target_gt[:, :, 0].long()
    gt_u = torch.clamp(gt_u, 0, flow.shape[-2] - 1)
    gt_v = torch.clamp(gt_v, 0, flow.shape[-1] - 1)
    B, N = gt_u.shape
    batch_indices = torch.arange(B).view(B, 1).expand(B, N).to(flow.device)
    target_pred = -flow[batch_indices, :, gt_u, gt_v].squeeze()

    target_depth = target_gt[:, :, 2]
    depth_loss = torch.mean(torch.abs(target_pred - target_depth), dim=1)

    return depth_loss

In [None]:
from core.raft_stereo import RAFTStereo
fusion_args = FusionArgs()
fusion_args.batch_size = 3
fusion_args.input_channel = 4
fusion_args.corr_levels = 4
fusion_args.corr_radius = 4
fusion_args.n_downsample = 2
fusion_args.context_norm = "batch"
fusion_args.n_gru_layers = 3
fusion_args.shared_backbone = False
fusion_args.train_iters = 15
fusion_args.valid_iters = 20
model_RAFT = RAFTStereo(fusion_args).cuda()
model_RAFT.load_state_dict(torch.load("checkpoints/2664_Raft4ChannelRaftLoss2.pth"))
#model.load_state_dict(torch.load("interrupted_model.pth"))
model_RAFT.eval()





In [None]:
from color_fusion_model import RGBNIRFusionNet
model = RGBNIRFusionNet().cuda()
model.load_state_dict(torch.load("checkpoints/2200_ColorChannel4.pth"))

In [None]:
from core.raft_stereo_fusion import RAFTStereoFusion
fusion_args = FusionArgs()
fusion_args.n_downsample = 2
fusion_args.n_gru_layers = 3
fusion_args.shared_backbone = False
fusion_args.shared_fusion = True
model_AFF = RAFTStereoFusion(fusion_args).cuda()
model_AFF.eval()
fusion_args = FusionArgs()
fusion_args.n_downsample = 3
fusion_args.n_gru_layers = 2
fusion_args.shared_backbone = True
fusion_args.shared_fusion = True
model_AFF_F = RAFTStereoFusion(fusion_args).cuda()
model_AFF_F.eval()



In [23]:
import torch
import torch.nn.functional as F
def create_overlay_image(original_image, feature_map, num_channels=8, alpha=0.6, cmap='jet', vmin=0, vmax=1):
    """
    original_image: Tensor of shape (3, H, W)
    feature_map: Tensor of shape (256, H//8, W//8)
    num_channels: Number of feature map channels to aggregate
    alpha: Transparency for heatmap overlay (0 to 1)
    cmap: Colormap for heatmap
    Returns:
        overlay_image: NumPy array of shape (H, W, 3) with values in [0, 1]
    """
    # 선택할 채널 수 조정
    num_channels = min(num_channels, feature_map.shape[0])
    
    # 첫 num_channels 채널 선택 (필요에 따라 중요 채널 선택 로직 추가 가능)
    selected_features = feature_map[:num_channels, :, :]
    
    # feature map을 평균하여 단일 채널로 축소
    aggregated_feature = torch.mean(selected_features, dim=0, keepdim=True)  # (1, H//8, W//8)
    
    # 업샘플링하여 원본 이미지 크기로 복원
    upsampled_feature = F.interpolate(aggregated_feature.unsqueeze(0), size=(original_image.shape[1], original_image.shape[2]), mode='bilinear', align_corners=False)
    upsampled_feature = upsampled_feature.squeeze().detach().cpu().numpy()  # (H, W)
    
    upsampled_feature -= vmin
    
    upsampled_feature /= (vmax - vmin)
    #upsampled_feature = np.clip(upsampled_feature, vmin, vmax)
    
    # 컬러맵 적용
    import matplotlib.cm as cm
    cmap_func = cm.get_cmap(cmap)
    heatmap = cmap_func(upsampled_feature)[:, :, :3]  # (H, W, 3), RGB
    
    # 원본 이미지 변환 (Tensor -> NumPy, 채널 순서 변경)
    original_np = original_image.detach().cpu().numpy()
    original_np = np.transpose(original_np, (1, 2, 0))  # (H, W, 3)
    
    original_np /= 255.0
    
    # Heatmap과 원본 이미지 결합
    overlay_image = (1 - alpha) * original_np + alpha * heatmap
    overlay_image = np.clip(overlay_image, 0, 1)
    
    return overlay_image

In [64]:
import matplotlib.pyplot as plt
import cv2
import torch
from myutils.hy5py import calibration_property, get_frame_by_path, read_calibration
from myutils.matrix import rmse_loss
from train_fusion.loss_function import (
    self_supervised_loss,
    reproject_disparity,
    disparity_smoothness,
    ssim as ssim_torch,
)
from myutils.points import (
    disparity_image_edge_eval,
    project_points_on_camera,
    refine_disparity_points,
    transform_point_inverse,
)
from skimage.metrics import structural_similarity as ssim

cmap_disparity = "rainbow"
spectrum = ["RGB", "NIR", "FUSION"]


def plot_raft_model(x, model_type="Raft", ckpoints=None, fusion_only=False, C=0):
    train_inputs = [x.cuda() for x in x]
    image1, image2, image3, image4 = train_inputs[:4]

    if C > 0:
        image2 = torch.roll(image2, shifts=-C, dims=-1)
        image4 = torch.roll(image4, shifts=-C, dims=-1)

    if len(train_inputs) > 4:
        depth = train_inputs[4]
        if C > 0:
            depth[..., 2] += C
    else:
        depth = None
    if len(train_inputs) > 5:
        dis_gt = train_inputs[5]
    else:
        dis_gt = None
    with torch.no_grad():
        if model_type in ["AFF", "AFF_F"]:
            model = model_AFF if model_type == "AFF" else model_AFF_F
            if ckpoints is not None:
                model.load_state_dict(torch.load(ckpoints), strict=False)

            input_dict = {
                "image_viz_left": image1,
                "image_viz_right": image2,
                "image_nir_left": image3,
                "image_nir_right": image4,
                "iters": args.valid_iters,
                "test_mode": True,
                "flow_init": None,
                "heuristic_nir": False,
                "attention_out_mode": False,
            }
            _, flow = model(input_dict)
            flow += C
            input_dict["attention_out_mode"] = True
            fmap1, fmap1_rgb, fmap1_nir = model(input_dict)

        if model_type == "RAFT":
            model = model_RAFT
            if ckpoints is not None:
                model.load_state_dict(torch.load(ckpoints))
            fused_input1 = torch.cat(
                [image1, image3], dim=1
            )  # image1: RGB, image3: Thermal
            fused_input2 = torch.cat([image2, image4], dim=1)

            coor, flow = model(fused_input1, fused_input2, test_mode=True, iters=24)
            fmap1, fmap2 = model.fnet([fused_input1, fused_input2])

        if model_type == "Color":
            fused_input1 = torch.cat(
                [image1, image3], dim=1
            )  # image1: RGB, image3: Thermal
            fused_input2 = torch.cat([image2, image4], dim=1)
            fused_image1 = model(fused_input1)
            fused_image2 = model(fused_input2)
            fmap1, fmap2 = raft_model.fnet([fused_image1, fused_image2])

            _, flow = raft_model(fused_image1, fused_image2, test_mode=True, iters=24)

        if not model_type in ["AFF", "AFF_F"]:
            fmap1_rgb, fmap2_rgb = raft_model.fnet([image1, image2])
            fmap1_nir, fmap2_nir = raft_model.fnet(
                [image3.repeat(1, 3, 1, 1), image4.repeat(1, 3, 1, 1)]
            )
        if not fusion_only:
            coor_rgb, flow_rgb = raft_model(image1, image2, test_mode=True, iters=24)
            _, flow_nir = raft_model(
                image3.repeat(1, 3, 1, 1),
                image4.repeat(1, 3, 1, 1),
                test_mode=True,
                iters=24,
            )

        disparity = -flow[:, 0]
        disparity = disparity.cpu().numpy()
        if not fusion_only:
            disparity_rgb = -flow_rgb[:, 0].cpu().numpy() - C
            disparity_nir = -flow_nir[:, 0].cpu().numpy() - C

    batch_size = image1.shape[0]

    def plt_disparity(ax, disparity, title, vmax):
        sc = ax.imshow(disparity, cmap=cmap_disparity, vmin=0, vmax=vmax)
        ax.set_title(title)
        plt.colorbar(sc, ax=ax)

    def plt_featuremap(ax, image, feature_map, title):
        overlay = create_overlay_image(image, feature_map, alpha=0.7, vmin=0, vmax=1)
        ax.imshow(overlay)
        ax.set_title(title)

    for b in range(batch_size):
        np_rgb_left, np_rgb_right, np_nir_left, np_nir_right = [
            x[b].permute(1, 2, 0).cpu().numpy().astype(np.uint8) for x in train_inputs[:4]
        ]
        np_fusion_left = modify_v_channel_numpy_opencv(np_rgb_left.copy(), np_nir_left.copy()) * 255
        np_fusion_right = modify_v_channel_numpy_opencv(np_rgb_right.copy(), np_nir_right.copy()) * 255
        np_fusion_left = guided_filter(np_nir_left.squeeze(), np_fusion_left, 5)
        np_fusion_right = guided_filter(np_nir_right.squeeze(), np_fusion_right, 5)
        
        with torch.no_grad():
            disparity_hsv = -raft_model(
                torch.from_numpy(np_fusion_left).permute(2, 0, 1).unsqueeze(0).cuda(),
                torch.from_numpy(np_fusion_right).permute(2, 0, 1).unsqueeze(0).cuda(),
                test_mode=True,
            )[1][0, 0]
        print(np_fusion_left.max())
        rows = 4
        cols = 1 if fusion_only else 3
        if depth is not None:
            cols += 1
        vmax = disparity[b].mean() * 2.2
        vmax = min(128, vmax)
        fig, axs = plt.subplots(cols, rows, figsize=(8 * rows, 5 * cols))
        if not fusion_only:
            axs[0, 0].imshow(np_rgb_left.astype(np.uint8))
            axs[0, 0].set_title("RGB_Left")

            axs[1, 0].imshow(np_nir_left.astype(np.uint8), cmap="gray")
            axs[1, 0].set_title("NIR_Left")
            plt_disparity(axs[0, 2], disparity_rgb[b], "Disparity_RGB_RaftStereo", vmax)
            plt_disparity(axs[1, 2], disparity_nir[b], "Disparity_NIR_RaftStereo", vmax)
            plt_disparity(axs[2, 2], disparity[b], "Disparity_Fused", vmax)
        else:
            plt_disparity(axs[1], disparity[b], "Disparity_Fused", vmax)
        plt_featuremap(axs[0, 3], image1[b], fmap1_rgb[b], "RGB Attention")
        plt_featuremap(axs[1, 3], image3[b], fmap1_nir[b], "NIR Attention")
        plt_featuremap(axs[2, 3], image1[b], fmap1[b], "Fused Attention")

        if model_type == "Color":
            axs[2, 0].set_title("Fused Image")
            fused = (
                fused_image1[b].permute(1, 2, 0).cpu().numpy()
                / fused_image1[b].max()
                * 255
            )
            axs[2, 0].imshow(fused.astype(np.uint8))

        lidar_depth = depth[b].cpu().numpy()
        u, v = lidar_depth[:, 0], lidar_depth[:, 1]
        z = lidar_depth[:, 2]

        def plot_points(
            ax, title, image, u, v, z, vmax=32, vmin=0, cmap=cmap_disparity
        ):
            ax.imshow(image.permute(1, 2, 0).cpu().numpy().astype(np.uint8))
            sc = ax.scatter(u, v, c=z, cmap=cmap, vmin=vmin, vmax=vmax, s=1)
            ax.set_title(title)
            plt.colorbar(sc, ax=ax)

        u = u.astype(np.int32)
        v = v.astype(np.int32)
        u_r = (u - z).astype(np.int32)
        u_r[u_r < 0] = 0
        color_sampled = np.clip(np_rgb_right[v, u_r] / 255.0, 0, 1)
        axs[2, 0].imshow(np.zeros_like(disparity[b]), cmap="gray")
        axs[2, 0].scatter(u, v, c=color_sampled, s=10)
        axs[2, 0].set_title("Lidar Warped")
        # plot_points(axs[2, 0], "Lidar Disparity", image1[b], u, v, z, vmax)

        dis_losses = [
            rmse_loss(z, x[v, u])
            for x in [disparity_rgb[b], disparity_nir[b], disparity[b]]
        ]
        smooth_losses = [
            disparity_smoothness([x], torch.concat([image1, image3], dim=1))
            .mean()
            .cpu()
            for x in [flow_rgb, flow_nir, flow]
        ]
        edge_loss = [
            disparity_image_edge_eval(x, np_rgb_left.astype(np.uint8))
            for x in [
                disparity_rgb[b, :540, :720],
                disparity_nir[b, :540, :720],
                disparity[b, :540, :720],
            ]
        ]

        labels = ["rgb", "nir", "fusion"]
        colors = ["blue", "green", "orange"]

        left_concat = torch.concatenate([image1[b : b + 1], image3[b : b + 1]], 1)
        right_concat = torch.concatenate([image2[b : b + 1], image4[b : b + 1]], 1)
        warped_right = [
            reproject_disparity(
                flow[b : b + 1],
                left_concat,
            )
            for flow in [flow_rgb, flow_nir, flow]
        ]
        warped_right_np = [x[0].permute(1, 2, 0).cpu().numpy() for x in warped_right]

        axs[0, 1].imshow(np_rgb_right.astype(np.uint8))
        axs[0, 1].set_title("RGB Right")
        axs[1, 1].imshow(np_nir_right.astype(np.uint8), cmap="gray")
        axs[1, 1].set_title("NIR Right")

        ssim_list = [
            1
            - ssim_torch(right_concat, warped_right)
            .cpu()[0]
            .permute(1, 2, 0)
            .numpy()
            .mean(axis=-1)
            for warped_right in warped_right
        ]

        # for ci, img in enumerate(ssim_list[: 2 if dis_gt is not None else 3]):
        #     sc = axs[ci, 3].imshow(img, cmap="OrRd", vmax=1, vmin=0)
        #     axs[ci, 3].set_title(f"{spectrum[ci]} Warp Error")
        #     plt.colorbar(sc, ax=axs[ci, 3])
        if dis_gt is not None and dis_gt.max() > 5:
            sc = axs[2, 1].imshow(
                dis_gt[b, 0].cpu(), cmap=cmap_disparity, vmin=0, vmax=vmax
            )
            axs[2, 1].set_title("Disparity GT")
            plt.colorbar(sc, ax=axs[2, 1])
        else:
            axs[2, 1].imshow(np_nir_left, cmap="gray")
            sc = axs[2, 1].scatter(u, v, c=z, s=0.1, cmap=cmap_disparity)
            plt.colorbar(sc, ax=axs[2, 1])
            axs[2, 1].set_title("Lidar Points")

        ssim_losses = [x.mean() for x in ssim_list]
        # for i, (data, title, label) in enumerate(
        #     [
        #         (dis_losses, "Disparity RMSE Loss Comparison", "RMSE Error"),
        #         (smooth_losses, "Disparity Smoothness Comparison", "Smoothness"),
        #         (edge_loss, "Edge Comparison", "Edge RMSE"),
        #         (ssim_losses, "SSIM Comparision", "Warp SSIM"),
        #     ]
        # ):

        #     vgap = max(data) - min(data)
        #     if vgap > min(data):
        #         vgap = min(data)
        #     ymax = max(data) + vgap * 0.3
        #     axs[3, i].bar(labels, data, color=colors)
        #     axs[3, i].set_ylabel(label)
        #     axs[3, i].set_ylim(float(min(data) - vgap), float(ymax))
        #     axs[3, i].set_title(title)
        axs[3, 0].imshow(np_fusion_left.astype(np.uint8))
        axs[3, 1].imshow(np_fusion_right.astype(np.uint8))
        axs[3, 2].imshow(disparity_hsv.cpu(), vmin=0, vmax=vmax, cmap=cmap_disparity)
        plt.tight_layout()
        plt.show()


def imread_tensor(image_path: str):
    nir = "nir" in image_path
    image = cv2.imread(
        image_path, cv2.IMREAD_ANYCOLOR if not nir else cv2.IMREAD_GRAYSCALE
    )
    if not nir:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = torch.from_numpy(image)
    if nir:
        image = image.unsqueeze(-1)
    image = image.permute(2, 0, 1)
    image = image.unsqueeze(0).float()
    return image


from myutils.image_process import (
    guided_filter,
    modify_v_channel_numpy_opencv,
    read_image_pair,
    cv2toTensor,
)


def get_valid_input_from_path(frame_path: str):
    transform_mtx = np.load("jai_transform.npy")
    images = read_image_pair(frame_path)
    images = [cv2toTensor(x).cuda() for x in images]
    calibration = read_calibration("/bean/depth/09-08-17-27-33/0.hdf5")
    fx, bs, cx, cy = calibration_property(calibration)
    with get_frame_by_path(frame_path) as frame:
        lidar_points = (frame["lidar/points"][:] * 1000).reshape(-1, 3)
        lidar_points = transform_point_inverse(lidar_points, transform_mtx)
        lidar_points = project_points_on_camera(lidar_points, fx, cx, cy, 720, 540)
        lidar_points[:, 2] = bs * fx / lidar_points[:, 2] - 1
        lidar_points = refine_disparity_points(torch.from_numpy(lidar_points)).numpy()
        if "rgb_exposure_left" in frame["image"].attrs:
            print(
                frame["image"].attrs["rgb_exposure_left"],
                frame["image"].attrs["rgb_exposure_right"],
                frame["image"].attrs["nir_exposure_left"],
                frame["image"].attrs["nir_exposure_right"],
            )
    images.append(torch.from_numpy(lidar_points).cuda().unsqueeze(0))
    return images

In [None]:
#train_iter = iter(valid_loader)
train_input = [x.unsqueeze(0) for x in dataset_real[466]]
plot_raft_model(train_input, "AFF", "checkpoints/10323_RaftRFusion2NoiseReal.pth")
#plot_raft_model(train_input, "AFF_F", "checkpoints/10292_RaftRFusion2SynthFast.pth")
#plot_raft_model(train_input, "RAFT")

In [None]:
from myutils.widget import FrameExplorer

FrameExplorer(lambda x : plot_raft_model(get_valid_input_from_path(x), model_type= "AFF", C=0))

In [None]:
FrameExplorer(lambda x : plot_raft_model(get_valid_input_from_path(x), model_type= "AFF_F", C=0))

In [None]:
frame_path = "/bean/depth/09-08-17-27-33/17_33_16_354"
frame_path = "/bean/depth/09-08-17-27-33/17_32_56_438"
#frame_path = "/bean/depth/09-28-21-15-50/21_34_03_125"
#plot_raft_model(get_valid_input_from_path(frame_path), model_type = "AFF", ckpoints="checkpoints/4000_RaftFusion22EncoderLarge.pth")
plot_raft_model(get_valid_input_from_path(frame_path), model_type = "AFF", C =0)
#plot_raft_model(get_valid_input_from_path(frame_path), model_type = "RAFT", ckpoints = 'checkpoints/666_Raft4ChannelRaftSynth.pth')