In [2]:
%load_ext autoreload
%autoreload 2
import torch
import math
import os
import imageio
import mcubes
import trimesh
import numpy as np
import argparse
from PIL import Image

from lrm.models.generator import LRMGenerator
from lrm.cam_utils import build_camera_principle, build_camera_standard, center_looking_at_camera_pose
from lrm.inferrer import LRMInferrer



In [None]:
infer = LRMInferrer('openlrm-base-obj-1.0')

In [None]:
type(infer.model) # lrm.models.generator.LRMGenerator
infer.model.encoder_feat_dim # 768
infer.model.camera_embed_dim # 1024
infer.model.encoder # DInoWrapper
infer.model.camera_embedder 
# CameraEmbedder(
#   (mlp): Sequential(
#     (0): Linear(in_features=16, out_features=1024, bias=True)
#     (1): SiLU()
#     (2): Linear(in_features=1024, out_features=1024, bias=True)
#   )
# )
infer.model.transformer


In [None]:
infer.infer_kwargs

In [None]:
source_image = './assets/sample_input/hydrant.png'
source_image_size = infer.infer_kwargs['source_size']
image = torch.tensor(np.array(Image.open(source_image))).permute(2, 0, 1).unsqueeze(0) / 255.0
print("image shape :  ",image.shape)
image = torch.nn.functional.interpolate(image, size=(source_image_size, source_image_size), mode='bicubic', align_corners=True)
print("reshaped image shape :  ",image.shape)
image = torch.clamp(image, 0, 1)
image = image.to(infer.device)
render_size = infer.infer_kwargs['render_size']
mesh_size = 384
mesh_thres = 3.0
chunck_size = 2
batch_size = 1

source_camera = infer._default_source_camera(batch_size).to(infer.device)
print("source_camera shape :  ",source_camera.shape)
print("source_camera :  \n",source_camera.reshape(-1, 4, 4))
render_cameras = infer._default_render_cameras(batch_size).to(infer.device)
print("render_cameras shape :  ",render_cameras.shape)

In [None]:
with torch.no_grad():
    planes = infer.model.forward_planes(image, source_camera)
    print(planes.shape)

In [None]:
with torch.no_grad():
    N = image.shape[0]
    print("image shape :  ",image.shape)
    image_feat = infer.model.encoder(image)
    print("image_feat shape :  ",image_feat.shape)
    camera_embed = infer.model.camera_embedder(source_camera)
    print("source_camera shape :  ",source_camera.shape)
    print("camera_embed shape :  ",camera_embed.shape)
    planes = infer.model.transformer(image_feat, camera_embed)
    print("planes shape :  ",planes.shape)

In [None]:
import torch
import torch.nn as nn
from lrm.models.transformer import ConditionModulationBlock
class TriplaneTransformer(nn.Module):
    """
    Transformer with condition and modulation that generates a triplane representation.
    
    Reference:
    Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
    """
    def __init__(self, inner_dim: int, image_feat_dim: int, camera_embed_dim: int,
                 triplane_low_res: int, triplane_high_res: int, triplane_dim: int,
                 num_layers: int, num_heads: int,
                 eps: float = 1e-6):
        super().__init__()

        # attributes
        self.triplane_low_res = triplane_low_res
        self.triplane_high_res = triplane_high_res
        self.triplane_dim = triplane_dim

        # modules
        # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
        self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
        self.layers = nn.ModuleList([
            ConditionModulationBlock(
                inner_dim=inner_dim, cond_dim=image_feat_dim, mod_dim=camera_embed_dim, num_heads=num_heads, eps=eps)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(inner_dim, eps=eps)
        self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)

    def forward(self, image_feats, camera_embeddings):
        # image_feats: [N, L_cond, D_cond]
        # camera_embeddings: [N, D_mod]

        assert image_feats.shape[0] == camera_embeddings.shape[0], \
            f"Mismatched batch size: {image_feats.shape[0]} vs {camera_embeddings.shape[0]}"

        N = image_feats.shape[0]
        H = W = self.triplane_low_res
        L = 3 * H * W

        x = self.pos_embed.repeat(N, 1, 1)  # [N, L, D]
        for layer in self.layers:
            x = layer(x, image_feats, camera_embeddings)
        x = self.norm(x)

        # separate each plane and apply deconv
        x = x.view(N, 3, H, W, -1)
        x = torch.einsum('nihwd->indhw', x)  # [3, N, D, H, W]
        x = x.contiguous().view(3*N, -1, H, W)  # [3*N, D, H, W]
        x = self.deconv(x)  # [3*N, D', H', W']
        x = x.view(3, N, *x.shape[-3:])  # [3, N, D', H', W']
        x = torch.einsum('indhw->nidhw', x)  # [N, 3, D', H', W']
        x = x.contiguous()

        assert self.triplane_high_res == x.shape[-2], \
            f"Output triplane resolution does not match with expected: {x.shape[-2]} vs {self.triplane_high_res}"
        assert self.triplane_dim == x.shape[-3], \
            f"Output triplane dimension does not match with expected: {x.shape[-3]} vs {self.triplane_dim}"

        return x


In [None]:
checkpoint = torch.load('./.cache/openlrm-base-obj-1.0/model.pth', map_location=infer.device)
checkpoint['kwargs']
print(checkpoint['kwargs'])
triplaneTrans = TriplaneTransformer(
    inner_dim = checkpoint['kwargs']['model']['transformer_dim'],
    num_layers=checkpoint['kwargs']['model']['transformer_layers'],
    num_heads=checkpoint['kwargs']['model']['transformer_heads'],
    image_feat_dim=768,
    camera_embed_dim=checkpoint['kwargs']['model']['camera_embed_dim'],
    triplane_low_res=checkpoint['kwargs']['model']['triplane_low_res'],
    triplane_high_res=checkpoint['kwargs']['model']['triplane_high_res'],
    triplane_dim=checkpoint['kwargs']['model']['triplane_dim'],
    
)

In [None]:
infer.model.transformer.pos_embed.shape

In [None]:
print(infer._default_intrinsics())
# tensor([[384., 384.], fx fy
        # [256., 256.], cx cy
        # [512., 512.]]) w h
print(infer._default_source_camera(batch_size).reshape(-1, 4, 4))
# tensor([[[ 1.0000,  0.0000,  0.0000,  0.0000], RT
#          [ 0.0000,  0.0000, -1.0000, -2.0000],
#          [ 0.0000,  1.0000,  0.0000,  0.0000],
#          [ 0.7500,  0.7500,  0.5000,  0.5000]]]) 384/512
print(infer._default_source_camera(batch_size).shape)
print(infer._default_render_cameras(batch_size).shape)

In [None]:
print(infer._get_surrounding_views().shape) # defualt - 160 surrounding views RT
print(infer._get_surrounding_views()[40])

# print(infer._default_render_cameras(batch_size).shape)
# print(infer._default_render_cameras(batch_size).reshape(-1, 4, 4))


Test source camera

In [6]:
from lrm.inferrer import LRMInferrer
source_image = './assets/sample_input/hydrant.png'
infer = LRMInferrer('openlrm-base-obj-1.0')
source_image_size = infer.infer_kwargs['source_size']
image = torch.tensor(np.array(Image.open(source_image))).permute(2, 0, 1).unsqueeze(0) / 255.0
print("image shape :  ",image.shape)
image = torch.nn.functional.interpolate(image, size=(source_image_size, source_image_size), mode='bicubic', align_corners=True)
print("reshaped image shape :  ",image.shape)
image = torch.clamp(image, 0, 1)
image = image.to(infer.device)
render_size = infer.infer_kwargs['render_size']
mesh_size = 384
mesh_thres = 3.0
chunck_size = 2
batch_size = 1

source_camera = infer._default_source_camera(batch_size).to(infer.device)
source_camera2 = infer._default_source_camera2(batch_size).to(infer.device)
print("source_camera :  \n",source_camera.reshape(-1, 4, 4))
print("source_camera2 : \n",source_camera2.reshape(-1, 4, 4))
render_cameras = infer._default_render_cameras(batch_size).to(infer.device)
# print("render_cameras shape :  ",render_cameras.shape)
with torch.no_grad():
    planes = infer.model.forward_planes(image, source_camera)
    grid_out = infer.model.synthesizer.forward_grid(planes = planes, grid_size=mesh_size)
    planes2 = infer.model.forward_planes(image, source_camera2)
    grid_out2 = infer.model.synthesizer.forward_grid(planes = planes2, grid_size=mesh_size)
    



image shape :   torch.Size([1, 3, 512, 512])
reshaped image shape :   torch.Size([1, 3, 256, 256])
source_camera :  
 tensor([[[ 1.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -1.0000, -2.0000],
         [ 0.0000,  1.0000,  0.0000,  0.0000],
         [ 0.7500,  0.7500,  0.5000,  0.5000]]], device='cuda:0')
source_camera2 : 
 tensor([[[ 1.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  1.0000,  0.0000, -2.0000],
         [ 0.0000,  0.0000,  1.0000,  0.0000],
         [ 0.7500,  0.7500,  0.5000,  0.5000]]], device='cuda:0')


In [34]:
print(grid_out.keys())
print(grid_out['rgb'].shape)
print(grid_out['sigma'].shape)
vtx, faces = mcubes.marching_cubes(grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), mesh_thres)
print(vtx.shape, faces.shape, faces.shape[0] / vtx.shape[0])
print(vtx.min(axis = 0))
print(vtx.max(axis = 0))
vtx_post = vtx / mesh_size * 2 - 1
print(vtx_post.min(axis = 0), vtx_post.max(axis = 0))
vtx_tensor = torch.tensor(vtx_post, dtype = torch.float32, device = infer.device).unsqueeze(0)
vtx_color  = infer.model.synthesizer.forward_points(planes , vtx_tensor)
minmax = vtx_post.max(axis = 0) - vtx_post.min(axis = 0)


dict_keys(['rgb', 'sigma'])
torch.Size([1, 384, 384, 384, 3])
torch.Size([1, 384, 384, 384, 1])
(268222, 3) (536468, 3) 2.0000894781188716
[90.12259924 99.13544342  0.2275228 ]
[291.81161441 278.64620625 381.41512328]
[-0.53061146 -0.48366957 -0.99881499] [0.51985216 0.45128232 0.9865371 ]


In [35]:
print(grid_out2.keys())
print(grid_out2['rgb'].shape)
print(grid_out2['sigma'].shape)
vtx2, faces2 = mcubes.marching_cubes(grid_out2['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), mesh_thres)
print(vtx2.shape, faces2.shape, faces2.shape[0] / vtx2.shape[0])
print(vtx2.min(axis = 0))
print(vtx2.max(axis = 0))
vtx_post2 = vtx2 / mesh_size * 2 - 1
print(vtx_post2.min(axis = 0), vtx_post2.max(axis = 0))
vtx_tensor2 = torch.tensor(vtx_post2, dtype = torch.float32, device = infer.device).unsqueeze(0)
vtx_color2  = infer.model.synthesizer.forward_points(planes2 , vtx_tensor2)
minmax2 = vtx_post2.max(axis = 0) - vtx_post2.min(axis = 0)
print( vtx_post2.max(axis=0)- vtx_post2.min(axis=0))


dict_keys(['rgb', 'sigma'])
torch.Size([1, 384, 384, 384, 3])
torch.Size([1, 384, 384, 384, 1])
(244114, 3) (488240, 3) 2.000049157360905
[ 94.09826899 100.20328071  12.0248415 ]
[287.78613444 275.51346265 371.76421675]
[-0.50990485 -0.47810791 -0.93737062] [0.49888612 0.43496595 0.93627196]
[1.00879097 0.91307386 1.87364258]


In [36]:
print(minmax2 [2]/minmax2[0], minmax2[2]/minmax2[1])

1.7836339522051574