In [None]:

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'
try:
    from core.raft_stereo import RAFTStereo
except ImportError:
    import os
    os.chdir("/RAFT-Stereo")
    from core.raft_stereo import RAFTStereo
    
FRPASS = "frames_cleanpass"
from train_fusion.dataloader import StereoDataset, StereoDatasetArgs

import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader
from fusion_args import FusionArgs
args = FusionArgs()
args.hidden_dims = [128, 128, 128]
args.corr_levels = 4
args.corr_radius = 4
args.n_downsample = 3
args.context_norm = "batch"
args.n_gru_layers = 2
args.shared_backbone = True
args.mixed_precision = True
args.corr_implementation = "reg_cuda"
args.slow_fast_gru = False
args.restore_ckpt = "models/raftstereo-realtime.pth"


args.lr = 0.001
args.train_iters = 7
args.valid_iters = 12
args.wdecay = 0.0001
args.num_steps = 100000
args.valid_steps = 1000
args.name = "ColorFusion"
args.batch_size = 8
args.fusion = "AFF"
args.shared_fusion = True
args.freeze_backbone = []
args.both_side_train= False

In [None]:
from color_fusion_model import RGBNIRFusionNet
from rgb_thermal_fusion_net import RGBThermalFusionNet

from encoding.parallel import DataParallelModel, DataParallelCriterion

device_ids=(0,1,2,3,4,5)
raft_model = DataParallelModel(RAFTStereo(args),device_ids=device_ids, output_device=device_ids[0]).to('cuda')
raft_model.load_state_dict(torch.load(args.restore_ckpt),  strict=False)
raft_model.eval()
raft_model.module.freeze_bn()
raft_model = raft_model.module





In [None]:
def compute_disparity(left: torch.Tensor, right: torch.Tensor):
        if left.shape[-3] == 1:
            left = left.repeat(1, 3, 1, 1)
            right = right.repeat(1, 3, 1, 1)
        _, flow = raft_model(left, right, test_mode=True)
        return flow

In [None]:
from train_fusion.my_h5_dataloader import MyH5DataSet
from torch.utils.data import DataLoader
import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader
dataset = MyH5DataSet( frame_cache=True)
cnt = len(dataset)
train_cnt = int(cnt * 0.9)
valid_cnt = cnt - train_cnt
print(cnt)
dataset_train = MyH5DataSet(id_list = dataset.frame_id_list[:train_cnt])
dataset_valid = MyH5DataSet(id_list = dataset.frame_id_list[train_cnt:])
train_loader = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
valid_loader = DataLoader(dataset_valid, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True)

In [None]:
from rgb_thermal_fusion_net import RGBThermalFusionNet
import matplotlib.pyplot as plt
from color_fusion_model import RGBNIRFusionNet
fusion_model = nn.DataParallel(RGBNIRFusionNet()).cuda()
#fusion_model = nn.DataParallel(RGBThermalFusionNet(hidden_dim=16)).cuda()
fusion_model.module.load_state_dict(torch.load("checkpoints/22550_ColorFusion.pth"))
#fusion_model.load_state_dict(torch.load("interrupted_model.pth"))
fusion_model = fusion_model.module

def loss_fn_detph_gt(flow: torch.Tensor, target_gt: torch.Tensor):
    gt_u = target_gt[:, :, 1].long()
    gt_v = target_gt[:, :, 0].long()
    gt_u = torch.clamp(gt_u, 0, flow.shape[-2] - 1)
    gt_v = torch.clamp(gt_v, 0, flow.shape[-1] - 1)
    B, N = gt_u.shape
    batch_indices = torch.arange(B).view(B, 1).expand(B, N).to(flow.device)
    target_pred = -flow[batch_indices, :, gt_u, gt_v].squeeze()

    target_depth = target_gt[:, :, 2]
    depth_loss = torch.sqrt(torch.mean((target_pred - target_depth) ** 2, dim=1))

    return depth_loss

In [None]:
train_iter = iter(train_loader)
train_input = next(train_iter)
fusion_model.eval()
image1, image2, image3, image4, depth = [x.cuda() for x in train_input]
with torch.no_grad():
    fused_input1 = torch.cat([image1, image3], dim=1)  # image1: RGB, image3: Thermal
    fused_input2 = torch.cat([image2, image4], dim=1)
    image_fusion_1 = fusion_model(fused_input1)
    image_fusion_2 = fusion_model(fused_input2)
print(image_fusion_1.min(), image_fusion_1.max(), image_fusion_1.mean())
image_fusion_1 -= image_fusion_1.min()
image_fusion_2 -= image_fusion_2.min()
# image_fusion_1 = image_fusion_1 / image_fusion_2.max() * 255
# image_fusion_2 = image_fusion_2 / image_fusion_2.max() * 255
print(image_fusion_1.min(), image_fusion_1.max(), image_fusion_1.mean())
with torch.no_grad():
    disparity_rgb = -compute_disparity(image1, image2)
    disparity_nir = -compute_disparity(image3, image4)
    disparity_fusion = -compute_disparity(image_fusion_1, image_fusion_2)
    loss_rgb = loss_fn_detph_gt(-disparity_rgb, depth).cpu()
    loss_nir = loss_fn_detph_gt(-disparity_nir, depth).cpu()
    loss_fusion = loss_fn_detph_gt(-disparity_fusion, depth).cpu()
    plt.figure(figsize=(10,5))
    plt.plot(loss_rgb)
    plt.plot(loss_nir)
    plt.plot(loss_fusion)
    plt.legend(["rgb","nir","fusion"])
    plt.show()



batch_size = image1.shape[0]
for idx in range(batch_size):
    plt.figure(figsize=(20, 5))
    plt.subplot(131)
    plt.imshow(image1[idx].permute(1,2,0).cpu().numpy().astype(np.uint8))
    plt.subplot(132)
    plt.imshow(image3[idx].permute(1,2,0).cpu().numpy().astype(np.uint8), cmap="gray")
    plt.subplot(133)
    plt.imshow(image_fusion_1[idx].permute(1,2,0).cpu().numpy().astype(np.uint8))
    plt.show()


    for vmax in [12, 64, 128]:
        plt.figure(figsize=(20, 5))
        plt.subplot(141)
        plt.imshow(disparity_rgb[idx,0].cpu().numpy(), cmap="magma", vmin=0, vmax=vmax)
        plt.subplot(142)
        plt.imshow(disparity_nir[idx,0].cpu().numpy(), cmap="magma", vmin=0, vmax=vmax)
        plt.subplot(143)
        plt.imshow(disparity_fusion[idx,0].cpu().numpy(), cmap="magma", vmin=0, vmax=vmax)
        ax = plt.subplot(144)

        lidar_depth = depth[idx].cpu().numpy()
        u, v = lidar_depth[:, 0], -lidar_depth[:, 1]
        z = lidar_depth[:, 2]
        sc = plt.scatter(
            u, v, c=z, cmap="magma", vmin = 0, vmax = vmax
        )
        plt.gca().set_aspect(540/720) 
        plt.colorbar(sc, ax=ax)
        plt.show()


In [None]:
try:
    from core.raft_stereo import RAFTStereo
except ImportError:
    import os
    os.chdir("/RAFT-Stereo")
    from core.raft_stereo import RAFTStereo
!torchrun --nproc_per_node=6 color_fusion_train.py