In [None]:
try:
    from core.raft_stereo_fusion import RAFTStereoFusion
except ImportError:
    import os
    os.chdir("/RAFT-Stereo")
    from core.raft_stereo_fusion import RAFTStereoFusion

In [None]:
from train_fusion.dataloader import StereoDataset

import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader


In [None]:
from fusion_args import FusionArgs
args = FusionArgs()
args.hidden_dims = [128, 128, 128]
args.corr_levels = 4
args.corr_radius = 4
args.n_downsample = 3
args.context_norm = "batch"
args.n_gru_layers = 2
args.shared_backbone = True
args.mixed_precision = True
args.corr_implementation = "reg_cuda"
args.slow_fast_gru = False
args.restore_ckpt = "models/raftstereo-realtime.pth"


args.lr = 0.001
args.train_iters = 7
args.valid_iters = 12
args.wdecay = 0.0001
args.num_steps = 100000
args.valid_steps = 1000
args.name = "StereoFusion"
args.batch_size = 4
args.fusion = "AFF"
args.shared_fusion = True
args.freeze_backbone = []
args.both_side_train= True

In [None]:
dataset = StereoDataset("/bean/depth", gt_depth=True, flying3d_json=True)
#dataset = StereoDataset("/bean/depth", real_data_json= True,flow3d_driving_json=False, gt_depth=False)

train_size = int(0.95 * len(dataset))
valid_size = len(dataset) - train_size
train_dataset = dataset.partial(0, train_size)
valid_dataset = dataset.partial(train_size, len(dataset))
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
from core.utils.utils import InputPadder

#train_loader = DataLoader(StereoDataset("/bean/depth/08-01-19-37-50", flow3d_driving_json=True, gt_depth=True), batch_size=args.batch_size, shuffle=True, num_workers=4)

In [None]:
model = nn.DataParallel(RAFTStereoFusion(args)).cuda()

model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
model.train()

In [None]:
from tqdm.notebook import tqdm
from train_fusion.train import train, self_supervised_real_batch, flow_gt_batch
from train_fusion.loss_function import self_supervised_loss, self_fm_loss, gt_loss

train(
    args = args, 
    model = model, 
    train_loader = train_loader, 
    valid_loader = valid_loader, 
    tqdm = tqdm, 
    batch_loader_function=flow_gt_batch, 
    loss_function = gt_loss
)

In [None]:
from visualize.batch_input_visualize import batch_input_visualize

iterator = iter(valid_loader)
batch_input_visualize(next(iterator))

In [None]:
from datastructure.train_input import TrainInput
from visualize.batch_input_visualize import batch_input_visualize
from core.utils.utils import InputPadder
import cv2
from train_fusion.loss_function import warp_reproject_loss, reproject_disparity
iterator = iter(valid_loader)
inputs = next(iterator)
image0, image1, image2, image3, dis1, dis2 = [x.cuda() for x in inputs[1:]]
with torch.no_grad():
    _, flow = model(TrainInput(
            {
                "image_viz_left": image0,
                "image_viz_right": image1,
                "image_nir_left": image2,
                "image_nir_right": image3,
                "iters": args.train_iters,
                "test_mode": True,
                "flow_init": None,
                "heuristic_nir": False,
            }
        ).data_dict)
print(inputs[1].shape, flow.shape)
batch_input_visualize(inputs, flow)
