In [1]:
import os
import clip
import torch
import numpy as np
from tqdm import tqdm
from clip_images import init_CLIP
from ae_training import load_model
from plyfile import PlyData, PlyElement
from prompts2 import idx2class, idx2prompts
from readers.clip_reader import clipReader
from torch.utils.data import DataLoader, TensorDataset
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# prompts for zero shot
idx2prompts

{0: ['an image of a wall',
  'a photo of a wall in a bedroom',
  'a closeup of a wall in an indoor scene'],
 1: ['an image of a door',
  'a photo of a door in a bedroom',
  'a closeup of a door in an indoor scene'],
 2: ['an image of a bed',
  'a photo of a bed in a bedroom',
  'a closeup of a bed in an indoor scene'],
 3: ['an image of a chair',
  'a photo of a chair in a bedroom',
  'a closeup of a chair in an indoor scene'],
 4: ['an image of a stool',
  'a photo of a stool in a bedroom',
  'a closeup of a stool in an indoor scene'],
 5: ['an image of a nightstand',
  'a photo of a nightstand in a bedroom',
  'a closeup of a nightstand in an indoor scene'],
 6: ['an image of a toilet',
  'a photo of a toilet in a bedroom',
  'a closeup of a toilet in an indoor scene'],
 7: ['an image of a dressing table',
  'a photo of a dressing table in a bedroom',
  'a closeup of a dressing table in an indoor scene'],
 8: ['an image of a wardrobe',
  'a photo of a wardrobe in a bedroom',
  'a clo

In [None]:

# scene specific auto encoder
ae,_,_,_ = load_model("/home/akshaysm/semantics/understanding_room1/ae_model.pth")

In [None]:
# semantic gs result
ckpt_sem = torch.load("/home/akshaysm/semantics/understanding_room1/ckpt_sem.pt")
# visual gs result
ckpt_col = torch.load("/home/akshaysm/semantics/understanding_room1/ckpt_col.pt")


coords = ckpt_sem["splats"]["means"].to(device)
features = ckpt_sem["splats"]["features"].to(device)
colors = ckpt_col['splats']["sh0"].squeeze(1).to(device)

coords.shape, features.shape, colors.shape

(torch.Size([2014184, 3]), torch.Size([2014184, 3]), torch.Size([2014184, 3]))

In [None]:
# emebedding prompts
clip_model, preprocess = init_CLIP()

def gen_prompts(prompt_list):
    tokens = clip.tokenize(prompt_list).to(device)
    with torch.no_grad():
        prompt_embeds = clip_model.encode_text(tokens)
        prompt_embeds /= torch.norm(prompt_embeds, dim=-1, keepdim=True)
        prompt_embeds = torch.mean(prompt_embeds, dim=0)
    return prompt_embeds

def contrastive_embeds(embeds, alpha=0.8):
    new_embeds = []
    for i in range(embeds.shape[0]):
        cur = embeds[i]
        # rem = torch.cat((embeds[:i], embeds[i + 1:]), dim=0)
        # rem = torch.mean(rem, dim=0)
        # rem /= torch.norm(rem)
        # cur -= alpha * rem
        cur /= torch.norm(cur)
        new_embeds.append(cur)
    
    return torch.vstack(new_embeds)

queries = torch.vstack([gen_prompts(idx2prompts[i]) for i in range(len(idx2prompts))])
queries = contrastive_embeds(queries)
queries = queries.to(device)
#queries /= torch.norm(queries, keepdim=True)

queries.shape

torch.Size([21, 512])

In [6]:
torch.norm(queries[0])

tensor(1.0000, device='cuda:0')

In [None]:
# find gaussians that best match prompts

dataset = TensorDataset(features)
loader = DataLoader(dataset, batch_size=2048, shuffle=False)
idxs = []


with torch.no_grad():
    for data in tqdm(loader, desc="sim search", unit="batch"):
        x = data[0]
        x = ae.decode(x)    # batchsize x 512
        sim = x @ queries.T
        sim = torch.argmax(sim, dim=1)
        idxs.append(sim)

idxs = torch.concat(idxs)

sim search: 100%|██████████| 984/984 [00:08<00:00, 122.92batch/s]


In [None]:
c2i = {v:k for k,v in idx2class.items()}

chairs = idxs == c2i["a chair"]
beds = idxs == c2i["a bed"]
night= idxs == c2i["a nightstand"]
walls= idxs == c2i["a wall"]


# chairs = idxs == c2i["chairs"]
# beds = idxs == c2i["beds"]
# night= idxs == c2i["nightstands"]
# walls= idxs == c2i["walls"]

print(f"chairs: {torch.sum(chairs)}")
print(f"beds: {torch.sum(beds)}")
#print(f"windows: {torch.sum(windows)}")
#print(f"statues: {torch.sum(statues)}")
print(f"night: {torch.sum(night)}")
print(f"walls: {torch.sum(walls)}")

chairs: 6458
beds: 19084
night: 1196
walls: 37444


In [None]:
# color matches for current class red, and non-matches grey

colors[walls] = torch.tensor([1.0, 0.0, 0.0]).to(device)
colors[~walls] = torch.tensor([.5, .5, .5]).to(device)

In [None]:
def filter_ckpt(coords, colors, classbin):
    coords_f = coords[classbin]
    colors_f =  colors[classbin]
    return coords_f, colors_f

def save_ply(coords: torch.Tensor, colors: torch.Tensor, filename: str):
    coords_np = coords.detach().cpu().numpy()
    colors_np = colors.detach().cpu().numpy()

    if colors_np.dtype != np.uint8:
        colors_np = (colors_np * 255).clip(0, 255).astype(np.uint8)

    vertices = np.empty(coords_np.shape[0], dtype=[
        ("x", "f4"), ("y", "f4"), ("z", "f4"),
        ("red", "u1"), ("green", "u1"), ("blue", "u1")
    ])

    vertices["x"] = coords_np[:, 0]
    vertices["y"] = coords_np[:, 1]
    vertices["z"] = coords_np[:, 2]
    vertices["red"] = colors_np[:, 0]
    vertices["green"] = colors_np[:, 1]
    vertices["blue"] = colors_np[:, 2]

    # Create PlyElement and write to file
    ply = PlyElement.describe(vertices, "vertex")
    PlyData([ply], text=True).write(filename)


save_ply(coords, colors, "walls_SPLM.ply")

In [None]:
chairsnp = chairs.cpu().numpy()
bedsnp = beds.cpu().numpy()
wallsnp = walls.cpu().numpy()

chairsnp

array([False, False, False, ..., False, False, False], shape=(2014184,))

In [None]:
# save as new plyfile

# plyfile constructed during visual gaussian splatting.
plyfile = "/home/akshaysm/semantics/understanding_room1/point_cloud_29999.ply"
vertices = PlyData.read(plyfile)['vertex'].data

filtered_vertices = vertices[bedsnp]
filtered_ply = PlyElement.describe(filtered_vertices, "vertex")
PlyData([filtered_ply], text=False, byte_order='<').write("/home/akshaysm/semantics/understanding_room1/beds_gs_splm.ply")

In [13]:
plyfile = "/home/akshaysm/semantics/understanding_room1/point_cloud_chairs.ply"
vertices = PlyData.read(plyfile)['vertex'].data

vertices

array([(-0.13204901, -1.422234  ,  0.76465446, 0.64077586,  0.6375941 ,  0.5260317 ,  0.00709344,  0.00752561,  0.00800285, -0.0203096 ,  0.02501735, -0.02524622, -0.01980346, -0.02405356, -0.0040809 ,  0.0012599 , -0.00640167, -0.00667609, -0.00120154, -0.00567731, -0.00136321,  0.01465897,  0.0150622 , -0.00388402, -0.01949572, 0.02071368, -0.02112596, -0.01898693, -0.01960926, -0.00382663, -0.00219843, -0.00652901, -0.0067387 ,  0.00248392, -0.00544363,  0.00204266,  0.0156681 ,  0.01642465, -0.00877442, -0.02670435, 0.02257509, -0.02325078, -0.02643657, -0.02064864, -0.0052585 , -0.01179841, -0.01034671, -0.01087158,  0.01235199, -0.00844992,  0.01118252,  9.9849205, -3.8728185, -6.1704063, -5.9277606, 0.05318695,  0.03800137,  1.4349699 , 0.5080759),
       ( 0.31580058, -1.4559342 ,  0.8290082 , 0.43206057,  0.4049391 ,  0.20257333,  0.01067102,  0.01069162,  0.00254201, -0.00126558,  0.00938305, -0.00938071, -0.00123671, -0.00935895, -0.01439709,  0.00387776, -0.01444059, -0.014