In [None]:
import sys
import os

os.chdir("../")
from core.raft_stereo_fusion import RAFTStereoFusion
from train_fusion.dataloader import StereoDataset

import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader
from fusion_args import FusionArgs
args = FusionArgs()


#################################################
args.restore_ckpt = "checkpoints/self10000_StereoFusion.pth"
#args.restore_ckpt = "models/raftstereo-realtime.pth"

#################################################
#################################################
#################################################

args.hidden_dims = [128, 128, 128]
args.corr_levels = 4
args.corr_radius = 4
args.n_downsample = 3
args.context_norm = "batch"
args.n_gru_layers = 2
args.shared_backbone = True
args.mixed_precision = True
args.corr_implementation = "reg_cuda"
args.slow_fast_gru = False


args.lr = 0.001
args.train_iters = 7

args.wdecay = 0.0001
args.num_steps = 100000
args.name = "StereoFusion"
args.batch_size = 4
args.fusion = "AFF"

model = nn.DataParallel(RAFTStereoFusion(args)).cuda()

model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
model = model
model.train()

In [None]:
from core.utils.utils import InputPadder
dataset = StereoDataset("/bean/depth", real_data_json=True, cut_resolution=(540,720))
train_size = int(0.8 * len(dataset))
valid_size = len(dataset) - train_size
train_dataset = dataset.partial(0, train_size)
valid_dataset = dataset.partial(train_size, len(dataset))
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)


iterator = iter(train_loader)


In [None]:
from tqdm.notebook import tqdm
from train_fusion.train import train, self_supervised_real_batch, flow_gt_batch
from train_fusion.loss_function import self_supervised_loss, self_fm_loss, gt_loss
train(
    args = args, 
    model = model, 
    train_loader = train_loader, 
    valid_loader = valid_loader, 
    tqdm = tqdm, 
    batch_loader_function=self_supervised_real_batch, 
    loss_function = self_supervised_loss
)

In [None]:

import cv2
from datastructure.train_input import TrainInput
from visualize.batch_input_visualize import batch_input_visualize

from train_fusion.loss_function import warp_reproject_loss, reproject_disparity
iterator = iter(train_loader)
inputs = next(iterator)
image0, image1, image2, image3 = [x.cuda() for x in inputs[1:]]


with torch.no_grad():
    _, flow = model(
            TrainInput({
                "image_viz_left": image0,
                "image_viz_right": image1,
                "image_nir_left": image2,
                "image_nir_right": image3,
                "iters": args.valid_iters,
                "test_mode": True,
                "flow_init": None,
                "heuristic_nir": False,
            })
        )
batch_input_visualize(inputs, flow)
