In [None]:
cd /private/home/ronghanghu/workspace/mmf_nr

In [None]:
import os
import numpy as np
import skimage.io
import cv2
import torch
import torchvision
import argparse
import omegaconf
import matplotlib.pyplot as plt
import quaternion
from tqdm import tqdm

from mmf.utils.env import setup_imports
from mmf.utils.configuration import Configuration
from mmf.utils.build import build_config, build_model
from mmf.common.sample import SampleList, Sample


def get_config_from_opts(opts):
    setup_imports()

    args = argparse.Namespace(config_override=None)
    args.opts = opts

    configuration = Configuration(args)
    config = build_config(configuration)
    return config


def load_model(config, device, ckpt_file=None):
    attributes = config.model_config[config.model]
    # Easy way to point to config for other model
    if isinstance(attributes, str):
        attributes = config.model_config[attributes]

    with omegaconf.open_dict(attributes):
        attributes.model = config.model

    model = build_model(attributes)
    model = model.to(device)
    model.eval()

    if ckpt_file is not None:
        state_dict = torch.load(ckpt_file, map_location=device)["model"]
        try:
            model.load_state_dict(state_dict, strict=True)
        except Exception as e:
            print(e)
            print('retry loading with `strict=False`')
            model.load_state_dict(state_dict, strict=False)

    return model


def build_sample_list(img_0, depth_0, R_0, T_0, R_1, T_1, image_transform):
    sample = Sample()
    sample.orig_img_0 = torch.tensor(img_0)
    sample.depth_0 = torch.tensor(depth_0)
    sample.trans_img_0 = image_transform(sample.orig_img_0.permute((2, 0, 1)))
    sample.R_0 = torch.tensor(R_0)
    sample.T_0 = torch.tensor(T_0)
    sample.R_1 = torch.tensor(R_1)
    sample.T_1 = torch.tensor(T_1)
    sample_list = SampleList([sample]).to(device)
    return sample_list

In [None]:
def get_habitat_position_rotation(R, T):
    P = np.eye(4, dtype=np.float32)
    P[0:3, 0:3] = R.T
    P[0:3, 3] = T
    
    # change from Habitat coordinates to PyTorch3D coordinates
    P[0] *= -1  # flip X axis
    P[2] *= -1  # flip Z axis
    
    Pinv = np.linalg.inv(P)
    position = Pinv[0:3, 3]
    rotation = Pinv[0:3, 0:3]
    rotation = quaternion.from_rotation_matrix(rotation)
    
    return position, rotation

    
def get_pytorch3d_camera_RT(position, rotation):
    rotation = quaternion.as_rotation_matrix(rotation)

    Pinv = np.eye(4, dtype=np.float32)
    Pinv[0:3, 0:3] = rotation
    Pinv[0:3, 3] = position
    P = np.linalg.inv(Pinv)

    # change from Habitat coordinates to PyTorch3D coordinates
    P[0] *= -1  # flip X axis
    P[2] *= -1  # flip Z axis

    R = P[0:3, 0:3].T  # to row major
    T = P[0:3, 3]

    return R, T

In [None]:
def rotate_camera(R_in, T_in, degree_right):
    position_in, rotation_in = get_habitat_position_rotation(R_in, T_in)
    angle = -degree_right * np.pi / 180

    horizontal_rotation = quaternion.from_float_array(
        [np.cos(angle), 0, np.sin(angle), 0]
    )  # wxyz-format
    rotation_out = horizontal_rotation * rotation_in
    R_out, T_out = get_pytorch3d_camera_RT(position_in, rotation_out)
    return R_out, T_out


def move_camera(R_in, T_in, front, right, distance):
    position_in, rotation_in = get_habitat_position_rotation(R_in, T_in)

    # transform direction vector from camera to world
    direction_vec = np.array([right, 0, -front], np.float32)
    direction_vec = quaternion.as_rotation_matrix(rotation_in) @ direction_vec

    # remove motion along Y axis (vertical) and re-normalize
    direction_vec[1] = 0.
    direction_vec = direction_vec / np.linalg.norm(direction_vec)

    position_out = position_in + direction_vec * distance
    R_out, T_out = get_pytorch3d_camera_RT(position_out, rotation_in)
    return R_out, T_out

In [None]:
exp_name = "depth"
split = "val"

opts = [
    f"config=projects/neural_rendering/configs/diode/{exp_name}.yaml",
    f"datasets=diode",
    f"model=mesh_renderer",
    f"training.batch_size=1",
    f"model_config.mesh_renderer.return_rendering_results_only=True",
#     f"model_config.mesh_renderer.fill_z_with_gt=True",
    f"model_config.mesh_renderer.grid_stride=4",
]

device = torch.device("cuda:1")
torch.cuda.set_device(device)

ckpt_file = f"./save/diode/{exp_name}/best.ckpt"
if not os.path.exists(ckpt_file):
    ckpt_file = ckpt_file.replace("best.ckpt", "current.ckpt")
assert os.path.exists(ckpt_file)

config = get_config_from_opts(opts)
model = load_model(config, device, ckpt_file)

In [None]:
import matplotlib.pyplot as plt
import sys
sys.path.append('/private/home/ronghanghu/workspace/DATASETS/diode-devkit/')

import diode
import numpy as np
import scipy as sp
import torch
import skimage.transform


class DiodeProcessor:
    def __init__(self):
        self.CROP_W = 734
        self.OUT_SIZE = 256

    def crop_and_resize(self, im, order):
        H, W = im.shape[:2]
        diff = (W - self.CROP_W) // 2
        im = im[:, diff:-diff]
        out = skimage.transform.resize(im, (self.OUT_SIZE, self.OUT_SIZE), order=order)
        return out
    
    def __call__(self, im, de, de_mask):
        de = de.copy()
        de[de_mask == 0] = 0
        im = self.crop_and_resize(im, order=None)
        de = self.crop_and_resize(de, order=0)  # nearest neighbor sampling on depth map
        de = de.astype(np.float32)
        
        # also downsample the mask and use it to mask invalid regions
        de_mask = self.crop_and_resize(de_mask, order=None)
        de_mask = (1 - de_mask < 1e-8)
        de[~de_mask] = 0
        return im, de


dataset = diode.DIODE(
    meta_fname='/private/home/ronghanghu/workspace/DATASETS/diode-devkit/diode_meta.json',
    data_root='/checkpoint/ronghanghu/neural_rendering_datasets/diode/',
    splits=['val'],
    scene_types=['outdoor']
)

processor = DiodeProcessor()

idx = 103
im, de, de_mask = dataset[idx]
im_out, de_out = processor(im, de, de_mask)

In [None]:
# saved_data_file = f"/checkpoint/ronghanghu/neural_rendering_datasets/replica/{split}/data/{scene}.npz"
# saved_results_file = f'./save/visualization/{exp_name}/{split}/{scene}_outputs.npz'

# d = np.load(saved_data_file)
# data = dict(d)
# d.close()
# d = np.load(saved_results_file)
# data.update(d)
# d.close()

# normalize with ResNet-50 preprocessing
image_transform = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
)

In [None]:
R_init = np.eye(3, dtype=np.float32)
T_init = np.zeros(3, dtype=np.float32)

data = {
    'orig_img_0': im_out.astype(np.float32),
    'depth_0': de_out,
}

plt.figure(figsize=(11, 4))
plt.subplot(1, 2, 1)
plt.imshow(im_out)
plt.subplot(1, 2, 2)
plt.imshow(np.log(de_out))
plt.colorbar()

In [None]:
def take_pic(r, t, return_torch_tensors=False, show=True):
    sample_list = build_sample_list(
        img_0=data['orig_img_0'],
        depth_0=data['depth_0'],
        R_0=R_init,
        T_0=T_init,
        R_1=r,
        T_1=t,
        image_transform=image_transform
    ).to(device)

    if return_torch_tensors:
        rendering_results = model.forward(sample_list)
        rgba_out = rendering_results['rgba_out_rec_list'][1][0, ..., :3]
        depth_out = rendering_results['depth_out_rec_list'][1][0]
        return rgba_out, depth_out

    with torch.no_grad():
        rendering_results = model.forward(sample_list)

    rgba_out = rendering_results['rgba_out_rec_list'][1][0, ..., :3].cpu().numpy()
    rgba_out = np.clip(rgba_out, 0, 1)
    depth_out = rendering_results['depth_out_rec_list'][1][0].cpu().numpy()
    if show:
        plt.figure(figsize=(11, 4))
        plt.subplot(1, 2, 1)
        plt.imshow(rgba_out[..., :3])
        plt.subplot(1, 2, 2)
        plt.imshow(np.log(depth_out))
        plt.colorbar()

    return rgba_out


def rotate(R, T, angles, cameras, sampling=3):
    if not isinstance(angles, list):
        angles = [angles]
    for n, e in enumerate(angles):
        b = angles[n-1] if n > 0 else 0
        for a in np.linspace(b, e, int(np.abs(e-b)*sampling)):
            R_new, T_new = rotate_camera(R, T, a)
            cameras.append((R_new, T_new))

    return R_new, T_new


def move(R, T, distances, cameras, sampling=6):
    if not isinstance(distances, list):
        distances = [distances]
    for n, e in enumerate(distances):
        b = distances[n-1] if n > 0 else 0
        for d in np.linspace(b, e, int(np.abs(e-b)*sampling)):
            R_new, T_new = move_camera(R, T, 1, 0, d)
            cameras.append((R_new, T_new))

    return R_new, T_new

In [None]:
d_gt = torch.tensor(de_out, device=device)
d_mask = d_gt.gt(0.1).float()

solver = torch.optim.Adam(model.parameters(), lr=1e-4)
for n_iter in range(40):
    solver.zero_grad()
    _, d = take_pic(R_init, T_init, return_torch_tensors=True, show=False)
    loss = torch.abs((d - d_gt) * d_mask).sum()
    print(loss.item())
    loss.backward()
    solver.step()

#     if n_iter % 5 == 0:
#         take_pic(R_init, T_init)
#         plt.title(f'iter {n_iter}')

In [None]:
plt.close('all')

cameras = []
R_new, T_new = R_init, T_init

for _ in range(30):
    cameras.append((R_new, T_new))

take_pic(R_new, T_new)
R_new, T_new = move(R_new, T_new, 5, cameras); take_pic(R_new, T_new)
R_new, T_new = rotate(R_new, T_new, -20, cameras); take_pic(R_new, T_new)
R_new, T_new = move(R_new, T_new, 5, cameras); take_pic(R_new, T_new)
R_new, T_new = rotate(R_new, T_new, 60, cameras); take_pic(R_new, T_new)
R_new, T_new = rotate(R_new, T_new, -60, cameras); take_pic(R_new, T_new)
R_new, T_new = move(R_new, T_new, -5, cameras); take_pic(R_new, T_new)
R_new, T_new = rotate(R_new, T_new, 20, cameras); take_pic(R_new, T_new)
R_new, T_new = move(R_new, T_new, -5, cameras); take_pic(R_new, T_new)
R_new, T_new = rotate(R_new, T_new, [30, -20, 0], cameras); take_pic(R_new, T_new)

for _ in range(30):
    cameras.append((R_new, T_new))

In [None]:
frames = []
for r, t in tqdm(cameras):
    rgba_out = take_pic(r, t, show=False)

    # float32 -> uint8, RGB -> BGR
    frames.append(skimage.img_as_ubyte(rgba_out[..., ::-1]))

video_file = f"/private/home/ronghanghu/workspace/gtdepth65x65fit_depth_diode_{idx:04d}.mp4"
fourcc = cv2.VideoWriter_fourcc(*"MP4V")
fps = 30
frame_size = (256, 256)
writer = cv2.VideoWriter(video_file, fourcc, fps, frame_size)

for img in frames:
    writer.write(img)
writer.release()