In [None]:
import time
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import pytorch3d.ops
from plyfile import PlyData, PlyElement
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
from argparse import ArgumentParser, Namespace
import cv2
from tqdm import tqdm
from arguments import ModelParams, PipelineParams, ModelHiddenParams
from scene import Scene, GaussianModel, FeatureGaussianModel
from gaussian_renderer import render, render_contrastive_feature, render_segmentation
from segment_anything import (SamAutomaticMaskGenerator, SamPredictor,
                              sam_model_registry)
from utils.sh_utils import SH2RGB
import imageio
from utils.segment_utils import *

%load_ext autoreload
%autoreload 2

In [None]:
import os
FEATURE_DIM = 32

DATA_ROOT = './data/hypernerf/split-cookie'
# the model path, same to the --model_path in the training, after train_scene.py this folder will be created but named randomly
MODEL_PATH = './output/hypernerf/split-cookie'
# 'lego_real_night_radial'
SPIN_SCENE_NAME = 'lego_real_night_radial'
NVOS_SCENE_NAME = 'orchids'
FEATURE_GAUSSIAN_ITERATION = 14000

SAM_PROJ_PATH = os.path.join(MODEL_PATH, f'point_cloud/iteration_{str(FEATURE_GAUSSIAN_ITERATION)}/sam_proj.pt')
NEG_PROJ_PATH = os.path.join(MODEL_PATH, f'point_cloud/iteration_{str(FEATURE_GAUSSIAN_ITERATION)}/neg_proj.pt')
FEATURE_PCD_PATH = os.path.join(MODEL_PATH, f'point_cloud/iteration_{str(FEATURE_GAUSSIAN_ITERATION)}/contrastive_feature_point_cloud.ply')
SCENE_PCD_PATH = os.path.join(MODEL_PATH, f'point_cloud/iteration_{str(FEATURE_GAUSSIAN_ITERATION)}/point_cloud.ply')

SAM_ARCH = 'vit_h'
SAM_CKPT_PATH = '/data/sxj/SegAnyGAussians/dependencies/sam_ckpt/sam_vit_h_4b8939.pth'

In [None]:
nonlinear = torch.nn.Sequential(
    torch.nn.Linear(256, 64, bias=True),
    torch.nn.LayerNorm(64),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(64, 64, bias=True),
    torch.nn.LayerNorm(64),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(64, FEATURE_DIM, bias=True),
)
nonlinear.load_state_dict(torch.load(SAM_PROJ_PATH))
nonlinear = nonlinear.cuda()
nonlinear.eval()

parser = ArgumentParser(description="Testing script parameters")
model = ModelParams(parser, sentinel=True)
# op = OptimizationParams(parser)
pipeline = PipelineParams(parser)
hp = ModelHiddenParams(parser)
parser.add_argument("--iteration", default=-1, type=int)
parser.add_argument('--target', default='scene', const='scene', nargs='?', choices=['scene', 'seg', 'feature', 'coarse_seg_everything', 'contrastive_feature', 'xyz'])
parser.add_argument('--idx', default=0, type=int)
parser.add_argument("--configs", type=str, default = "./arguments/hypernerf/default.py")
parser.add_argument('--precomputed_mask', default=None, type=str)

args = get_combined_args(parser, MODEL_PATH)

if args.configs:
    import mmcv
    from utils.params_utils import merge_hparams
    config = mmcv.Config.fromfile(args.configs)
    args = merge_hparams(args, config)

dataset = model.extract(args)
hyperparam = hp.extract(args)
dataset.need_features = True
dataset.need_masks = True

# gaussians = GaussianModel(dataset.sh_degree, hyperparam)
gaussians = None
feature_gaussians = FeatureGaussianModel(dataset.sh_degree, FEATURE_DIM, hyperparam)
scene = Scene(dataset, gaussians, feature_gaussians, load_iteration=-1, feature_loaded_iteration=-1, target='contrastive_feature')

xyz = feature_gaussians.get_xyz
point_features = feature_gaussians.get_sam_features
# print(xyz.device)
# print(point_features.device)

In [None]:
model_type = SAM_ARCH
sam = sam_model_registry[model_type](checkpoint=SAM_CKPT_PATH).to('cuda')
predictor = SamPredictor(sam)

In [None]:
cameras = [i for i in scene.getVideoCameras()]
print("There are",len(cameras),"views in the dataset.")

In [None]:
ref_img_camera_id = 0
mask_img_camera_id = 0

with torch.no_grad():
    view = cameras[ref_img_camera_id]
    bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
    rendering = render(view, feature_gaussians, pipeline, background, cam_type="blender")["render"]
    print(rendering.shape)
    img = to8b(rendering).transpose(1,2,0)
    plt.imshow(img)
    # plt.axis('off')
    plt.show()
    # img = cv2.resize(img.permute([1,2,0]).detach().cpu().numpy().astype(np.uint8), dsize=(1024,1024), fx=1, fy=1, interpolation=cv2.INTER_LINEAR)

    predictor.set_image(img)
    sam_feature = predictor.features
    # sam_feature = view.original_features

    bg_feature = [0 for i in range(FEATURE_DIM)]
    background_feature = torch.tensor(bg_feature, dtype=torch.float32, device="cuda")

    start_time = time.time()
    rendered_feature = render_contrastive_feature(view, feature_gaussians, pipeline.extract(args), background_feature)['render']
    time1 = time.time() - start_time
    time1 = 0

H, W = sam_feature.shape[-2:]
print("sam_features: ", sam_feature.shape)
# print("rendered_feature: ", rendered_feature.shape)

# print("time1: ", time1)

In [None]:
input_point = np.array([[300, 400]])
input_label = np.ones(len(input_point))

# plt.figure(figsize=(10,10))
plt.imshow(img)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()  

In [None]:
with torch.no_grad():
    vanilla_masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
    )
    
l = len(vanilla_masks)

for i, (mask, score) in enumerate(zip(vanilla_masks, scores)):
    plt.figure()
    # plt.subplot(1, l, i+1)
    plt.imshow(img)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()  

In [None]:
print("vanilla_masks: ", vanilla_masks.shape)
masks = torch.nn.functional.interpolate(torch.from_numpy(vanilla_masks).float().unsqueeze(0), (64,64), mode='bilinear').squeeze(0).cuda()
masks[masks > 0.5] = 1
masks[masks != 1] = 0
print("masks resized: ", masks.shape)

In [None]:
# mask_id = 0
mask_id = np.argmax(scores)
origin_ref_mask = torch.tensor(vanilla_masks[mask_id]).float().cuda()

# if origin_ref_mask.shape != (64,64):
#     ref_mask = torch.nn.functional.interpolate(origin_ref_mask[None, None, :, :], (64,64), mode='bilinear').squeeze().cuda()
#     ref_mask[ref_mask > 0.5] = 1
#     ref_mask[ref_mask != 1] = 0
# else:
#     ref_mask = origin_ref_mask
    
# sam features

low_dim_features = nonlinear(
    sam_feature.view(-1, H*W).permute([1,0])
).squeeze().permute([1,0]).reshape([-1, H, W])

# Feature Field query
# mask_low_dim_features = ref_mask.unsqueeze(0) * low_dim_features
# mask_pooling_prototype = mask_low_dim_features.sum(dim = (1,2)) / torch.count_nonzero(ref_mask)
ref_mask = torch.nn.functional.interpolate(origin_ref_mask[None, None, :, :], (358, 200), mode='bilinear').squeeze().cuda()
ref_mask[ref_mask > 0.5] = 1
ref_mask[ref_mask != 1] = 0
mask_low_dim_features = ref_mask.unsqueeze(0) * rendered_feature
mask_pooling_prototype = mask_low_dim_features.sum(dim = (1,2)) / torch.count_nonzero(ref_mask)


In [None]:
import kmeans_pytorch
import importlib
importlib.reload(kmeans_pytorch)
from kmeans_pytorch import kmeans

# K-means or not

bg_color = [0 for i in range(32)]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
rendered_feature = render_contrastive_feature(view, feature_gaussians, pipeline.extract(args), background)['render']

# similarity = torch.cosine_similarity(mask_pooling_prototype.cuda(), rendered_feature.permute([1, 2, 0]), dim=-1)

temp_mask = torch.einsum('C,CHW->HW', mask_pooling_prototype.cuda(), rendered_feature)
similarity = temp_mask.clone().detach()
temp_mask = torch.nn.functional.interpolate(similarity.float().unsqueeze(0).unsqueeze(0), (64,64), mode='bilinear').squeeze().cuda()
temp_mask[temp_mask > 0] = 1
temp_mask[temp_mask != 1] = 0


ref_mask = torch.nn.functional.interpolate(origin_ref_mask[None, None, :, :], (64, 64), mode='bilinear').squeeze().cuda()
ref_mask[ref_mask > 0.5] = 1
ref_mask[ref_mask != 1] = 0
iob = (temp_mask * ref_mask).sum(dim = (-1, -2)) / ref_mask.sum()
print(iob.item())

if iob > 0.9:
    fmask_prototype = mask_pooling_prototype.unsqueeze(0)
# else:
#     # fmask_prototype = mask_pooling_prototype.unsqueeze(0)
#     downsampled_masks = torch.nn.functional.adaptive_avg_pool2d(ref_mask.unsqueeze(0).unsqueeze(0), (8,8)).squeeze()
#     downsampled_features = torch.nn.functional.adaptive_avg_pool2d(mask_low_dim_features.unsqueeze(0), (8,8)).squeeze(0)
#     downsampled_features /= downsampled_masks.unsqueeze(0)

#     downsampled_masks[downsampled_masks != 0]= 1
#     init_prototypes = downsampled_features[:, downsampled_masks.bool()].permute([1,0])


#     masked_sam_features = low_dim_features[:, ref_mask.bool()]
#     masked_sam_features = masked_sam_features.permute([1,0])
    
#     num_clusters = init_prototypes.shape[0]
#     print(num_clusters)
    
#     if num_clusters <= 1:
#         num_clusters = min(int(masked_sam_features.shape[0] ** 0.5), 32)
#         init_prototypes = []

#     cluster_ids_x, cluster_centers = kmeans(
#         X=masked_sam_features, num_clusters=num_clusters, distance='euclidean', device=torch.device('cuda')
#     )

    # temp_mask = torch.sigmoid(torch.einsum('NC,CHW->NHW', cluster_centers.cuda(), rendered_feature))
    # temp_mask = torch.nn.functional.interpolate(temp_mask.float().unsqueeze(1), (64,64), mode='bilinear').squeeze().cuda()
    # temp_mask[temp_mask >= 0.5] = 1
    # temp_mask[temp_mask != 1] = 0
    # temp_mask = temp_mask.squeeze()

    # ioa = (temp_mask * ref_mask[None,:,:]).sum(dim = (-1, -2)) / (temp_mask.sum(dim = (-1, -2)) + 1e-5)
    # iob = (temp_mask * ref_mask[None,:,:]).sum(dim = (-1, -2)) / ref_mask.sum()
    # ioa = ioa.squeeze()
    # iob = iob.squeeze()
    # cluster_mask = ioa > 0.75

    # # NMS
    # for i in range(len(cluster_mask)):
    #     if not cluster_mask[i]:
    #         continue

    #     for j in range(i+1, len(cluster_mask)):
    #         if not cluster_mask[j]:
    #             continue

    #         if (temp_mask[j] * temp_mask[i]).sum() / ((temp_mask[j] + temp_mask[i]).sum() - (temp_mask[j] * temp_mask[i]).sum()) > 0.75:
    #             if ioa[i] > ioa[j]:
    #                 cluster_mask[j] = False
    #             else:
    #                 cluster_mask[i] = False
    #                 break
    
    # cluster_centers = cluster_centers.cuda()
    # cluster_centers = cluster_centers[cluster_mask, :]
#     fmask_prototype = torch.cat([mask_pooling_prototype.unsqueeze(0), cluster_centers.cuda()], dim = 0)

In [None]:
# similarity = torch.einsum('NC, CHW -> NHW', fmask_prototype, rendered_feature)
# similarity = similarity.max(0)[0]

# normalized_similarity_mask = (similarity_mask - similarity_mask.min()) / (similarity_mask.max() - similarity_mask.min())
# feature_map = torch.sigmoid(similarity).detach().cpu().numpy()

plt.imshow(similarity.detach().cpu().numpy())
# plt.imshow(feature_map)
plt.axis('off')
plt.show()