In [1]:
!git clone https://github.com/munholy/depthsplat

Cloning into 'depthsplat'...
remote: Enumerating objects: 215, done.[K
remote: Counting objects: 100% (215/215), done.[K
remote: Compressing objects: 100% (171/171), done.[K
remote: Total 215 (delta 49), reused 185 (delta 34), pack-reused 0 (from 0)[K
Receiving objects: 100% (215/215), 2.25 MiB | 1.10 MiB/s, done.
Resolving deltas: 100% (49/49), done.


In [1]:
!pip install -r depthsplat/requirements.txt



In [2]:
!mkdir depthsplat/pretrained
!wget -P depthsplat/pretrained https://huggingface.co/haofeixu/depthsplat/resolve/main/depthsplat-gs-small-re10k-256x256-49b2d15c.pth

mkdir: cannot create directory ‘depthsplat/pretrained’: File exists
--2024-11-04 07:39:03--  https://huggingface.co/haofeixu/depthsplat/resolve/main/depthsplat-gs-small-re10k-256x256-49b2d15c.pth
Resolving huggingface.co (huggingface.co)... 3.168.178.101, 3.168.178.31, 3.168.178.58, ...
Connecting to huggingface.co (huggingface.co)|3.168.178.101|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.hf.co/repos/e2/4e/e24e7686e9c27778b01e5c45cf39b8fe8e3b18ca04a96fc220da07a8c9c27d6d/49b2d15cada75e1bb1b3a243567a7e0b52fa98bf4bc1d2719a99b73d4bf9b9ae?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27depthsplat-gs-small-re10k-256x256-49b2d15c.pth%3B+filename%3D%22depthsplat-gs-small-re10k-256x256-49b2d15c.pth%22%3B&Expires=1730965143&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczMDk2NTE0M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zL2UyLzRlL2UyNGU3Njg2ZTljMjc3NzhiMDFlNWM0NWN

In [3]:
import os
os.chdir(os.path.join(os.getcwd(),'depthsplat'))
print(os.getcwd())

/root/mun/depthsplat/depthsplat


In [4]:
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import torchvision.transforms as tf
from einops import repeat
from omegaconf import OmegaConf

from src.model.encoder.encoder_depthsplat import EncoderDepthSplat
from src.config import load_typed_config
from src.model.encoder import EncoderCfg
import rerun as rr

In [5]:

def process_images(image_paths, target_shape):
    """Process input images into proper tensor format"""
    images = []
    for path in image_paths:
        img = Image.open(path)
        img = tf.ToTensor()(img)
        if img.shape[-2:] != target_shape:
            img = F.interpolate(img.unsqueeze(0), size=target_shape, 
                                mode='bilinear', align_corners=True).squeeze(0)
        images.append(img)
    
    return torch.stack(images)

class DepthSplatDepthPredictor:
    def __init__(self, checkpoint_path):
        """
        Args:
            checkpoint_path: Path to the pretrained depth model checkpoint
        """


        # Load base config
        encoder_cfg = OmegaConf.load("config/model/encoder/depthsplat.yaml")
        
        encoder_cfg['num_depth_candidates'] = 128
        encoder_cfg['costvolume_unet_feat_dim'] = 128
        encoder_cfg['costvolume_unet_channel_mult'] = [1,1,1]
        encoder_cfg['costvolume_unet_attn_res'] = [4]
        encoder_cfg['gaussians_per_pixel'] = 1
        encoder_cfg['depth_unet_feat_dim'] = 32
        encoder_cfg['depth_unet_attn_res'] = [16]
        encoder_cfg['depth_unet_channel_mult'] = [1,1,1,1,1]

        if 'large' in checkpoint_path:
            ## large
            encoder_cfg['num_scales']=2
            encoder_cfg['upsample_factor']=2
            encoder_cfg['lowest_feature_resolution']=4
            encoder_cfg['monodepth_vit_type']='vitl'
            encoder_cfg['gaussian_regressor_channels']=64
            encoder_cfg['color_large_unet']=True
            encoder_cfg['feature_upsampler_channels']=128
        elif 'base' in checkpoint_path:
            ## base
            encoder_cfg['num_scales'] = 2
            encoder_cfg['upsample_factor'] = 2
            encoder_cfg['lowest_feature_resolution'] = 4
            encoder_cfg['monodepth_vit_type'] = "vitb"
            encoder_cfg['color_large_unet'] = True
            encoder_cfg['feature_upsampler_channels'] = 128
            encoder_cfg['gaussian_regressor_channels'] = 32
        else:
            ## small
            assert('small' in checkpoint_path)
            encoder_cfg['upsample_factor']=4
            encoder_cfg['lowest_feature_resolution']=4
            encoder_cfg['gaussian_regressor_channels']=16
            encoder_cfg['feature_upsampler_channels']=64


        encoder_cfg['return_depth'] = True
        encoder_cfg['train_depth_only'] = True


        cfg_encoder = load_typed_config(encoder_cfg, EncoderCfg)

        self.encoder = EncoderDepthSplat(cfg_encoder)
        
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        
        new_state_dict = {}
        for key in checkpoint['state_dict']:
            new_key = key.removeprefix('encoder.')
            new_state_dict[new_key] = checkpoint['state_dict'][key].clone().detach()
        self.encoder.load_state_dict(new_state_dict, strict = False)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.encoder = self.encoder.to(self.device)
        self.encoder.eval()
        

    def predict(self, images, intrinsics, extrinsics, translation_scaling = 5.0, logging = False):
        # Initialize rerun logging
        
        
        # Process inputs
        images = images.to(self.device)
        intrinsics = intrinsics.to(self.device)
        extrinsics = extrinsics.to(self.device)

        extrinsics[...,:3,3:] *= translation_scaling

        num_views = images.size(1)
        # Log input images
        if logging:
            for idx in range(num_views):
                
                # Log camera position and orientation
                cam_extrinsics = extrinsics[0, idx].cpu().numpy()
                R = cam_extrinsics[:3, :3]
                T = cam_extrinsics[:3, 3] / translation_scaling

                rr.log(f"world/camera_{idx}", 
                    rr.Transform3D(translation=T,
                                    mat3x3=R, axis_length= 0.1))

        # Create near and far tensors
        near = repeat(torch.tensor(0.5, dtype=torch.float32, device=self.device), "-> v", v=num_views).unsqueeze(0)
        far = repeat(torch.tensor(100, dtype=torch.float32, device=self.device), "-> v", v=num_views).unsqueeze(0)

        context = {
            'image': images,
            'extrinsics': extrinsics,
            'intrinsics': intrinsics,
            'near': near,
            'far': far
        }

        # Forward pass
        with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
            

            output = self.encoder(context, global_step=0, deterministic=False)
            pred_depths = output["depths"]

            pred_depths = pred_depths / translation_scaling

        if logging:
            # Visualize depth maps and point clouds
            for idx, depth in enumerate(pred_depths[0]):
                # Log depth map
                if logging:
                    rr.log(f"world/camera_{idx}/depth", rr.DepthImage(depth.cpu(), meter = 1.0))
                
                # Create point cloud from depth
                h, w = depth.shape
                y, x = torch.meshgrid(torch.arange(h) / float(h), torch.arange(w) / float(w), indexing='ij')
                pixels = torch.stack([x, y], dim=-1).float().to(self.device)
                
                # Get camera intrinsics
                fx = intrinsics[0, idx, 0, 0].item()
                fy = intrinsics[0, idx, 1, 1].item()
                cx = intrinsics[0, idx, 0, 2].item()
                cy = intrinsics[0, idx, 1, 2].item()
                
                # Convert to 3D points
                Z = depth.reshape(-1)
                X = (pixels[..., 0].reshape(-1) - cx) * Z / fx
                Y = (pixels[..., 1].reshape(-1) - cy) * Z / fy
                points = torch.stack([X, Y, Z], dim=-1)
                
                # Transform points to world coordinates
                # points_homogeneous = torch.cat([points, torch.ones_like(points[:, :1])], dim=-1)
                R = extrinsics[0, idx, :3, :3]
                T = extrinsics[0, idx, :3, 3] / translation_scaling
                world_points = (R @ points.T).T.add(T.view(1,3)).cpu()
                
                # Get colors from original image
                colors = images[0,idx].permute(1,2,0).cpu().numpy().reshape(-1, 3)
                
                # Log point cloud
                
                rr.log(f"world/points_{idx}", 
                    rr.Points3D(world_points[:, :3],
                                colors=colors,
                                radii=0.002))

        return pred_depths

In [6]:
rr.init("depthsplat_depth_prediction")

model_path = "pretrained/depthsplat-gs-small-re10k-256x256-49b2d15c.pth"
predictor = DepthSplatDepthPredictor(model_path)

image_paths = ['../sample/checker_left.png', '../sample/checker_right.png']

# Camera poses


R0 = torch.tensor([[-0.9859, -0.0948, -0.1377],
        [ 0.0489, -0.9512,  0.3045],
        [-0.1599,  0.2935,  0.9425]])
T0 = torch.tensor([ 0.2313, -0.3366, -0.4551])


ext1 = torch.eye(4)
ext1[:3,:3] = R0.T
ext1[:3,3] = T0

R1 = torch.tensor([[-0.9574,  0.1549,  0.2438],
        [-0.0665, -0.9396,  0.3359],
        [ 0.2811,  0.3053,  0.9098]])
T1 = torch.tensor([ 0.0512, -0.3418, -0.4622])

ext2 = torch.eye(4)
ext2[:3,:3] = R1.T
ext2[:3,3] = T1

extrinsics = torch.stack([ext1, ext2])[None]

# Camera intrinsics
target_shape = (480 , 640)
intrinsics = torch.tensor([386, 0, 320, 0, 386, 240, 0, 0, 1]).view(3,3).float()[None].repeat(2,1,1)[None]

##### Normalized intrinsics !!!! #####
intrinsics[:,:,0] = intrinsics[:,:,0] / target_shape[1]
intrinsics[:,:,1] = intrinsics[:,:,1] / target_shape[0]


images = process_images(image_paths, target_shape)
images = images[None]

# Get depth predictions
print(images.shape)
print(intrinsics.shape)
print(extrinsics.shape)
depth_maps = predictor.predict(images, intrinsics, extrinsics, logging=True)

rr.notebook_show(width=1600, height=800)

Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


torch.Size([1, 2, 3, 480, 640])
torch.Size([1, 2, 3, 3])
torch.Size([1, 2, 4, 4])


Viewer()