In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataclasses import dataclass
import math
import time
from typing import Dict, Literal, Tuple

import tyro

from datasets.clip import OpenCLIPNetwork, OpenCLIPNetworkConfig
import nerfview
import torch
from torch import Tensor
import viser
from utils import SAMOptModule
from gsplat.rendering import rasterization

from gsplat._helper import *


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import matplotlib.pyplot as plt
import numpy as np

def display_image(image_array: np.ndarray):
    """
    Display an image from a numpy array.

    Parameters:
    image_array (numpy.ndarray): The image array to display. Expected shape is (height, width, channels).
    """
    plt.imshow(image_array)
    plt.axis('off')  # Turn off axis
    plt.show()

In [4]:
import torch
ckpt = torch.load("results/nerf_dff_dataset_depth/ckpts/ckpt_29999.pt", map_location="cuda")

In [5]:
def _get_sam_module(sam_state_dict):
    n = sam_state_dict["embeds.weight"].shape[0]
    feature_dim = sam_state_dict["feature_head.4.weight"].shape[0]
    embed_dim = sam_state_dict["embeds.weight"].shape[1]
    mlp_width = sam_state_dict["color_head.0.weight"].shape[0]
    sh_degree = int(
        math.sqrt(
            sam_state_dict["color_head.0.weight"].shape[1] - feature_dim - embed_dim
        )
        - 1
    )
    mlp_depth = len(sam_state_dict) // 2 - 4  ### TODO: double check this

    sam_module = SAMOptModule(
        n=n,
        feature_dim=feature_dim,
        embed_dim=embed_dim,
        sh_degree=sh_degree,
        mlp_width=mlp_width,
        mlp_depth=mlp_depth,
        output_dim=feature_dim,
    ).to("cuda")

    sam_module.load_state_dict(sam_state_dict)
    print("sam_module", sam_module)

    return sam_module, sh_degree

sam_module, sh_degree = _get_sam_module(ckpt["sam_module"])

sam_module SAMOptModule(
  (embeds): Embedding(49, 256)
  (color_head): Sequential(
    (0): Linear(in_features=784, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=64, out_features=3, bias=True)
  )
  (feature_head): Sequential(
    (0): Linear(in_features=784, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=64, out_features=512, bias=True)
  )
)


In [6]:
splats = ckpt["splats"]
splats.keys()

odict_keys(['colors', 'features', 'means3d', 'opacities', 'quats', 'scales'])

In [7]:
clip = OpenCLIPNetwork(OpenCLIPNetworkConfig)
clip

OpenCLIPNetwork(
  (model): CLIP(
    (visual): VisionTransformer(
      (conv1): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
      (patch_dropout): Identity()
      (ln_pre): LayerNormFp32((768,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNormFp32((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNormFp32((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
        

In [8]:
import pickle
with open('viewer_render_fn_inputs.pkl', 'rb') as f:
    inputs = pickle.load(f)

inputs.keys()

dict_keys(['camera_state', 'img_wh', 'kwargs'])

In [9]:
@torch.no_grad()
def viewer_render_fn(
    camera_state: nerfview.CameraState, img_wh: Tuple[int, int], **kwargs
):
    feature_query = kwargs.get("feature_query", None)
    print("feature_query: ", feature_query)

    tok_phrase = clip.tokenizer(feature_query).to("cuda")
    feature_embeds = clip.model.encode_text(tok_phrase)  # [1, 512]
    feature_embeds /= feature_embeds.norm(dim=-1, keepdim=True)

    W, H = img_wh
    c2w = camera_state.c2w
    K = camera_state.get_K(img_wh)
    c2w = torch.from_numpy(c2w).float().to("cuda")
    K = torch.from_numpy(K).float().to("cuda")

    render_colors, _, _, features, xyz = rasterize_splats(
        camtoworlds=c2w[None],
        Ks=K[None],
        width=W,
        height=H,
        radius_clip=3.0,  # skip GSs that have small image radius (in gt_colors)
        feature_embeds=feature_embeds,
        # segment=True,
        **kwargs,
    )
    colors = render_colors[..., :3]  # [1, H, W, 3]
    render_features = render_colors[..., 3:]  # [1, H, W, 512]

    feature_embeds = feature_embeds.view(1, 1, 1, 512)
    cosine_similarity = F.cosine_similarity(render_features, feature_embeds, dim=-1)

    threshold = kwargs.get("feature_similarity_threshold", 0.3)

    mask = cosine_similarity > threshold

    new_color = torch.tensor([1.0, 0.0, 0.0], device=colors.device)
    colors[mask] = new_color

    xyz = xyz[..., :3]


    return colors[0].cpu().numpy(), features, xyz[0].cpu().numpy()

def rasterize_splats(
        camtoworlds: Tensor,
        Ks: Tensor,
        width: int,
        height: int,
        **kwargs,
    ) -> Tuple[Tensor, Tensor, Dict]:
        means = splats["means3d"]
        print("NUM OF GAUSSIANS: ", means.shape[0])
        quats = splats["quats"]
        scales = torch.exp(splats["scales"])
        opacities = torch.sigmoid(splats["opacities"])

        image_ids = kwargs.pop("image_ids", None)

        # get colors from sam_module
        colors, features = sam_module(
            features=splats["features"],
            embed_ids=image_ids,
            dirs=means[None, :, :] - camtoworlds[:, None, :3, 3],
            sh_degree=sh_degree,
        )
        colors = colors + splats["colors"]
        colors = torch.sigmoid(colors)

        features = features + splats["features"]
        colors_with_features = torch.cat([colors, features], dim=-1)

        render_colors, render_alphas, info = rasterization(
            means=means,
            quats=quats,
            scales=scales,
            opacities=opacities,
            colors=colors_with_features,
            viewmats=torch.linalg.inv(camtoworlds),  # [C, 4, 4]
            Ks=Ks,  # [C, 3, 3]
            width=width,
            height=height,
            packed=False,
            absgrad=False,
            sparse_grad=False,
            rasterize_mode="classic",
            sh_degree=None,
            **kwargs,
        )

        rendered_xyz, _, _ = rasterization(
            means=means,
            quats=quats,
            scales=scales,
            opacities=opacities,
            colors=means,
            viewmats=torch.linalg.inv(camtoworlds),  # [C, 4, 4]
            Ks=Ks,  # [C, 3, 3]
            width=width,
            height=height,
            packed=False,
            absgrad=False,
            sparse_grad=False,
            rasterize_mode="classic",
            sh_degree=None,
            **kwargs,
        )

        # xyz_flattened = rendered_xyz.view(-1, 3)
        # distances = torch.cdist(xyz_flattened, means)



        
        return render_colors, render_alphas, info, features, rendered_xyz

In [12]:
query = "carrot"
render_image, features, xyz = viewer_render_fn(camera_state=inputs["camera_state"], img_wh=inputs["img_wh"], feature_query=query, feature_similarity_threshold=0.25)

[?25l[32m( ●    )[0m [1;33mgsplat: Setting up CUDA with MAX_JOBS=10 (This may take a few minutes [0m
[2K[1A[2K[32m(  ●   )[0m [1;33mgsplat: Setting up CUDA with MAX_JOBS=10 (This may take a few minutes [0m
[2K[1A[2K[32m(  ●   )[0m [1;33mgsplat: Setting up CUDA with MAX_JOBS=10 (This may take a few minutes [0m
[1;33mthe first time)[0m
[1A[2K[1A[2Kfeature_query:  carrot
NUM OF GAUSSIANS:  386490


[autoreload of gsplat.cuda._backend failed: Traceback (most recent call last):
  File "/sfs/weka/scratch/jqm9ba/Repos/gsplat/gsplat/cuda/_backend.py", line 55, in <module>
    from gsplat import csrc as _C
ImportError: cannot import name 'csrc' from 'gsplat' (/sfs/weka/scratch/jqm9ba/Repos/gsplat/gsplat/__init__.py)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/scratch/jqm9ba/envs/gsplat-cuda/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/scratch/jqm9ba/envs/gsplat-cuda/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
  File "/scratch/jqm9ba/envs/gsplat-cuda/lib/python3.9/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 613, in _exec
  File "<frozen importlib._bootstrap_external>", line 850, 

AttributeError: 'NoneType' object has no attribute 'fully_fused_projection_fwd'

In [None]:
display_image(render_image)

NameError: name 'render_image' is not defined

In [None]:
display_image(xyz)

NameError: name 'xyz' is not defined