In [None]:
try:
    from core.raft_stereo_fusion import RAFTStereoFusion
except ImportError:
    import os
    os.chdir("/RAFT-Stereo")
    from core.raft_stereo_fusion import RAFTStereoFusion

In [None]:
from train_fusion.dataloader import StereoDataset, StereoDatasetArgs

import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader


In [None]:
from fusion_args import FusionArgs
args = FusionArgs()
args.hidden_dims = [128, 128, 128]
args.corr_levels = 4
args.corr_radius = 4
args.n_downsample = 3
args.context_norm = "batch"
args.n_gru_layers = 2
args.shared_backbone = True
args.mixed_precision = True
args.corr_implementation = "reg_cuda"
args.slow_fast_gru = False
args.restore_ckpt = "models/raftstereo-realtime.pth"
#args.restore_ckpt = "checkpoints/14000_StereoFusion.pth"

args.lr = 0.001
args.train_iters = 3
args.valid_iters = 8
args.wdecay = 0.0001
args.num_steps = 100000
args.valid_steps = 1000
args.name = "StereoFusion"
args.batch_size = 6
args.fusion = "AFF"
args.shared_fusion = False
args.freeze_backbone = ["BatchNorm", "Extractor"]
args.both_side_train= False

In [None]:
dataset = StereoDataset( StereoDatasetArgs(
        "/bean/depth",
        flying3d_json=True,
        flow3d_driving_json=False,
        gt_depth=True,
        validate_json=False,
        synth_no_filter=False,
        synth_no_rgb=True
    ))
#dataset = StereoDataset("/bean/depth", real_data_json= True,flow3d_driving_json=False, gt_depth=False)

train_size = int(0.99 * len(dataset))
valid_size = len(dataset) - train_size
train_dataset = dataset.partial(0, train_size)
valid_dataset = dataset.partial(train_size, len(dataset))
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)


#train_loader = DataLoader(StereoDataset("/bean/depth/08-01-19-37-50", flow3d_driving_json=True, gt_depth=True), batch_size=args.batch_size, shuffle=True, num_workers=4)

In [None]:
print(len(dataset))

In [None]:
model = nn.DataParallel(RAFTStereoFusion(args)).cuda()

model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
model.train()

In [None]:
from tqdm.notebook import tqdm
from train_fusion.train import train, self_supervised_real_batch, flow_gt_batch
from train_fusion.loss_function import self_supervised_loss, self_fm_loss, gt_loss

train(
    args = args, 
    model = model, 
    train_loader = train_loader, 
    valid_loader = valid_loader, 
    tqdm = tqdm, 
    batch_loader_function=flow_gt_batch, 
    loss_function = gt_loss,
    spectral_feature=True,
)

In [None]:
from visualize.batch_input_visualize import batch_input_visualize

iterator = iter(valid_loader)
batch_input_visualize(next(iterator))

In [None]:
print(model.module.cnet.outputs08.state_dict())

In [None]:
from datastructure.train_input import TrainInput
from visualize.batch_input_visualize import batch_input_visualize
from core.utils.utils import InputPadder
import cv2
from torchvision.transforms.functional import pad
from train_fusion.loss_function import warp_reproject_loss, reproject_disparity
iterator = iter(valid_loader)
next(iterator)
inputs = next(iterator)
image0, image1, image2, image3, dis1, dis2 = [x.cuda() for x in inputs[1:]]
with torch.no_grad():
    with torch.cuda.amp.autocast(enabled=True):
        fusion, rgb, nir= model(TrainInput(
                {
                    "image_viz_left": image0,
                    "image_viz_right": image1,
                    "image_nir_left": image2,
                    "image_nir_right": image3,
                    "iters": args.train_iters,
                    "test_mode": True,
                    "flow_init": None,
                    "heuristic_nir": False,
                    "attention_out_mode": True,
                }
            ).data_dict)
        print(rgb[0].shape)
        print(rgb[0].min(), rgb[0].max(), fusion[0].min(), fusion[0].max(), nir[0].min(), nir[0].max())
        print(model.module.cnet.fusion.attention_rgb(rgb[0]))#, rgb, nir)
    
    image0, image1, image2, image3 = model.module.batch_preprocess(TrainInput(
            {
                "image_viz_left": image0,
                "image_viz_right": image1,
                "image_nir_left": image2,
                "image_nir_right": image3,
                "iters": args.train_iters,
                "test_mode": True,
                "flow_init": None,
                "heuristic_nir": False,
                "attention_out_mode": True,
            }
        ))
    
    # outputs = model.module.cnet(
    #     (image0 / 127.5 - 1).contiguous(), (image2.repeat(1, 3, 1, 1) / 127.5 - 1).contiguous(), True, 2, True
    # )
    with torch.cuda.amp.autocast(enabled=True):
        outputs = model.module.cnet(
            image0, image2.repeat(1, 3, 1, 1), True, 2, True
        )
        
_fusion = outputs[-3]
rgb = outputs[-2]
nir = outputs[-1]

#_fusion = model.module.cnet.fusion(rgb.float(), nir.float())

#print(rgb)
#print(nir[0])
#print(fusion[0])
#print(_fusion.shape)
#print(_fusion)

# rgb = torch.concat([rgb[0], rgb[1]],dim=0)
# nir = torch.concat([nir[0],nir[1]],dim=0)
# print(rgb.shape)
# print(nir.shape)
# print(fusion[0].shape)
# x = model.module.cnet.fusion(rgb.float(), nir.float())
# print(x.shape)
# #print(model.module.cnet.outputs08[0](rgb.float()))
# #print(model.module.cnet.outputs08[0](nir.float()))

# print(x.mean())

# print(fusion[1].mean())
# #batch_input_visualize(inputs, flow)


In [None]:
from datastructure.train_input import TrainInput
from visualize.batch_input_visualize import batch_input_visualize
from core.utils.utils import InputPadder
import cv2
from torchvision.transforms.functional import pad
from train_fusion.loss_function import warp_reproject_loss, reproject_disparity
iterator = iter(valid_loader)
inputs = next(iterator)
image0, image1, image2, image3, dis1, dis2 = [x.cuda() for x in inputs[1:]]
with torch.no_grad():
    flow = model(TrainInput(
            {
                "image_viz_left": image0,
                "image_viz_right": image1,
                "image_nir_left": image2,
                "image_nir_right": image3,
                "iters": args.valid_iters,
                "test_mode": False,
                "flow_init": None,
                "heuristic_nir": False,
                "attention_out_mode": False,
            }
        ).data_dict)
    
    
batch_input_visualize(inputs, flow[-1])
