In [None]:
import sys
import os
try:
    from core.raft_stereo_fusion import RAFTStereoFusion
except ModuleNotFoundError:
    os.chdir("../")
    from core.raft_stereo_fusion import RAFTStereoFusion



from train_fusion.dataloader import StereoDataset

import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader
from fusion_args import FusionArgs
args = FusionArgs()


#################################################
args.restore_ckpt = "checkpoints/self10000_StereoFusion.pth"
#args.restore_ckpt = "models/raftstereo-realtime.pth"

#################################################
#################################################
#################################################

args.hidden_dims = [128, 128, 128]
args.corr_levels = 4
args.corr_radius = 4
args.n_downsample = 3
args.context_norm = "batch"
args.n_gru_layers = 2
args.shared_backbone = True
args.mixed_precision = True
args.corr_implementation = "reg_cuda"
args.slow_fast_gru = False


args.lr = 0.001
args.train_iters = 7

args.wdecay = 0.0001
args.num_steps = 100000
args.name = "StereoFusion"
args.batch_size = 4
args.fusion = "AFF"

model = nn.DataParallel(RAFTStereoFusion(args)).cuda()

model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
model = model
model.train()

In [None]:
#dataset = StereoDataset("/bean/depth", gt_depth=True, flying3d_json=True)
dataset = StereoDataset("/bean/depth",  real_data_json=True)

train_size = int(0.92 * len(dataset))
valid_size = len(dataset) - train_size
train_dataset = dataset.partial(0, train_size)
valid_dataset = dataset.partial(train_size, len(dataset))
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)

In [None]:
from tqdm.notebook import tqdm
from train_fusion.train import train, flow_gt_batch
def flow_gt_batch_attention_out(args, input, padder, valid_mode=False):
    """
    Batch Load function for real input data
    """
    (image_list, *data_blob) = input
    image1, image2, image3, image4, gt, gt_r = padder.pad(
        *[x.cuda() for x in data_blob[:6]]
    )

    return {
        "image_viz_left": image1,
        "image_viz_right": image2,
        "image_nir_left": image3,
        "image_nir_right": image4,
        "iters": args.train_iters if not valid_mode else args.valid_iters,
        "test_mode": False,
        "flow_init": None,
        "heuristic_nir": False,
        "attention_out_mode": True
    }, [gt]

from train_fusion.loss_function import self_supervised_loss, self_fm_loss, gt_loss
train(args, model, train_loader, valid_loader, tqdm, flow_gt_batch, self_supervised_loss)

In [None]:
from core.utils.utils import InputPadder
import cv2
from visualize.batch_input_visualize import batch_input_visualize

from train_fusion.loss_function import warp_reproject_loss, reproject_disparity
iterator = iter(train_loader)

In [None]:
from datastructure.train_input import TrainInput
from visualize.batch_featuremap_visualize import batch_featuremap_visualize
inputs = next(iterator)

image0, image1, image2, image3 = [x.cuda() for x in inputs[1:5]]

with torch.no_grad():
    outputs = model(
        TrainInput(
            {
                "image_viz_left": image0,
                "image_viz_right": image1,
                "image_nir_left": image2,
                "image_nir_right": image3,
                "iters": args.train_iters,
                "test_mode": True,
                "flow_init": None,
                "heuristic_nir": False,
                "attention_out_mode": True
            }
        )
    )
batch_featuremap_visualize((image0, image1, image2, image3), outputs)

In [None]:
import math

from sympy import plot
fx, fy = 1764.0609770407675, 1764.8799163219228
cx, cy = 712.9065552882782, 562.1627377324658
INPUT_RESOLUTION = (1080,1440)
INPUT_SCALE = 0.5
T = [-65.81824301985549, 0.5432960979771516,0.7294257881872517]


def plot_disparity_range(fx, T, max_disparity, INPUT_SCALE = 1.0):
    def disparity_to_depth(disparity):
        baseline = math.sqrt(T[0]**2 + T[1]**2 + T[2]**2)
        depth = (fx * INPUT_SCALE * baseline) / (disparity) / 1000
        return depth

    disparity_graph = np.arange(1, 64)
    depth = disparity_to_depth(disparity_graph)


    import matplotlib.pyplot as plt

    plt.plot(disparity_graph, depth)
    
    
plot_disparity_range(fx, T, 64, INPUT_SCALE = 0.5)
        
    
