In [None]:
%load_ext autoreload

In [None]:
%autoreload 2
from torchdrive.data import collate
from torchdrive.datasets.rice import MultiCamDataset
from torchdrive.notebook import display_img, display_color, display, to_pil_image
from torchdrive.transforms.batch import NormalizeCarPosition

a = MultiCamDataset(
    index_file="../../openape/snapshots/out-mar23/index.txt",
    mask_dir="../../openape/masks",
    cameras=["main", "narrow", "fisheye", "leftpillar", "leftrepeater", "rightpillar", "rightrepeater", "backup"],
    cam_shape=(480, 640),
    nframes_per_point=5,
    limit_size=1000,
    dynamic=True,
)
print(len(a))
example = a[500]

#transform = NormalizeCarPosition(start_frame=0)
#batch = collate([example])
#batch = transform(batch)

#for cam in a.cameras:
#    display_img(example.color[cam][0].float())

In [None]:
%autoreload 2

import math
import torch
from pythreejs import *
from IPython.display import display
from torchdrive.transforms.batch import (
    NormalizeCarPosition,
    Compose,
    RandomRotation,
    RandomTranslation,
)

batch = collate([example])
transform=Compose(
    NormalizeCarPosition(start_frame=0),
    RandomRotation(),
    #RandomTranslation(distances=(5.0, 5.0, 0.0)),
)
batch = transform(batch)

SCALE = 3
D = 256/SCALE
W = 256/SCALE
H = 12/SCALE

view_width = 600
view_height = 400
camera = PerspectiveCamera( position=[-10, 6, 10], aspect=640/view_height)
camera.up = (0, 0, 1)
key_light = DirectionalLight(position=[0, 10, 10])
ambient_light = AmbientLight()

grid_helper1 = GridHelper(20, 20, '#888', '#444')
grid_helper10 = GridHelper(100, 10, '#888', '#444')
grid_helper1.rotateX(math.pi/2)
grid_helper10.rotateX(math.pi/2)

scene = Scene(children=[grid_helper1, grid_helper10, camera, key_light, ambient_light], background='#111')
#frame = 0
num_positions = batch.cam_T[0].size(1)
for frame in [0, 3,4]:#range(num_positions):

    # render car positions
    T = batch.car_to_world(frame)
    geo = AxesHelper(1)
    geo.matrixAutoUpdate = False
    geo.matrix = tuple(T.T.contiguous().view(-1).tolist())
    scene.add([geo])

    # render camera positions
    for cam in batch.T:
        T = batch.cam_to_world(cam, frame)
        geo = AxesHelper(1)
        geo.matrixAutoUpdate = False
        geo.matrix = tuple(T.T.contiguous().view(-1).tolist())
        scene.add([geo])

#sphere.position = (0, 0, 0)
renderer = Renderer(camera=camera, scene=scene, controls=[OrbitControls(controlling=camera)], width=view_width, height=view_height)
display(renderer)

In [None]:
import torch
import torch.nn.functional as F
from torchdrive.transforms.depth import Project3D, BackprojectDepth
from torchdrive.losses import multi_scale_projection_loss

torch.set_printoptions(precision=3, sci_mode=False)

device = torch.device('cuda')

# projecting from src to target
offset = 1
target_cam = "main"
target_frame = 2
src_cam = "main"
src_frame = target_frame+offset

batch = batch.to(device)

target_color = batch.color[target_cam][:, target_frame].float()
src_color = batch.color[src_cam][:, src_frame].float()
h, w = src_color.shape[-2:]

backproject_depth = BackprojectDepth(h, w).to(device)
project_3d = Project3D(h, w).to(device)

depth = torch.ones(1, 1, h//4, w//4, device=device) * 1
depth.requires_grad = True

src_K = batch.K[src_cam].clone()
# convert to image space
src_K[:, 0] *= backproject_depth.width
src_K[:, 1] *= backproject_depth.height

target_K = batch.K[target_cam].clone()
# convert to image space
target_K[:, 0] *= backproject_depth.width
target_K[:, 1] *= backproject_depth.height
target_inv_K = target_K.pinverse()

optimizer = torch.optim.AdamW([depth], lr=1e-2)
for i in range(1):
    optimizer.zero_grad()
    
    target_depth = F.interpolate(depth, scale_factor=4, mode='bilinear')

    # convert points to world
    target_cam_to_world = batch.cam_to_world(target_cam, target_frame)
    world_points = backproject_depth(
        target_depth, target_inv_K, target_cam_to_world
    ).clone()

    #print(world_points[0, :, 0])

    # (world to cam) * camera motion
    world_to_src_cam = batch.world_to_cam(src_cam, src_frame)
    pix_coords = project_3d(world_points, src_K, world_to_src_cam)

    proj_color = F.grid_sample(
        src_color,
        pix_coords,
        mode="bilinear",
        padding_mode="border",
        align_corners=False,
    )
    
    proj_loss = multi_scale_projection_loss(proj_color, target_color, scales=6)
    loss = proj_loss.mean()
    loss.backward()
    optimizer.step()
    
    if i % 200 == 0:
        print(i, loss, target_depth.aminmax())
display_color(proj_loss[0, 0])
display_img(src_color[0])
print("proj")
display_img(proj_color[0])
print("target")
display_img(target_color[0])
diff = (target_color[0]-proj_color[0]).abs().mean(dim=0)
print(diff.sum(), diff.aminmax())
display_color(diff)
print("depth", target_depth.shape)
display_color(target_depth[0, 0])
display_color(1/target_depth[0, 0])