In [12]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import pandas as pd


In [13]:
objaverse_dict = torch.load("../src/eval_data/objaverse_dict.pt")
ov_idx2category = objaverse_dict["idx2category"]
ov_category2idx = objaverse_dict["category2idx"]
ov_labels_all = objaverse_dict["labels_all"]
ov_logits_all = objaverse_dict["logits_all"]
objaverse_dict = None

ov_xyz = torch.load("../src/eval_data/ov_xyz.pt")
ov_rgb = torch.load("../src/eval_data/ov_rgb.pt")


ov_res_norm = torch.load("../src/eval_data/ov_confusion_norm.pt")
ov_confusion = torch.load("../src/eval_data/ov_confusion.pt")

In [11]:
def get_confusion_matrix(idx2category, labels_all, logits_all):
    # modelnet_dict = torch.load("src/eval_data/modelnet_dict.pt")

    cat_list = list(idx2category.values())
    
    print("labels_all: ", labels_all.shape)
    
    _, pred = logits_all.topk(1, 1, True, True)
    pred = pred.reshape(-1)
    
    true_pred = torch.argwhere(pred == labels_all).reshape(-1)
    false_pred = torch.argwhere(pred != labels_all).reshape(-1)
    
    res = torch.zeros(len(cat_list), len(cat_list))
    print("res: ", res.shape)
    for val in true_pred:
        idx = val.item()
        res[labels_all[idx], pred[idx]] += 1
    
    for val in false_pred:
        idx = val.item()
        res[labels_all[idx], pred[idx]] += 1
        # print(labels_all[idx], pred[idx])

    
    res_norm = F.softmax(res, dim=1)

    return res, res_norm

In [4]:
def plot_pointcloud(ov_xyz, ov_rgb, list_instance, apply_color=True, marker_size=4, opacity=0.2):
    # objaverse_dict = torch.load("eval_data/objaverse_dict.pt")
    # ov_xyz = torch.load("eval_data/ov_xyz.pt")
    # ov_rgb = torch.load("eval_data/ov_rgb.pt")
    # ov_idx2category = objaverse_dict["idx2category"]
    # ov_category2idx = objaverse_dict["category2idx"]
    # ov_labels_all = objaverse_dict["labels_all"]

    # idx = ov_category2idx[obj_name]
    # obj_instance = np.random.choice(torch.where(ov_labels_all == idx)[0])
    # print(f"obj_instance: {obj_instance}")

    for ins in list_instance:
        obj_xyz = ov_xyz[ins]
        obj_rgb = ov_rgb[ins]*255
        obj_rgb = obj_rgb.type(torch.int)
        df = pd.DataFrame(obj_xyz, columns=["x", "y", "z"])
        
        color_list = [f"rgb({r}, {g}, {b})" for [r, g, b] in obj_rgb]
        # color_list = ['rgb(227, 119, 194)']*7000 + ['rgb(119, 227, 152)']*3000
        
        fig = px.scatter_3d(df, x="x", y="y", z="z", opacity=0.1)
        if apply_color:
            fig = px.scatter_3d(df, x="x", y="y", z="z", color = color_list, color_discrete_map="identity", opacity=opacity)
        fig.update_traces(marker_size = marker_size)
        fig.show()
    

In [9]:
# Given a list of object, it will plot random instances of those objects. Total num_ins plots
def plot_multiple_ins(ov_xyz, ov_rgb, ov_idx2category, ov_category2idx, ov_labels_all, obj_list, num_ins):
    # ov_idx2category = objaverse_dict["idx2category"]
    # ov_category2idx = objaverse_dict["category2idx"]
    idx = torch.tensor([ov_category2idx[o] for o in obj_list])
    # idx = ov_category2idx[obj_name]
    # ov_labels_all = objaverse_dict["labels_all"]
    mask = torch.isin(ov_labels_all, idx)
    indices = torch.nonzero(mask).flatten()
    print(f"indices: {indices}")
    # all_ins = torch.where(ov_labels_all == idx)[0]
    list_ins = np.random.choice(indices, size = num_ins, replace=False)
    plot_pointcloud(ov_xyz, ov_rgb, list_ins, apply_color=True)
    objs = [ov_idx2category[o.item()] for o in ov_labels_all[list_ins]]
    return objs

In [6]:
# Where true is vodka but pred is water_bottle
def plot_multiple_confused_obj(ov_labels_all, ov_xyz, ov_rgb, true_obj, pred_obj_list, num_ins):
    # ov_idx2category = objaverse_dict["idx2category"]
    # ov_labels_all = objaverse_dict["labels_all"]
    
    idx = torch.tensor([ov_category2idx[o] for o in pred_obj_list])
    # idx = ov_category2idx[obj_name]
    ov_confusion
    
    all_ins = torch.where(ov_labels_all == idx)[0]
    list_ins = np.random.choice(all_ins, size = num_ins, replace=False)
    plot_pointcloud(ov_xyz, ov_rgb, list_ins, apply_color=True)

In [10]:
# objs = plot_multiple_ins(objaverse_dict, ov_xyz, ov_rgb, ['sherbert', 'icecream', 'popsicle', 'plow_(farm_equipment)', 'pocket_watch', 'pliers'], 4)
objs = plot_multiple_ins(ov_xyz, ov_rgb, ov_idx2category, ov_category2idx, ov_labels_all, ['shopping_bag'], 1)

indices: tensor([36005, 36006, 36007, 36008, 36009, 36010, 36011, 36012, 36013, 36014,
        36015, 36016, 36017, 36018, 36019, 36020, 36021, 36022, 36023, 36024,
        36025, 36026, 36027, 36028, 36029, 36030, 36031, 36032, 36033, 36034,
        36035, 36036, 36037, 36038, 36039, 36040, 36041, 36042, 36043, 36044,
        36045, 36046, 36047, 36048, 36049, 36050, 36051, 36052, 36053, 36054,
        36055, 36056, 36057])


In [13]:
ov_xyz.shape

torch.Size([46205, 10000, 3])

In [33]:
objs

['fighter_jet',
 'fighter_jet',
 'fighter_jet',
 'fighter_jet',
 'fighter_jet',
 'fighter_jet']