In [1]:
import os
import clip
import torch
import numpy as np
from tqdm import tqdm
from clip_images import init_CLIP
from ae_training import load_model
from plyfile import PlyData, PlyElement
from torch.utils.data import DataLoader, TensorDataset
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# load all necessary data

# scene specific auto encoder
ae,_,_,_ = load_model("/home/akshaysm/semantics/remember_room1/ae_model.pth")
clip_model, preprocess = init_CLIP()

# semantic gs result
ckpt_sem = torch.load("/home/akshaysm/semantics/remember_room1/ckpt_sem.pt")

# visual gs result
ckpt_col = torch.load("/home/akshaysm/semantics/remember_room1/ckpt_col.pt")

coords = ckpt_sem["splats"]["means"].to(device)
features = ckpt_sem["splats"]["features"].to(device)
colors = ckpt_col['splats']["sh0"].squeeze(1).to(device)

coords.shape, features.shape, colors.shape

(torch.Size([2188803, 3]), torch.Size([2188803, 3]), torch.Size([2188803, 3]))

In [3]:
# query for which we find the best corresponding feature in the data

string = "a wall"
wall_list = [
    f"an image of {string}",
    f"a photo of {string} in a bedroom",
    f"a closeup of {string} in an indoor scene"
]

tokens = clip.tokenize(wall_list).to(device)
with torch.no_grad():
    wall_embeds = clip_model.encode_text(tokens)
    wall_embeds /= torch.norm(wall_embeds, dim=-1, keepdim=True)
    wall_embeds = torch.mean(wall_embeds, dim=0)
    wall_embeds /= torch.norm(wall_embeds, dim=-1, keepdim=True)

wall_embeds.shape, torch.norm(wall_embeds)

(torch.Size([512]), tensor(1.0000, device='cuda:0'))

In [4]:
# find the best feature describing string in the splatting.

dataset = TensorDataset(features)
loader = DataLoader(dataset, batch_size=2048, shuffle=False)
best_cos_sim = 0
best_feature = None

with torch.no_grad():
    for data in tqdm(loader, desc="find best feature", unit="batch"):
        x = data[0]
        x = ae.decode(x)
        sim = x @ wall_embeds.T
        high_idx = torch.argmax(sim)
        high_sim = sim[high_idx].item()
        if high_sim > best_cos_sim:
            best_cos_sim = high_sim
            best_feature = x[high_idx]

print(best_cos_sim, best_feature.shape, torch.norm(best_feature))

  sim = x @ wall_embeds.T
find best feature: 100%|██████████| 1069/1069 [00:08<00:00, 119.14batch/s]

0.31820613145828247 torch.Size([512]) tensor(1., device='cuda:0')





In [None]:
# use the best feature to match gaussians in the scene

dataset = TensorDataset(features)
loader = DataLoader(dataset, batch_size=2048, shuffle=False)
cosine_sims = []

with torch.no_grad():
    for data in tqdm(loader, desc="sim search", unit="batch"):
        x = data[0]
        x = ae.decode(x)
        sim = x @ best_feature.T
        cosine_sims.append(sim)

cosine_sims = torch.concat(cosine_sims)
print(torch.std_mean(cosine_sims))


sim search: 100%|██████████| 1069/1069 [00:08<00:00, 131.50batch/s]

(tensor(0.0299, device='cuda:0'), tensor(0.8624, device='cuda:0'))





In [None]:
# thresholding for selecting gaussians
# 0.93 for chair
# 0.942 for bed
# 0.96 for walls
# 0.985 for TV

filtered_sim = cosine_sims > 0.92 # ????
torch.sum(filtered_sim)

tensor(39868, device='cuda:0')

In [None]:
# save as new plyfile with selected guassians colored red

filtered_np = filtered_sim.cpu().numpy()

plyfile = "/home/akshaysm/semantics/remember_room1/point_cloud_29999.ply"
plydata = PlyData.read(plyfile)
vertices_array = np.array(plydata['vertex'].data)

vertices_array['f_dc_0'][filtered_np] = 10.0
vertices_array['f_dc_1'][filtered_np] = 0.0
vertices_array['f_dc_2'][filtered_np] = 0.0

vertex_element = PlyElement.describe(vertices_array, 'vertex')
PlyData([vertex_element], text=False, byte_order='<').write("/home/akshaysm/semantics/remember_room1/rem_wred_max_best.ply")
# binary and little endian for Unity Gaussian Splatting