In [1]:
import sys
import torch
sys.path.insert(0, '..')
from datasets import ShapeNetDataset, PointCloudNormalize
from torch.utils.data import DataLoader
from models import PointNet
import torch.nn.functional as F
import k3d
import seaborn as sns
import numpy as np
from distinctipy import distinctipy

In [2]:
def apply_projector(x, model, layer_idx):
    batch_size, dim, n_points = x.shape
    x = x.transpose(2, 1).contiguous().view(-1, dim)
    x = model.mlp(x, idx=layer_idx).contiguous().view(batch_size, n_points, -1).transpose(2, 1)
    return x

def center(points):
    out = points.clone()
    '''y = points[:, 1, :].clone()
    out[:, 1, :] = points[:, 2, :].clone()
    out[:, 2, :] = y'''
    return out.transpose(2, 1).numpy()

@torch.no_grad()
def get_features(model, x, layer_idx=-1):
    model.eval()
    features = model.forward_features(x)
    proj = apply_projector(features, model, layer_idx)
    proj = F.normalize(proj, dim=1)

    return proj

def get_similarity_scores(features, query_sample_idx, query_point_idx, support_sample_idx):
    q_feat = features[query_sample_idx, :, query_point_idx]
    s_feats = features[support_sample_idx]
    
    return (q_feat @ s_feats).cpu()
    

def convert_labels2colors(labels):
    colors = np.zeros((labels.shape[0], labels.shape[1], 3))

    for i in range(labels.shape[0]):
        for j in range(labels.shape[1]):
            colors[i, j] = sem_palette[labels[i, j]][:3]
            
    return colors

sem_palette = np.array(distinctipy.get_colors(70, pastel_factor=0.2))
# sem_palette = np.load('../vis/palette.npy')
sim_palette = sns.color_palette("plasma", as_cmap=True)

In [3]:
normalized_dataset = ShapeNetDataset('../../datasets/hdfs/shapenet.h5', ['val'], ['all'],
                                     points_labels_path='../../datasets/shapenet_labels_global.h5',
                                     transform=PointCloudNormalize('box'),
                                     point_labels_level='local',
                                     n_classes=50)

dataset = ShapeNetDataset('../../datasets/hdfs/shapenet.h5', ['val'], ['all'],
                          points_labels_path='../../datasets/shapenet_labels_global.h5',
                          point_labels_level='local',
                          n_classes=70)

loader = DataLoader(normalized_dataset, shuffle=False, batch_size=4)
un_loader = DataLoader(dataset, shuffle=False, batch_size=4)

In [None]:
device = 'cuda:2'
model = PointNet().to(device)
model.load_state_dict(torch.load('../weights/simclr_run_1kindykb_ckp_150.pt', map_location=device)['model'])

In [None]:
!nvidia-smi

In [None]:
def find_query_point(colors):
    red = np.array([220 / 255, 27 / 255, 27 / 255])
    
    for i in range(colors.shape[0]):
        if np.allclose(colors[i], red):
            return i

In [None]:
d = np.load('../vis/activations_all/headphones_arc.npz')

In [None]:
find_query_point(d['colors'][0])

In [None]:
x0, l0 = normalized_dataset[2]
x1, l1 = normalized_dataset[5]

x2, l2 = normalized_dataset[1204]
x3, l3 = normalized_dataset[1205]

x4, l4 = normalized_dataset[1520]
x5, l5 = normalized_dataset[1519]

x = torch.from_numpy(np.array([x0, x1, x2, x3, x4, x5]))
labels = torch.from_numpy(np.array([l0, l1, l2, l3, l4, l5]))
device = 'cuda:2'
x = x.to(device)

In [6]:
gl_dataset = ShapeNetDataset('../../datasets/hdfs/shapenet.h5', ['val'], ['all'],
                             points_labels_path='../../datasets/shapenet_labels_global.h5',
                             point_labels_level='global',
                             n_classes=70)

x0, l0 = normalized_dataset[2]
x1, l1 = normalized_dataset[5]
x2, l2 = normalized_dataset[6]
x3, l3 = normalized_dataset[7]
labels = torch.from_numpy(np.array([l0, l1, l2, l3]))
x = torch.from_numpy(np.array([x0, x1, x2, x3]))

global_labels = np.array([gl_dataset[2][1], gl_dataset[5][1], gl_dataset[6][1], gl_dataset[7][1]])

In [7]:
grey = np.full((4, 2048, 3), 134 / 255)
colors = np.concatenate([grey, sem_palette[labels], sem_palette[global_labels]], axis=0)

In [8]:
points = np.concatenate([x.transpose(2, 1).cpu().numpy(), x.transpose(2, 1).cpu().numpy(),
                         x.transpose(2, 1).cpu().numpy()])

In [None]:
k3d.points(x[0].cpu().t(), attribute=labels[0], point_size=0.08)

In [None]:
feats = get_features(model, x)

In [None]:
pl = k3d.plot()

query_sample_idx = 6
query_point_idx = 117
support_sample_idx = 7


sim = get_similarity_scores(feats, query_sample_idx, query_point_idx, support_sample_idx)

y = x[query_sample_idx].clone()
y[1] += 2
pl += k3d.points(y.cpu().t(), point_size=0.08)
pl += k3d.points(y[:, query_point_idx:query_point_idx+1].cpu().t(), point_size=0.2)
pl += k3d.points(x[support_sample_idx].cpu().t(), point_size=0.08, attribute=sim)
pl

In [None]:
grey = np.full((2048, 3), 134 / 255)
grey[query_point_idx] = [220 / 255, 27 / 255, 27 / 255]

In [9]:
np.savez('../vis/algo_ex', points=points, colors=colors)

In [None]:
colors.shape

In [None]:
colors = np.array([grey, sim_palette(sim.numpy())[:, :3]])

In [None]:
np.savez('../vis/headphones_arc', points=x[6:8].transpose(2, 1).cpu(), colors=colors)

In [None]:
airplane, airplane_labels = dataset[0]
airplane2, airplane2_labels = dataset[2]

chair, chair_labels = dataset[2011]
chair2, chair2_labels = dataset[2013]

car, car_labels = dataset[1161]
car2, car2_labels = dataset[1169]

labels = np.array([airplane_labels, airplane2_labels, chair_labels, chair2_labels, car_labels, car2_labels])

In [None]:
data = np.load('data.npz')
points = data['points']
labels = data['labels']

In [None]:
colors = convert_labels2colors(labels)

In [None]:
k3d.points((points @ r)[0], point_size=0.01)

In [None]:
set(car_labels), set(car2_labels)

In [None]:
points = np.transpose(np.array([airplane, airplane2, chair, chair2, car, car2]), (0, 2, 1))

In [None]:
np.savez('../vis/patches_spectral.npz', colors=colors, points=points @ r)

In [None]:
np.savez('../vis/patches.npz', colors=colors, points=points)

In [None]:
set(chair2_labels)

In [None]:
k3d.points(car2.T, attribute=car2_labels, point_size=0.01)

In [None]:
bright'

In [None]:
! pip install distinctipy

In [None]:
np.save('../vis/palette', sem_palette)