In [None]:
import os
import torch
import numpy as np
from glob import glob
from tqdm.notebook import tqdm
from os.path import join, exists
import open3d as o3d
import matplotlib.pyplot as plt
from itertools import combinations
import copy
from tabulate import tabulate
import clip

model, _ = clip.load("ViT-L/14@336px")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
import sys
sys.path.append('../')

from utils import *

In [None]:
# load all the required data
source_path = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/augmented/birds_new/scannet_3d/example/scene0024_00_vh_clean_2.pth"
fused_path = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/augmented/birds_new/fused/scene0024_00_0.pt"
distilled_path = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/augmented/birds_new/features_3D/scene0024_00_vh_clean_2_openscene_feat_distill.npy"

source_points, source_colors, source_labels = load_scene(source_path, False)

fused_f, filtered_pc, filtered_pc_c, filtered_pc_labels, indices = load_fused_features(fused_path,
                                                                              source_points, 
                                                                              source_colors,
                                                                              source_labels)
distilled_f = load_distilled_features(distilled_path, indices)

In [None]:
# info about scene with birds
# Nicobar Pigeon          class label = 20, 1797 points, on the table
# Eastern Rosella          class label = 21, 1787 points, on the stairs

In [None]:
# query
#query = ["a bird which is grey-breasted", "a bird which is brown-crowned", "a bird which is yellow-eyed"]
#query = ["a bird"]
#query = ["Eastern Rosella bird in a scene"]
query = ["nicobar pigeon bird in a scene"]
#query = ["a bird which has Slender-bodied.", 'a bird which has Long-tail.',
#                                  'a bird which has White-barred-wings.','a bird which has Red-legs.',
#                                  'a bird which has Grey-breast.']
        
query = ['a bird which has White-tipped wings.',
 'a bird which has Yellow belly.',
 'a bird which has Orange cheeks.',
 'a bird which has Broad wings.',
 'a bird which has White eye-ring..']
        
query = ["a bird"]    
        
similarity = highlight_query(query, "fused", "max", distilled_f, fused_f, filtered_pc, filtered_pc_c, device)

In [None]:
# manually set descriptors if necessary
descriptors = {"Eastern Rosella": ['a bird which is/has Orange-breast.',
  'a bird which is/has Blue-tail.',
  'a bird which is/has White-cheek.',
  'a bird which is/has Black-eye-ring.',
  'a bird which is/has Yellow-wing-bar.'],
               "nicobar pigeon" : ['a bird which is/has Purple-neck.',
  'a bird which is/has Red-breast.',
  'a bird which is/has Long-tail.',
  'a bird which is/has Black-beak.',
  'a bird which is/has Gray-crown.']
              }

In [None]:
SCANNET_LABELS_20 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa',
                     'table', 'door', 'window', 'bookshelf', 'picture','counter', 'desk', 'curtain', 'refrigerator', 'shower curtain',
                     'toilet', 'sink', 'bathtub', 'otherfurniture']
UNKNOWN_ID = 255
NO_FEATURE_ID = 256

for key, value in descriptors.items():
    SCANNET_LABELS_20.append(value)

In [None]:
SCANNET_LABELS_20

In [None]:
class_ious, class_accs, mean_iou, mean_acc, pred_ids = evaluate(SCANNET_LABELS_20, descriptors, model, "fused", "max" , distilled_f, fused_f, filtered_pc_labels)

In [None]:
class_ious

In [None]:
# combine all the added object labels into 20 
# gt_ids = np.where(np.logical_or(filtered_pc_labels == 21, filtered_pc_labels == 22), 20, filtered_pc_labels)

In [None]:
mean_acc

In [None]:
print_results_table(SCANNET_LABELS_20, class_ious, descriptors)

# experiment with different descriptor combinations

In [None]:
import clip
model, _ = clip.load("ViT-L/14@336px")

In [None]:
#parse descriptors from openai api with gpt, set _nr to the number of descriptors you'd want to retrieve
_nr = 10
_prompt = f'Generate {str(_nr)} visual descriptors for each of the following categories, they are bird species: [Nicobar Pigeon, Eastern Rosella]. The descriptors will be used for input queries for a CLIP model. The descriptors should be concise and distinct from the descriptors of the other classes. Do not focus on behavior, but purely on attributes which are recognizable by the CLIP model. The output should be in the following form as a string: *bird name*: *descriptor1*, *descriptor2*, etc."'
descriptors = descriptors_from_prompt(_prompt, verbose = True)

In [None]:
descriptors = {'Nicobar Pigeon': ['a bird which is/has Green-blue plumage.',
  'a bird which is/has Metallic-sheen feathers.',
  'a bird which is/has Slender body.',
  'a bird which is/has Long tail feathers.',
  'a bird which is/has Light brown head.',
  'a bird which is/has Red beak.',
  'a bird which is/has White eye-ring.',
  'a bird which is/has Pink feet.',
  'a bird which is/has Dark neck patch.',
  'a bird which is/has Yellow shoulder stripe..'],
 'Eastern Rosella': ['a bird which is/has Red head.',
  'a bird which is/has Red shoulder patches.',
  'a bird which is/has Blue wings.',
  'a bird which is/has White breast.',
  'a bird which is/has Yellow belly.',
  'a bird which is/has Orange-red tail.',
  'a bird which is/has Black beak.',
  'a bird which is/has Blue-green back.',
  'a bird which is/has White eye-ring.',
  'a bird which is/has Long tail feathers..']}

In [None]:
SCANNET_LABELS_20 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa',
                     'table', 'door', 'window', 'bookshelf', 'picture','counter', 'desk', 'curtain', 'refrigerator', 'shower curtain',
                     'toilet', 'sink', 'bathtub', 'otherfurniture']
UNKNOWN_ID = 255
NO_FEATURE_ID = 256

In [None]:
# to test without descriptors
comb_dict_list =[ {
  }]

In [None]:
# get combinations of 5 out of *_nr* descriptors for each class
comb_dict_list = combinations_descriptor(descriptors, 5)

In [None]:
comb_dict_list

In [None]:
filtered_pc_labels[np.where(filtered_pc_labels == 21)] = 20

In [None]:
np.unique(filtered_pc_labels)

In [None]:
# iterate all the combinations and store their results in a list
class_IoU_result_list, class_accs_result_list, mean_iou_result_list, mean_acc_result_list, pred_ids_list = try_diff_combs(SCANNET_LABELS_20, comb_dict_list, model, "fused", "max", distilled_f, fused_f, filtered_pc_labels)

In [None]:
# 0th index of the list corresponds to the descriptors that belongs to the 0th index of the comb_dict_list
class_IoU_result_list[0]

In [None]:
c1, c2= [], [] # Nicobar Pigeon, Easter Rosella
# store tp/ (tp + fp + fn) values in list per augmented class
for idx in range(len(class_IoU_result_list)):
    c1.append(class_IoU_result_list[idx]["Nicobar Pigeon"][0])
    c2.append(class_IoU_result_list[idx]["Eastern Rosella"][0])

In [None]:
comb_dict_list[c1.index(max(c1))]

In [None]:
max(c1), max(c2)

In [None]:
c2[c1.index(max(c1))]

In [None]:
# these 5 descriptors gives the highest class IoU for Easter Rosella
comb_dict_list[c2.index(max(c2))]['Eastern Rosella']

In [None]:
# these 5 descriptors gives the highest class IoU for Mouse-colored Tyrannulet
comb_dict_list[c1.index(max(c1))]['Nicobar Pigeon']

# visualization, birds as single category

In [None]:
pred_labels = pred_ids_list[c1.index(max(c1))].numpy()
other_color = np.array([0.773, 0.922, 0.651])
color_gt = np.tile(other_color, (len(filtered_pc_labels), 1))
color_pred = np.tile(other_color, (len(filtered_pc_labels), 1))
color_gt[np.where(filtered_pc_labels == 20)] = [1, 0.294, 0.165] # nicobar pigeon :red
color_pred[np.where(pred_labels == 20)] = [1, 0.294, 0.165] # nicobar pigeon :red
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(filtered_pc)
pcd.colors = o3d.utility.Vector3dVector(color_array)

pcd_pred = o3d.geometry.PointCloud()
pcd_pred.points = o3d.utility.Vector3dVector(np.asarray(filtered_pc) + [0,10,0])
pcd_pred.colors = o3d.utility.Vector3dVector(color_pred)

# Visualize the point cloud
o3d.visualization.draw_geometries([pcd,pcd_pred])


# visualization by color

In [None]:
pred_labels = pred_ids_list[c1.index(max(c1))].numpy()

In [None]:
pred_labels_bad = pred_ids_list[c1.index(max(c1))].numpy()

In [None]:
other_color = np.array([0.773, 0.922, 0.651])
color_gt = np.tile(other_color, (len(filtered_pc_labels), 1))
color_pred = np.tile(other_color, (len(filtered_pc_labels), 1))

In [None]:
color_gt[np.where(filtered_pc_labels == 20)] = [1, 0.294, 0.165] # nicobar pigeon :red
color_gt[np.where(filtered_pc_labels == 21)] = [0.024, 0.788, 1] # eastern rosella  :blue 

color_pred[np.where(pred_labels == 20)] = [1, 0.294, 0.165] # nicobar pigeon :red
color_pred[np.where(pred_labels == 21)] = [0.024, 0.788, 1] # eastern rosella  :blue 


In [None]:
# Create an Open3D point cloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(filtered_pc)
pcd.colors = o3d.utility.Vector3dVector(color_array)

pcd_pred = o3d.geometry.PointCloud()
pcd_pred.points = o3d.utility.Vector3dVector(np.asarray(filtered_pc) + [0,10,0])
pcd_pred.colors = o3d.utility.Vector3dVector(color_pred)

# Visualize the point cloud
o3d.visualization.draw_geometries([pcd,pcd_pred])

# experiments

In [None]:
import plotly.graph_objects as go
import plotly.io as pio

In [None]:
#for mean aggregation

# max scores for Nicobar Pigeon
_np = [0.3961, 0.6884, 0.5992, 0.6133, 0.5813, 0.5519, 0.5377, 0.5163, 0.5050, 0.4872, 0.4593]

# max scores for Eastern Rosella
_er = [0.0, 0.6626, 0.7036, 0.7115, 0.7289, 0.7233, 0.7122, 0.7138, 0.7006, 0.6887, 0.6778]

_numbers = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

In [None]:
# for max aggregation

# max scores for Nicobar Pigeon
_np = [0.3961, 0.6884, 0.7332, 0.7332, 0.7332, 0.7332, 0.7327, 0.7061, 0.6391, 0.4610, 0.1698]

# max scores for Eastern Rosella
_er = [0.0, 0.6626, 0.6815, 0.7330, 0.7346, 0.7346, 0.7321, 0.6792, 0.6790, 0.5396, 0.4600]


_numbers = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=_numbers, y=_np, mode='lines+markers', name='Nicobar Pigeon',  line=dict(width=6)))
fig.add_trace(go.Scatter(x=_numbers, y=_er, mode='lines+markers', name='Eastern Rosella', line=dict(width=6)))

fig.update_layout(title='Effect of Descriptors,  Aggregation : max',
                  xaxis_title='# of descriptors',
                  yaxis_title='max mIoU scores',
                  showlegend=True,
                  plot_bgcolor='rgba(250,250, 250,1)', 
                  font=dict(
                      family='Arial',  # Set the font family
                      size=26,         # Set the font size
                      color='black'    # Set the font color
                  )
                  )

# Show the plot
fig.show()

In [None]:
pio.write_image(fig, 'max agg.png', format='png', width=1200, height=1080, scale=4)

In [None]:
import textwrap

In [None]:
# bar chart for different prompts
x_labels = ['a bird', 'Nicobar Pigeon', 'a Nicobar Pigeon bird', 'a Nicobar Pigeon bird in a scene']
y_values = [0.2993, 0.4, 0.3856, 0.4141]

#x_labels = ['a bird', 'Eastern Rosella', 'a Nicobar Pigeon bird', 'a Nicobar Pigeon bird in a scene']
#y_values = [0.2993, 0.0, 0.3856, 0.4141]

# Create a bar graph using the go.Bar function
bar_graph = go.Bar(
    x=x_labels,  # x-axis labels
    y=y_values   # y-axis values
)

# Define the layout settings (optional)
layout = go.Layout(
    title='Different Prompts',
    xaxis=dict(
        title='Prompts',          # x-axis label
        tickmode='array',            # Set tickmode to 'array' for custom tick labels
        tickvals=list(range(len(x_labels))),  # Set tick positions           # Set tick labels
        tickangle=0,                 # Rotate labels to 0 degrees (horizontal)
        automargin=True,
        # Automatically adjust margins to fit labels
        tickfont=dict(size=30),  
        ticktext=["<br>".join(textwrap.wrap(label, width=12)) for label in x_labels]# Set font size for tick labels
    ),  # x-axis label
    yaxis=dict(title='mIoU'),
    plot_bgcolor='rgba(250,250, 250,1)',
    font=dict(family='Arial',  
            size=26,         
            color='black'),
   # y-axis label
)

# Create the figure and add the bar graph to it
fig = go.Figure(data=[bar_graph], layout=layout)
fig.update_traces(
    marker=dict(line=dict(width=1), color='green'),  # Set the border width of the bars
    width=0.2                        # Set the width of the bars (0.4 means 40% of the available space)
)
# Add a horizontal line on the maximum y value

max_y_value = max(y_values)
fig.add_shape(
    type='line',
    x0=-0.5,   # Starting x position (corresponding to the first bar)
    x1=len(x_labels) - 0.5,  # Ending x position (corresponding to the last bar)
    y0=max_y_value,  # y position of the horizontal line (maximum y value)
    y1=max_y_value,  # y position of the horizontal line (maximum y value)
    line=dict(color='red', width=2),  # Line properties (color and width)
)
fig.update_layout(
    uniformtext_minsize=8,  # Set the minimum size of text to avoid overlapping
    uniformtext_mode='hide' # Hide text when it does not fit
)
# Display the graph
fig.show()

In [None]:
pio.write_image(fig, 'prompts.png', format='png', width=1220, height=1080, scale=4)

In [None]:
# tested descriptors for different number of descriptors
{'Nicobar Pigeon': ['a bird which is/has Green-blue plumage.',
  'a bird which is/has Metallic-sheen feathers.',
  'a bird which is/has Slender body.',
  'a bird which is/has Long tail feathers.',
  'a bird which is/has Light brown head.',
  'a bird which is/has Red beak.',
  'a bird which is/has White eye-ring.',
  'a bird which is/has Pink feet.',
  'a bird which is/has Dark neck patch.',
  'a bird which is/has Yellow shoulder stripe..'],
 'Eastern Rosella': ['a bird which is/has Red head.',
  'a bird which is/has Red shoulder patches.',
  'a bird which is/has Blue wings.',
  'a bird which is/has White breast.',
  'a bird which is/has Yellow belly.',
  'a bird which is/has Orange-red tail.',
  'a bird which is/has Black beak.',
  'a bird which is/has Blue-green back.',
  'a bird which is/has White eye-ring.',
  'a bird which is/has Long tail feathers..']}

# ---------------------------------------------------------------------

In [None]:
# should be the preprocessed file path
sample_path_0 = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/augmented/birds/scannet_3d/example/scene0000_00_vh_clean_2.pth"
#sample_path_1 = "D:/AT3DCV_Data/Preprocessed_OpenScene/data/scannet_3d/train/scene0000_01_vh_clean_2.pth"
#sample_path_2 = "D:/AT3DCV_Data/Preprocessed_OpenScene/data/scannet_3d/train/scene0000_02_vh_clean_2.pth"

In [None]:
sample_0 = torch.load(sample_path_0) # coords,colors,labels
#sample_1 = torch.load(sample_path_1) # coords,colors,labels
#sample_2 = torch.load(sample_path_2) # coords,colors,labels

In [None]:
len(sample_0[0])

In [None]:
# aggregating all of the partial point clouds of the same scene (they don't overlap perfectly)
#sample_points = np.concatenate((sample_0[0], sample_1[0], sample_2[0]))
#sample_colors = np.concatenate((sample_0[1], sample_1[1], sample_2[1]))

# single partial point cloud
sample_points  = sample_0[0]
sample_colors = sample_0[1]
sample_labels = sample_0[2]

In [None]:
#to view original scene
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(sample_points))
#original colors
pcd.colors = o3d.utility.Vector3dVector(np.asarray(sample_colors))
#------
#paint uniform
#sample_paint_uniform = np.asarray([200,200,200])/255.0 #redish
#pcd.paint_uniform_color(sample_paint_uniform)
o3d.visualization.draw_geometries([pcd])

# load fused features

In [None]:
# should be the fused feature path
feature_path = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/augmented/birds/fused/scene0000_00_0.pt"

In [None]:
feature = torch.load(feature_path)

In [None]:
feature["mask_full"].shape

In [None]:
feature["feat"].shape

In [None]:
# Get the indices where the mask is True
indices = torch.nonzero(feature["mask_full"]).squeeze()

In [None]:
filtered_point_cloud = sample_points[indices, :]
filtered_point_cloud_colors = sample_colors[indices, :]
filtered_point_cloud_labels = sample_labels[indices]
gt_ids = filtered_point_cloud_labels

In [None]:
np.unique(filtered_point_cloud_labels)

In [None]:
# Replace every occurrence of 21 with 20 if necessary
gt_ids= np.where(filtered_point_cloud_labels == 21.0, 20.0, filtered_point_cloud_labels)
gt_ids= np.where(gt_ids == 22.0, 20.0, gt_ids)
# gt_ids = filtered_point_cloud_labels

In [None]:
np.unique(gt_ids)

In [None]:
unique_values, counts = np.unique(gt_ids, return_counts=True)

In [None]:
counts

In [None]:
filtered_point_cloud.shape

# using clip model

In [None]:
import clip
model, preprocess = clip.load("ViT-L/14@336px")

In [None]:
# highlight with a threshold
# type the query here 
query = ["dragon"]

with torch.no_grad():
    all_text_embeddings = []
    for category in tqdm(query):
        texts = clip.tokenize(category)  #tokenize
        texts = texts.cuda()
        text_embeddings = model.encode_text(texts)  #embed with text encoder
        text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
        text_embedding = text_embeddings.mean(dim=0)
        text_embedding /= text_embedding.norm()
        all_text_embeddings.append(text_embedding)

    all_text_embeddings = torch.stack(all_text_embeddings, dim=1)

# normalizing 
fused_f = (feature["feat"]/(feature["feat"].norm(dim=-1, keepdim=True)+1e-5)).half()
# calculating similarity matrix
# similarity_matrix = torch.matmul(feature["feat"].cuda(), all_text_embeddings) # 
similarity_matrix = fused_f.cuda() @ all_text_embeddings    
    
# set higher to increase the certainty (not always correct)
threshold_percentage = 0.9
cap = similarity_matrix.max().item()
found_indices = torch.nonzero(similarity_matrix > cap*threshold_percentage, as_tuple=False).squeeze().T[0]

# creating pc
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud))
pcd.colors = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud_colors))

found_region = pcd.select_by_index(found_indices.tolist())
found_region.paint_uniform_color([1.0, 0, 0]) # paint related points to red
rest = pcd.select_by_index(found_indices.tolist(), invert=True)
o3d.visualization.draw_geometries([rest,found_region])

In [None]:
# highlight with a heatmap
# type the query here 
# query = ["deathwing"]
# query = [" a blue-faced, yellow-crowned, white-breasted, black-eyed, long-billed, hooked-beak, yellow-beaked, yellow-breasted, yellow-throated and black-tailed bird"]

# mouse-colored tyrannulet 
query = [["grey-bodied","yellow-breasted","black-crowned",
          "white-eyed","black-winged","yellow-throated",
          "white-breasted","yellow-billed","grey-headed",
          "long-tailed","bird"]]

# diamong firetail
query = [["red-breasted","black-crowned","gold-winged",
          "black-winged","white-eyed","yellow-billed",
          "red-headed","black-tailed","long-tailed",
          "white-breasted","bird"]]



#query = ["bird"]
#query = [["Mouse-colored Tyrannulet bird"]]

with torch.no_grad():
    all_text_embeddings = []
    for category in tqdm(query):
        texts = clip.tokenize(category)  #tokenize
        texts = texts.cuda()
        text_embeddings = model.encode_text(texts)  #embed with text encoder
        text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
        text_embedding = text_embeddings.mean(dim=0)
        text_embedding /= text_embedding.norm()
        all_text_embeddings.append(text_embedding)

    all_text_embeddings = torch.stack(all_text_embeddings, dim=1)

# normalizing 
fused_f = (feature["feat"]/(feature["feat"].norm(dim=-1, keepdim=True)+1e-5)).half()
# calculating similarity matrix
# similarity_matrix = torch.matmul(feature["feat"].cuda(), all_text_embeddings) # 
similarity_matrix = fused_f.cuda() @ all_text_embeddings    

# creating pc
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud))
pcd.colors = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud_colors))

# heatmap
cmap = plt.get_cmap('cividis')

# normalize the tensor to the range [0, 1]
normalized_tensor = (similarity_matrix - torch.min(similarity_matrix)) / (torch.max(similarity_matrix) - torch.min(similarity_matrix))

colors = cmap(normalized_tensor.detach().cpu().numpy().squeeze())
pcd_heatmap = o3d.geometry.PointCloud()

pcd_heatmap.points = o3d.utility.Vector3dVector(pcd.points)
pcd_heatmap.colors = o3d.utility.Vector3dVector(colors[:, :3])

#transform heatmap to the side
pcd_heatmap.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) + [0,10,0])

o3d.visualization.draw_geometries([pcd, pcd_heatmap])

# mIoU evaluation

In [None]:
SCANNET_LABELS_20 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa',
                     'table', 'door', 'window', 'bookshelf', 'picture','counter', 'desk', 'curtain', 'refrigerator', 'shower curtain',
                     'toilet', 'sink', 'bathtub', 'otherfurniture']
UNKNOWN_ID = 255
NO_FEATURE_ID = 256

SCANNET_LABELS_20.append(query[0])
#SCANNET_LABELS_20.append("bird")

CLASS_LABELS = SCANNET_LABELS_20

In [None]:
with torch.no_grad():
    label_embeds = []
    for category in tqdm(SCANNET_LABELS_20):
        texts = clip.tokenize(category)  #tokenize
        texts = texts.cuda()
        text_embeddings = model.encode_text(texts)  #embed with text encoder
        text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
        text_embedding = text_embeddings.mean(dim=0)
        text_embedding /= text_embedding.norm()
        label_embeds.append(text_embedding)

    label_embeds = torch.stack(label_embeds, dim=1)


In [None]:
print('classes          IoU')
print('----------------------------')
for i in range(N_CLASSES):
    label_name = CLASS_LABELS[i]
    if not isinstance(label_name, str): label_name = target_label
    try:
        print('{0:<14s}: {1:>5.5f}   ({2:>6d}/{3:<6d})'.format(
                label_name,
                class_ious[label_name][0],
                class_ious[label_name][1],
                class_ious[label_name][2]))
    except:
        print(label_name + ' error!')
        continue

In [None]:
import openai

openai.api_key = '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'

response = openai.Completion.create(
  engine="text-davinci-003",
  prompt="Could you generate 5 visual descriptors for each of the following object classes, they are bird species: [Blue-faced Honeyeate, Diamond Firetail, Mouse-colored Tyrannulet]. The descriptors will be used for input queries for a CLIP model. The descriptors should be concise and distinct from one another. Do not focus on behavior, but purely on attributes which are recognizable by the CLIP model. The output should be in the following form, without any additional text: object class 1, visual descriptor 1.1, visual descriptor 1.2",

  temperature=0.5,
  max_tokens=200
)

In [None]:
# old version of the aggregating text embeddings, it's not properly working
def highlight_query(query, feature_type, model, distill, fused, fpc, fpcc, device):
    
    
    with torch.no_grad():
        all_text_embeddings = []
        for category in tqdm(query):
            texts = clip.tokenize(category)  #tokenize
            texts = texts.to(device)
            text_embeddings = model.encode_text(texts)  #embed with text encoder
            text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
            text_embedding = text_embeddings.mean(dim=0)
            text_embedding /= text_embedding.norm()
            all_text_embeddings.append(text_embedding)

        all_text_embeddings = torch.stack(all_text_embeddings, dim=1)

        
    if feature_type == "fused":
        similarity_matrix = fused.to(device) @ all_text_embeddings
    elif feature_type == "distilled":
        similarity_matrix = distill.to(device) @ all_text_embeddings
    elif feature_type == "ensembled":
        pred_fusion = fused.to(device) @ all_text_embeddings
        pred_distill = distill.to(device) @ all_text_embeddings
        feat_ensemble = distill.clone().half()
        mask_ = pred_distill.max(dim=-1)[0] < pred_fusion.max(dim=-1)[0]
        feat_ensemble[mask_] = fused_f[mask_]
        similarity_matrix = feat_ensemble @ all_text_embeddings
        
    print(similarity_matrix.shape)
    # creating pc
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(np.asarray(fpc))
    pcd.colors = o3d.utility.Vector3dVector(np.asarray(fpcc))

    # heatmap
    cmap = plt.get_cmap('cividis')

    # normalize the tensor to the range [0, 1]
    normalized_tensor = (similarity_matrix - torch.min(similarity_matrix)) / (torch.max(similarity_matrix) - torch.min(similarity_matrix))

    colors = cmap(normalized_tensor.detach().cpu().numpy().squeeze())
    pcd_heatmap = o3d.geometry.PointCloud()

    pcd_heatmap.points = o3d.utility.Vector3dVector(pcd.points)
    pcd_heatmap.colors = o3d.utility.Vector3dVector(colors[:, :3])

    #transform heatmap to the side
    pcd_heatmap.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) + [0,10,0])

    o3d.visualization.draw_geometries([pcd, pcd_heatmap])
    
def evaluate(labelset, descriptors, feature_type, model, distill, fused, gt_ids):
    
    with torch.no_grad():
        label_embeds = []
        for category in tqdm(labelset):
            texts = clip.tokenize(category)  #tokenize
            texts = texts.cuda()
            text_embeddings = model.encode_text(texts)  #embed with text encoder
            text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
            text_embedding = text_embeddings.mean(dim=0)
            text_embedding /= text_embedding.norm()
            label_embeds.append(text_embedding)

        label_embeds = torch.stack(label_embeds, dim=1)
        
    if feature_type == "fused":
        similarity_matrix = fused.to(device) @ label_embeds
    elif feature_type == "distilled":
        similarity_matrix = distill.to(device) @ label_embeds
    elif feature_type == "ensembled":
        pred_fusion = fused.to(device) @ label_embeds
        pred_distill = distill.to(device) @ label_embeds
        feat_ensemble = distill.clone().half()
        mask_ = pred_distill.max(dim=-1)[0] < pred_fusion.max(dim=-1)[0]
        feat_ensemble[mask_] = fused_f[mask_]
        similarity_matrix = feat_ensemble.to(device) @ label_embeds
        
    pred_ids = torch.max(similarity_matrix, 1)[1].detach().cpu()    
    
    N_CLASSES = len(labelset)
    confusion = confusion_matrix(pred_ids, gt_ids, N_CLASSES)
    class_ious = {}
    class_accs = {}
    mean_iou = 0
    mean_acc = 0
    
    count = 0
    for i in range(N_CLASSES):
        label_name = labelset[i]

        if not isinstance(label_name, str): 
            for key, value in descriptors.items():
                if value == label_name:
                    label_name = key
                    
        if (gt_ids==i).sum() == 0: # at least 1 point needs to be in the evaluation for this class
            continue


        class_ious[label_name] = get_iou(i, confusion)
        class_accs[label_name] = class_ious[label_name][1] / (gt_ids==i).sum()
        count+=1

        mean_iou += class_ious[label_name][0]
        mean_acc += class_accs[label_name]


    mean_iou /= N_CLASSES
    mean_acc /= N_CLASSES
    
    return class_ious, class_accs, mean_iou, mean_acc
