In [None]:

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'
try:
    from core.raft_stereo import RAFTStereo
except ImportError:
    import os
    os.chdir("/RAFT-Stereo")
    from core.raft_stereo import RAFTStereo
    
FRPASS = "frames_cleanpass"
from train_fusion.dataloader import StereoDataset, StereoDatasetArgs

import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader
from fusion_args import FusionArgs
args = FusionArgs()
args.hidden_dims = [128, 128, 128]
args.corr_levels = 4
args.corr_radius = 4
args.n_downsample = 3
args.context_norm = "batch"
args.n_gru_layers = 2
args.shared_backbone = True
args.mixed_precision = True
args.corr_implementation = "reg_cuda"
args.slow_fast_gru = False
args.restore_ckpt = "models/raftstereo-realtime.pth"


args.lr = 0.001
args.train_iters = 7
args.valid_iters = 12
args.wdecay = 0.0001
args.num_steps = 100000
args.valid_steps = 1000
args.name = "ColorFusion"
args.batch_size = 4
args.fusion = "AFF"
args.shared_fusion = True
args.freeze_backbone = []
args.both_side_train= False

In [None]:
raft_model = torch.nn.DataParallel(RAFTStereo(args)).cuda()
raft_model.load_state_dict(torch.load(args.restore_ckpt))
raft_model.eval()
raft_model = raft_model.module
raft_model.eval()

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
def fusion_rgb_nir(rgb: np.ndarray, nir: np.ndarray):
    gray = cv2.cvtColor(rgb, cv2.COLOR_BGR2GRAY)
    luminance_weight = (gray - nir) / 255
    
    ycrcb = cv2.cvtColor(rgb, cv2.COLOR_BGR2YCrCb)
    ycrcb_l = ycrcb[:, :, 0]
    ycrcb_l_sum = ycrcb_l * luminance_weight + nir * (1 - luminance_weight)
    ycrcb_a = ycrcb[:, :, 1]
    ycrcb_b = ycrcb[:, :, 2]
    m = (ycrcb_l - ycrcb_l_sum) / ycrcb_l
    
    ycrcb_a_new = ycrcb_a * (1+m)
    ycrcb_b_new = ycrcb_b * (1+m)
    ycrcb_new = np.stack([ycrcb_l_sum, ycrcb_a_new, ycrcb_b_new], axis=-1)
    ycrcb_new = np.clip(ycrcb_new, 0, 255).astype(np.uint8)
    fusion = cv2.cvtColor(ycrcb_new, cv2.COLOR_YCrCb2BGR)
    return fusion
    

In [None]:

frame = "/bean/depth/09-10-15-21-40/15_21_43_545"
left = cv2.imread(f"{frame}/rgb/left.png")
left_nir = cv2.imread(f"{frame}/nir/left.png", cv2.IMREAD_GRAYSCALE)
fusion = fusion_rgb_nir(left, left_nir)

plt.figure(figsize=(10, 10))
plt.subplot(131)
plt.imshow(cv2.cvtColor(left, cv2.COLOR_BGR2RGB))
plt.subplot(132)
plt.imshow(cv2.cvtColor(left_nir, cv2.COLOR_BGR2RGB))
plt.subplot(133)
plt.imshow(cv2.cvtColor(fusion, cv2.COLOR_BGR2RGB))

plt.show()


In [None]:
from train_fusion.my_h5_dataloader import MyH5DataSet
from torch.utils.data import DataLoader
dataset = MyH5DataSet( frame_cache=True)
cnt = len(dataset)
train_cnt = int(cnt * 0.9)
valid_cnt = cnt - train_cnt
print(cnt)
dataset_train = MyH5DataSet(id_list = dataset.frame_id_list[:train_cnt])
dataset_valid = MyH5DataSet(id_list = dataset.frame_id_list[train_cnt:])
train_loader = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(dataset_valid, batch_size=args.batch_size, shuffle=False, num_workers=4)

In [None]:
from color_fusion_model import RGBNIRFusionNet

fusion_model = RGBNIRFusionNet().cuda()



In [None]:
from typing import Tuple
from train_fusion.loss_function import warp_reproject_loss, reproject_disparity

def compute_disparity(left: torch.Tensor, right: torch.Tensor):
    if left.shape[-3] == 1:
        left = left.repeat(1, 3, 1, 1)
        right = right.repeat(1, 3, 1, 1)
    _, flow = raft_model(left, right, test_mode=True)
    return flow


def loss_fn_detph_gt(flow: torch.Tensor, target_gt: torch.Tensor):
    gt_u = target_gt[:,:,1].long()
    gt_v = target_gt[:,:,0].long()
    gt_u = torch.clamp(gt_u, 0, flow.shape[-2]-1)
    gt_v = torch.clamp(gt_v, 0, flow.shape[-1]-1)
    B, N = gt_u.shape
    batch_indices = torch.arange(B).view(B, 1).expand(B, N)
    target_pred = -flow[batch_indices,:,gt_u, gt_v].squeeze()
    
    target_depth = target_gt[:,:, 2]
    depth_loss = nn.MSELoss()(target_pred, target_depth)
    
    return depth_loss
def loss_fn(pred: Tuple[torch.Tensor, torch.Tensor], target: Tuple[Tuple[torch.Tensor,torch.Tensor],Tuple[torch.Tensor, torch.Tensor]], target_gt: torch.Tensor):
    flow = compute_disparity(pred[0], pred[1])
    flow = flow[:,:,:pred[0].shape[-2], :pred[0].shape[-1]]

    warp_loss_rgb, warp_metric_rgb = warp_reproject_loss([flow], *target[0])
    warp_loss_nir, warp_metric_nir = warp_reproject_loss([flow], *target[1])
    
    depth_loss = loss_fn_detph_gt(flow, target_gt)
    
    loss_dict = {
        **warp_metric_rgb,
    }
    for k, v in warp_metric_nir.items():
        loss_dict[f"{k}_nir"] = v
    loss_dict["depth_loss"] = depth_loss
    return warp_loss_rgb.mean() + warp_loss_nir.mean() + depth_loss, loss_dict
    


In [None]:
from typing import Dict


def validate_things(
    model,
    valid_loader: DataLoader,
):
    model.eval()
    metrics: Dict[str, torch.Tensor] = {}
    losses = []
    with torch.no_grad():
        for i_batch, input_valid in enumerate(valid_loader):
            image1, image2, image3, image4, depth = [x.cuda() for x in data_blob]

            
            image_fusion_1 = fusion_model(image1, image3)
            image_fusion_2 = fusion_model(image2, image4)
            
            loss, metric = loss_fn((image_fusion_1, image_fusion_2), ((image1, image2), (image3, image4)), depth)            

            print(f"Batch {i_batch} Loss {loss}")
            for k, v in metric.items():
                if k not in metrics:
                    metrics[k] = torch.tensor(0.0)
                metrics[k] += v / len(valid_loader)
            losses.append(loss.item())

    loss = sum(losses) / len(losses)

    return loss, metrics


In [None]:
from train_stereo import Logger
from torch.cuda.amp import GradScaler
from tqdm import tqdm
import logging
from pathlib import Path
import torch
from torch import optim


optimizer = optim.AdamW(
    fusion_model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8
)

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    args.lr,
    args.num_steps + 100,
    pct_start=0.01,
    cycle_momentum=False,
    anneal_strategy="linear",
)

total_steps = 0
logger = Logger(fusion_model, scheduler)

fusion_model.train()


validation_frequency = 10000

scaler = GradScaler(enabled=args.mixed_precision)

should_keep_training = True
global_batch_num = 0
for param in raft_model.parameters():
    param.requires_grad = False
    
while should_keep_training:
    raft_model.eval()
    for i_batch, data_blob in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        image1, image2, image3, image4, depth = [x.cuda() for x in data_blob]

        assert fusion_model.training
        
        image_fusion_1 = fusion_model(image1, image3)
        image_fusion_2 = fusion_model(image2, image4)
        
        
        assert fusion_model.training



        loss, metrics = loss_fn((image_fusion_1, image_fusion_2), ((image1, image2), (image3, image4)), depth)
        logger.writer.add_scalar("live_loss", loss.item(), global_batch_num)
        logger.writer.add_scalar(f'learning_rate', optimizer.param_groups[0]['lr'], global_batch_num)
        global_batch_num += 1
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(fusion_model.parameters(), 1.0)

        scaler.step(optimizer)
        scheduler.step()
        scaler.update()

        logger.push(metrics)

        if total_steps % validation_frequency == validation_frequency - 1:
            save_path = Path('checkpoints/%d_%s.pth' % (total_steps + 1, args.name))
            logging.info(f"Saving file {save_path.absolute()}")
            torch.save(fusion_model.state_dict(), save_path)

            results = validate_things(fusion_model.module, valid_loader)

            logger.write_dict(results)

            fusion_model.train()


        total_steps += 1

        if total_steps > args.num_steps:
            should_keep_training = False
            break

    if len(train_loader) >= 10000:
        save_path = Path('checkpoints/%d_epoch_%s.pth.gz' % (total_steps + 1, args.name))
        torch.save(fusion_model.state_dict(), save_path)

print("FINISHED TRAINING")
logger.close()

In [None]:
train_iter = iter(train_loader)
train_input = next(train_iter)
fusion_model.eval()
image1, image2, image3, image4, depth = [x.cuda() for x in train_input]
with torch.no_grad():
    image_fusion_1 = fusion_model(image1, image3)
    image_fusion_2 = fusion_model(image2, image4)
print(image_fusion_1.shape, image_fusion_1.max(), image3.shape)
plt.figure(figsize=(10, 20))
plt.subplot(131)
plt.imshow(image1[0].permute(1,2,0).cpu().numpy().astype(np.uint8))
plt.subplot(132)
plt.imshow(image3[0].permute(1,2,0).cpu().numpy().astype(np.uint8), cmap="gray")
plt.subplot(133)
plt.imshow(image_fusion_1[0].permute(1,2,0).cpu().numpy().astype(np.uint8))
plt.show()
with torch.no_grad():
    disparity_rgb = -compute_disparity(image1, image2)
    disparity_nir = -compute_disparity(image3, image4)
    disparity_fusion = -compute_disparity(image_fusion_1, image_fusion_2)
plt.figure(figsize=(10, 20))
plt.subplot(131)
plt.imshow(disparity_rgb[0,0].cpu().numpy())
plt.subplot(132)
plt.imshow(disparity_nir[0,0].cpu().numpy())
plt.subplot(133)
plt.imshow(disparity_fusion[0,0].cpu().numpy())
plt.show()