In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import importlib
import os
from pathlib import Path

import cv2
import matplotlib
import matplotlib.pyplot as plt
from mmengine import Config, DictAction
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import visu3d as v3d

In [None]:
def build_camera_model(H : int, W : int, intrinsics : list) -> np.array:
    """
    Encode the camera intrinsic parameters (focal length and principle point) to a 4-channel map. 
    """
    fx, fy, u0, v0 = intrinsics
    f = (fx + fy) / 2.0
    # principle point location
    x_row = np.arange(0, W).astype(np.float32)
    x_row_center_norm = (x_row - u0) / W
    x_center = np.tile(x_row_center_norm, (H, 1)) # [H, W]

    y_col = np.arange(0, H).astype(np.float32) 
    y_col_center_norm = (y_col - v0) / H
    y_center = np.tile(y_col_center_norm, (W, 1)).T # [H, W]

    # FoV
    fov_x = np.arctan(x_center / (f / W))
    fov_y = np.arctan(y_center / (f / H))

    cam_model = np.stack([x_center, y_center, fov_x, fov_y], axis=2)
    return cam_model

def resize_for_input(image, output_shape, intrinsic, canonical_shape, to_canonical_ratio):
    """
    Resize the input.
    Resizing consists of two processed, i.e. 1) to the canonical space (adjust the camera model); 2) resize the image while the camera model holds. Thus the
    label will be scaled with the resize factor.
    """
    padding = [123.675, 116.28, 103.53]
    h, w, _ = image.shape
    resize_ratio_h = output_shape[0] / canonical_shape[0]
    resize_ratio_w = output_shape[1] / canonical_shape[1]
    to_scale_ratio = min(resize_ratio_h, resize_ratio_w)

    resize_ratio = to_canonical_ratio * to_scale_ratio

    reshape_h = int(resize_ratio * h)
    reshape_w = int(resize_ratio * w)

    pad_h = max(output_shape[0] - reshape_h, 0)
    pad_w = max(output_shape[1] - reshape_w, 0)
    pad_h_half = int(pad_h / 2)
    pad_w_half = int(pad_w / 2)

    # resize
    image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
    # padding
    image = cv2.copyMakeBorder(
        image, 
        pad_h_half, 
        pad_h - pad_h_half, 
        pad_w_half, 
        pad_w - pad_w_half, 
        cv2.BORDER_CONSTANT, 
        value=padding)
    
    # Resize, adjust principle point
    intrinsic[2] = intrinsic[2] * to_scale_ratio
    intrinsic[3] = intrinsic[3] * to_scale_ratio

    cam_model = build_camera_model(reshape_h, reshape_w, intrinsic)
    cam_model = cv2.copyMakeBorder(
        cam_model, 
        pad_h_half, 
        pad_h - pad_h_half, 
        pad_w_half, 
        pad_w - pad_w_half, 
        cv2.BORDER_CONSTANT, 
        value=-1)

    pad = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half]
    label_scale_factor = 1 / to_scale_ratio
    return image, cam_model, pad, label_scale_factor

def transform_test_data_scalecano(rgb, intrinsic, data_basic):
    """
    Pre-process the input for forwarding. Employ `label scale canonical transformation.'
        Args:
            rgb: input rgb image. [H, W, 3]
            intrinsic: camera intrinsic parameter, [fx, fy, u0, v0]
            data_basic: predefined canonical space in configs.
    """
    canonical_space = data_basic["canonical_space"]
    forward_size = data_basic["crop_size"]
    mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None]
    std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None]

    # BGR to RGB
    # rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)

    ori_h, ori_w, _ = rgb.shape
    ori_focal = (intrinsic[0] + intrinsic[1]) / 2
    canonical_focal = canonical_space['focal_length']

    cano_label_scale_ratio = canonical_focal / ori_focal

    canonical_intrinsic = [
        intrinsic[0] * cano_label_scale_ratio,
        intrinsic[1] * cano_label_scale_ratio,
        intrinsic[2],
        intrinsic[3],
    ]

    # resize
    rgb, cam_model, pad, resize_label_scale_ratio = resize_for_input(rgb, forward_size, canonical_intrinsic, [ori_h, ori_w], 1.0)

    # label scale factor
    label_scale_factor = cano_label_scale_ratio * resize_label_scale_ratio

    rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float()
    rgb = torch.div((rgb - mean), std)
    rgb = rgb.cuda()
    
    cam_model = torch.from_numpy(cam_model.transpose((2, 0, 1))).float()
    cam_model = cam_model[None, :, :, :].cuda()
    cam_model_stacks = [
        torch.nn.functional.interpolate(cam_model, size=(cam_model.shape[2]//i, cam_model.shape[3]//i), mode='bilinear', align_corners=False)
        for i in [2, 4, 8, 16, 32]
    ]
    return rgb, cam_model_stacks, pad, label_scale_factor

def align_scale(pred: torch.tensor, target: torch.tensor):
    mask = target > 0
    if torch.sum(mask) > 10:
        scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8)
    else:
        scale = 1
    pred_scaled = pred * scale
    return pred_scaled, scale

def get_prediction(
    model: torch.nn.Module,
    input: torch.tensor,
    cam_model: torch.tensor,
    pad_info: torch.tensor,
    scale_info: torch.tensor,
    gt_depth: torch.tensor,
    normalize_scale: float,
    ori_shape: list=[],
):

    data = dict(
        input=input,
        cam_model=cam_model,
    )
    #pred_depth, confidence, output_dict = model.module.inference(data)
    pred_depth, confidence, output_dict = model.inference(data)
    pred_depth = pred_depth.squeeze()
    pred_depth = pred_depth[pad_info[0] : pred_depth.shape[0] - pad_info[1], pad_info[2] : pred_depth.shape[1] - pad_info[3]]
    confidence = confidence.squeeze()
    confidence = confidence[pad_info[0] : confidence.shape[0] - pad_info[1], pad_info[2] : confidence.shape[1] - pad_info[3]]
    if gt_depth is not None:
        resize_shape = gt_depth.shape
    elif ori_shape != []:
        resize_shape = ori_shape
    else:
        resize_shape = pred_depth.shape

    pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], resize_shape, mode='bilinear').squeeze() # to original size
    pred_depth = pred_depth * normalize_scale / scale_info
    if gt_depth is not None:
        pred_depth_scale, scale = align_scale(pred_depth, gt_depth)
    else:
        pred_depth_scale = None
        scale = None

    return pred_depth, pred_depth_scale, scale, output_dict, confidence

def gray_to_colormap(img, cmap='rainbow'):
    """
    Transfer gray map to matplotlib colormap
    """
    assert img.ndim == 2

    img[img<0] = 0
    mask_invalid = img < 1e-10
    img = img / (img.max() + 1e-8)
    norm = matplotlib.colors.Normalize(vmin=0, vmax=1.1)
    cmap_m = matplotlib.cm.get_cmap(cmap)
    map = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap_m)
    colormap = (map.to_rgba(img)[:, :, :3] * 255).astype(np.uint8)
    colormap[mask_invalid] = 0
    return colormap

def vis_surface_normal(normal: torch.tensor, mask: torch.tensor=None) -> np.array:
    """
    Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255]
    Aargs:
        normal (torch.tensor, [h, w, 3]): surface normal
        mask (torch.tensor, [h, w]): valid masks
    """
    normal = normal.cpu().numpy().squeeze()
    n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True))
    n_img_norm = normal / (n_img_L2 + 1e-8)
    normal_vis = n_img_norm * 127
    normal_vis += 128
    normal_vis = normal_vis.astype(np.uint8)
    if mask is not None:
        mask = mask.cpu().numpy().squeeze()
        normal_vis[~mask] = 0
    return normal_vis

class BaseDepthModel(nn.Module):
    def __init__(self, cfg, **kwargs) -> None:
        super(BaseDepthModel, self).__init__()
        model_type = cfg.model.type
        self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg)

    def forward(self, data):
        output = self.depth_model(**data)

        return output['prediction'], output['confidence'], output

    def inference(self, data):
        with torch.no_grad():
            pred_depth, confidence, _ = self.forward(data)
        return pred_depth, confidence


def get_func(func_name):
    """
        Helper to return a function object by name. func_name must identify 
        a function in this module or the path to a function relative to the base
        module.
        @ func_name: function name.
    """
    if func_name == '':
        return None
    try:
        parts = func_name.split('.')
        # Refers to a function in this module
        if len(parts) == 1:
            return globals()[parts[0]]
        # Otherwise, assume we're referencing a module under modeling
        module_name = '.'.join(parts[:-1])
        module = importlib.import_module(module_name)
        return getattr(module, parts[-1])
    except:
        raise RuntimeError(f'Failed to find function: {func_name}')

class DepthModel(BaseDepthModel):
    def __init__(self, cfg, **kwargs):
        super(DepthModel, self).__init__(cfg)   
        model_type = cfg.model.type
        
    def inference(self, data):
        with torch.no_grad():
            pred_depth, confidence, output_dict = self.forward(data)       
        return pred_depth, confidence, output_dict

def get_monodepth_model(
    cfg : dict,
    **kwargs
    ) -> nn.Module:
    # config depth  model
    model = DepthModel(cfg, **kwargs)
    #model.init_weights(load_imagenet_model, imagenet_ckpt_fpath)
    assert isinstance(model, nn.Module)
    return model

def get_configured_monodepth_model(
    cfg: dict,
    ) -> nn.Module:
    """
        Args:
        @ configs: configures for the network.
        @ load_imagenet_model: whether to initialize from ImageNet-pretrained model.
        @ imagenet_ckpt_fpath: string representing path to file with weights to initialize model with.
        Returns:
        # model: depth model.
    """
    model = get_monodepth_model(cfg)
    return model

def load_ckpt(load_path, model, optimizer=None, scheduler=None, strict_match=True, loss_scaler=None):
    """
    Load the check point for resuming training or finetuning.
    """
    if os.path.isfile(load_path):
        checkpoint = torch.load(load_path, map_location="cpu")
        ckpt_state_dict  = checkpoint['model_state_dict']
        model.load_state_dict(ckpt_state_dict, strict=strict_match)

        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer'])
        if scheduler is not None:
            scheduler.load_state_dict(checkpoint['scheduler'])
        if loss_scaler is not None and 'scaler' in checkpoint:
            scheduler.load_state_dict(checkpoint['scaler'])
        del ckpt_state_dict
        del checkpoint
    return model, optimizer, scheduler, loss_scaler

def get_pcd_base(H, W, u0, v0, fx, fy):
    x_row = np.arange(0, W)
    x = np.tile(x_row, (H, 1))
    x = x.astype(np.float32)
    u_m_u0 = x - u0

    y_col = np.arange(0, H)  # y_col = np.arange(0, height)
    y = np.tile(y_col, (W, 1)).T
    y = y.astype(np.float32)
    v_m_v0 = y - v0

    x = u_m_u0 / fx
    y = v_m_v0 / fy
    z = np.ones_like(x)
    pw = np.stack([x, y, z], axis=2)  # [h, w, c]
    return pw

def reconstruct_pcd(depth, fx, fy, u0, v0, pcd_base=None, mask=None):
    if type(depth) == torch.__name__:
        depth = depth.cpu().numpy().squeeze()
    depth = cv2.medianBlur(depth, 5)
    if pcd_base is None:
        H, W = depth.shape
        pcd_base = get_pcd_base(H, W, u0, v0, fx, fy)
    pcd = depth[:, :, None] * pcd_base
    if mask:
        pcd[mask] = 0
    return pcd

In [None]:
cfg = Config.fromfile('barrelnet/metric3d/vit.raft5.large.py')
model = get_configured_monodepth_model(cfg)
model, _,  _, _ = load_ckpt('checkpoints/metric_depth_vit_large_800k.pth', model, strict_match=False)
# model = torch.hub.load("yvanyin/metric3d", "metric3d_vit_large", pretrain=True)
model.cuda().eval()

In [None]:
imgpath = Path("data/barrel2/cropped0005.jpg")
# imgpath = Path("data/files_underwater.jpg")
fx = 1000.0
fy = 1000.0
img = cv2.imread(str(imgpath))
cv_image = np.array(img)
img = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
intrinsic = [fx, fy, img.shape[1] / 2, img.shape[0] / 2]
rgb_input, cam_models_stacks, pad, label_scale_factor = transform_test_data_scalecano(img, intrinsic, cfg.data_basic)
with torch.no_grad():
    pred_depth, pred_depth_scale, scale, output, confidence = get_prediction(
        model = model,
        input = rgb_input.unsqueeze(0),
        cam_model = cam_models_stacks,
        pad_info = pad,
        scale_info = label_scale_factor,
        gt_depth = None,
        normalize_scale = cfg.data_basic.depth_range[1],
        ori_shape=[img.shape[0], img.shape[1]],
    )
    pred_normal = output['normal_out_list'][0][:, :3, :, :] 
    H, W = pred_normal.shape[2:]
    pred_normal = pred_normal[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3]]

pred_depth = pred_depth.squeeze().cpu().numpy()
pred_depth[pred_depth<0] = 0
pred_color = gray_to_colormap(pred_depth)

pred_normal = torch.nn.functional.interpolate(pred_normal, [img.shape[0], img.shape[1]], mode='bilinear').squeeze()
pred_normal = pred_normal.permute(1,2,0)
pred_color_normal = vis_surface_normal(pred_normal)
pred_normal = pred_normal.cpu().numpy()

In [None]:
pred_depth

In [None]:
plt.imshow(pred_depth / np.max(pred_depth))

In [None]:
plt.imshow(pred_color_normal)

In [None]:
pcd = reconstruct_pcd(pred_depth, intrinsic[0], intrinsic[1], intrinsic[2], intrinsic[3])

In [None]:
v3d.Point3d(p=pcd.reshape(-1, 3)).fig

In [None]:
imgpath = Path("data/barrel2/cropped0005.jpg")
intrinsic = [1250.0, 1250.0, 960.0, 437.5]  # [fx, fy, cx, cy]
rgb_origin = cv2.cvtColor(cv2.imread(str(imgpath)), cv2.COLOR_BGR2RGB)

#### adjust input size to fit pretrained model
# keep ratio resize
input_size = (616, 1064) # for vit model
# input_size = (544, 1216) # for convnext model
h, w = rgb_origin.shape[:2]
scale = min(input_size[0] / h, input_size[1] / w)
rgb = cv2.resize(rgb_origin, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
# remember to scale intrinsic, hold depth
intrinsic = [intrinsic[0] * scale, intrinsic[1] * scale, intrinsic[2] * scale, intrinsic[3] * scale]
# padding to input_size
padding = [123.675, 116.28, 103.53]
h, w = rgb.shape[:2]
pad_h = input_size[0] - h
pad_w = input_size[1] - w
pad_h_half = pad_h // 2
pad_w_half = pad_w // 2
rgb = cv2.copyMakeBorder(rgb, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=padding)
pad_info = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half]

#### normalize
mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None]
std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None]
rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float()
rgb = torch.div((rgb - mean), std)
rgb = rgb[None, :, :, :].cuda()

In [None]:
model.cuda().eval()
with torch.no_grad():
    pred_depth, confidence, output_dict = model.inference({"input": rgb})
# un pad
pred_depth = pred_depth.squeeze()
pred_depth = pred_depth[pad_info[0] : pred_depth.shape[0] - pad_info[1], pad_info[2] : pred_depth.shape[1] - pad_info[3]]

# upsample to original size
pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], rgb_origin.shape[:2], mode="bilinear").squeeze()
#### de-canonical transform
canonical_to_real_scale = intrinsic[0] / 1000.0 # 1000.0 is the focal length of canonical camera
pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric
pred_depth = torch.clamp(pred_depth, 0, 300)
# pred_normal = output_dict["prediction_normal"][:, :3, :, :] # only available for Metric3Dv2 i.e., ViT models
# normal_confidence = output_dict["prediction_normal"][:, 3, :, :] # see https://arxiv.org/abs/2109.09881 for details

In [None]:
depths = pred_depth.cpu().numpy()
depths

In [None]:
plt.imshow(depths)

In [None]:
xx, yy = np.meshgrid(np.arange(depths.shape[0]).astype(float), np.arange(depths.shape[1]).astype(float))

In [None]:
p = np.array([xx.reshape(-1), yy.reshape(-1), depths.reshape(-1)]).T
pc = v3d.Point3d(p=p)

In [None]:
pc.fig

In [None]:
#### prepare data
rgb_file = 'data/kitti_demo/rgb/0000000050.png'
depth_file = 'data/kitti_demo/depth/0000000050.png'
intrinsic = [707.0493, 707.0493, 604.0814, 180.5066]
gt_depth_scale = 256.0
rgb_origin = cv2.imread(rgb_file)[:, :, ::-1]

#### ajust input size to fit pretrained model
# keep ratio resize
input_size = (616, 1064) # for vit model
# input_size = (544, 1216) # for convnext model
h, w = rgb_origin.shape[:2]
scale = min(input_size[0] / h, input_size[1] / w)
rgb = cv2.resize(rgb_origin, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
# remember to scale intrinsic, hold depth
intrinsic = [intrinsic[0] * scale, intrinsic[1] * scale, intrinsic[2] * scale, intrinsic[3] * scale]
# padding to input_size
padding = [123.675, 116.28, 103.53]
h, w = rgb.shape[:2]
pad_h = input_size[0] - h
pad_w = input_size[1] - w
pad_h_half = pad_h // 2
pad_w_half = pad_w // 2
rgb = cv2.copyMakeBorder(rgb, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=padding)
pad_info = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half]

#### normalize
mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None]
std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None]
rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float()
rgb = torch.div((rgb - mean), std)
rgb = rgb[None, :, :, :].cuda()

###################### canonical camera space ######################
# inference
model = torch.hub.load('yvanyin/metric3d', 'metric3d_vit_small', pretrain=True)
model.cuda().eval()
with torch.no_grad():
    pred_depth, confidence, output_dict = model.inference({'input': rgb})

# un pad
pred_depth = pred_depth.squeeze()
pred_depth = pred_depth[pad_info[0] : pred_depth.shape[0] - pad_info[1], pad_info[2] : pred_depth.shape[1] - pad_info[3]]

# upsample to original size
pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], rgb_origin.shape[:2], mode='bilinear').squeeze()
###################### canonical camera space ######################

#### de-canonical transform
canonical_to_real_scale = intrinsic[0] / 1000.0 # 1000.0 is the focal length of canonical camera
pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric
pred_depth = torch.clamp(pred_depth, 0, 300)

#### you can now do anything with the metric depth 
# such as evaluate predicted depth
if depth_file is not None:
    gt_depth = cv2.imread(depth_file, -1)
    gt_depth = gt_depth / gt_depth_scale
    gt_depth = torch.from_numpy(gt_depth).float().cuda()
    assert gt_depth.shape == pred_depth.shape
    
    mask = (gt_depth > 1e-8)
    abs_rel_err = (torch.abs(pred_depth[mask] - gt_depth[mask]) / gt_depth[mask]).mean()
    print('abs_rel_err:', abs_rel_err.item())

#### normal are also available
if 'prediction_normal' in output_dict: # only available for Metric3Dv2, i.e. vit model
    pred_normal = output_dict['prediction_normal'][:, :3, :, :]
    normal_confidence = output_dict['prediction_normal'][:, 3, :, :] # see https://arxiv.org/abs/2109.09881 for details
    # un pad and resize to some size if needed
    pred_normal = pred_normal.squeeze()
    pred_normal = pred_normal[:, pad_info[0] : pred_normal.shape[1] - pad_info[1], pad_info[2] : pred_normal.shape[2] - pad_info[3]]
    # you can now do anything with the normal
    # such as visualize pred_normal
    pred_normal_vis = pred_normal.cpu().numpy().transpose((1, 2, 0))
    pred_normal_vis = (pred_normal_vis + 1) / 2
    cv2.imwrite('normal_vis.png', (pred_normal_vis * 255).astype(np.uint8))