### Import

In [14]:
import os
from os.path import join
import random
import pickle
from time import time
from datetime import datetime
import shutil

from icecream import ic
import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import ConfusionMatrixDisplay

from utils import get_angles
from feature_extractor import surflet_pairs_feature

### Data

In [15]:
def select_ids(ids, labels, samples_per_class):
    label2ids = {label:[] for label in range(1, 6)}
    res = []
    for id in ids:
        label = gt[id]['label']
        label2ids[label].append(id)
    for label in labels:
        res += random.sample(label2ids[label], samples_per_class)
    return res


def get_train_data(ids, filepath, labels=[1, 2, 3, 4, 5], samples_per_class=10**9):
    x_train, y_train = [], []
    count = {label:0 for label in labels}
    for id in tqdm(select_ids(ids, labels, samples_per_class)):
        label = gt[id]['label']
        if label in labels and count[label] < samples_per_class:
            # hists = get_histogram(filepath=filepath % id, bins=bins, normalize=True, flatten=True, max_ratio=max_ratio)
            hist = surflet_pairs_feature(filepath=filepath % id, n_pairs=100)
            x_train.append(hist)
            y_train.append(label)
            count[label] += 1
    return np.array(x_train), np.array(y_train)


def get_test_data(ids, filepath, labels=[1, 2, 3, 4, 5], samples_per_class=10**9):
    x_test, y_test, test_ids = [], [], []
    count = {label:0 for label in labels}
    for id in tqdm(select_ids(ids, labels, samples_per_class)):
        label = gt[id]['label']
        if label in labels and count[label] < samples_per_class:
            hist = surflet_pairs_feature(filepath=filepath % id, n_pairs=100)
            x_test.append(hist)
            y_test.append(label)
            test_ids.append(id)
            count[label] += 1
    return x_test, y_test, test_ids

In [16]:
def init_dataset(labels=[1, 2, 3, 4, 5], samples_per_class=[10**9, 10**9], train_full_data=False):
    if train_full_data:
        train_ids = list(range(1, 46001))
        test_ids = []
    else:
        train_ids, test_ids = pickle.load(open('./honv/train_test_ids.pkl', 'rb'))
    x_train, y_train = get_train_data(
        ids=train_ids,
        filepath='./dataset/ply/training/pointCloud/pointCloud%d.ply',
        labels=labels,
        samples_per_class=samples_per_class[0],
    )

    if not train_full_data:
        x_test, y_test, test_ids = get_test_data(
            ids=test_ids,
            filepath='./dataset/ply/training/pointCloud/pointCloud%d.ply',
            labels=labels,
            samples_per_class=samples_per_class[1],
        )
        return (x_train, y_train), (x_test, y_test, test_ids)
    else:
        return x_train, y_train

### Eval utils

In [17]:
def predict(model, xs, labels):
    return model.predict(xs)


def evaluate(model=None, x_test=None, y_test=None, test_ids=None, labels=None, y_pred=None, save_fig=True, names=None, ):
    assert labels is not None and y_test is not None
    n_labels = len(labels)
    confusion_matrix = np.zeros((n_labels, n_labels))
    confusion_ids = [[[] for i in range(n_labels)] for i in range(n_labels)]
    confusion_hists = [[[] for i in range(n_labels)] for i in range(n_labels)]
    N = len(y_test)

    if y_pred is None:
        assert model is not None and x_test is not None
        y_pred = predict(model, x_test[:N], labels)
    for i in range(N):
        confusion_matrix[labels.index(y_test[i]), labels.index(y_pred[i])] += 1
        confusion_ids[labels.index(y_test[i])][labels.index(y_pred[i])].append(test_ids[i])
        confusion_hists[labels.index(y_test[i])][labels.index(y_pred[i])].append(random.choice(x_test[i]))

    pickle.dump(confusion_ids, open(join(LOGGING_PATH, 'confusion_ids.pkl'), 'wb'))
    pickle.dump(confusion_hists, open(join(LOGGING_PATH, 'confusion_hists.pkl'), 'wb'))

    acc = sum(confusion_matrix[i, i] for i in range(n_labels)) / confusion_matrix.sum()
    confusion_matrix = confusion_matrix.astype('int32')
     
    disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=[names[label] for label in labels])
    plt.rc('font', size=12)
    fig, ax = plt.subplots(figsize=(10, 10))
    try:
        ax.set_title(f'accuracy = %.4f\nK={cf.n_neighbors}\n\nEach row has a sum of 1' % acc)
    except:
        ax.set_title(f'accuracy = %.4f\n\nEach row has a sum of 1' % acc)
    disp.plot(cmap='Reds', ax=ax)
    plt.show()
    if save_fig:
        fig.savefig(join(LOGGING_PATH, f'{NAME}.jpg'))

    print('Accuracy:', acc)
    return acc, confusion_matrix

### Run

In [18]:
class Config:
    def __init__(self, **kwargs):
        for key in kwargs:
            setattr(self, key, kwargs[key])

    def save(self, root):
        text = ''
        for attr, value in self.__dict__.items():
            text += f'{attr} {value}\n'
        with open(join(root, 'config.txt'), 'w') as f:
            f.write(text)

In [24]:
cf = Config(
    n_neighbors = 15,
    labels=[3, 5]
)

NAME = '35_15_full'
LOGGING_PATH = f'./SP/{NAME}/'
os.makedirs(LOGGING_PATH, exist_ok=True)
cf.save(LOGGING_PATH)

In [26]:
gt = pickle.load(open('./metadata/ground_truth.pkl', 'rb'))

# (x_train, y_train), (x_test, y_test, test_ids) = init_dataset(labels=cf.labels, samples_per_class=[8280, 920])
# pickle.dump(((x_train, y_train), (x_test, y_test, test_ids)), open(join(LOGGING_PATH, 'data.pkl'), 'wb'))

x_train, y_train = init_dataset(labels=cf.labels, samples_per_class=[9200, 0], train_full_data=True)
pickle.dump((x_train, y_train), open(join(LOGGING_PATH, 'data.pkl'), 'wb'))

# (x_train, y_train), (x_test, y_test, test_ids) = pickle.load(open('./SP/35_1/data.pkl', 'rb'))

  0%|          | 0/18400 [00:00<?, ?it/s]

In [None]:
model = KNeighborsClassifier(n_neighbors=cf.n_neighbors)
model.fit(x_train, y_train)
pickle.dump(model, open(join(LOGGING_PATH, 'model.pkl'), 'wb'))

# acc, confusion_matrix = evaluate(model, x_test, y_test, test_ids, cf.labels,
#     names={3:'sphere', 5:'torus'}
# )

In [94]:
# confusion_ids = pickle.load(open(f'{LOGGING_PATH}/confusion_ids.pkl', 'rb'))
# confusion_hists = pickle.load(open(f'./{LOGGING_PATH}/confusion_hists.pkl', 'rb'))

# os.makedirs(join(LOGGING_PATH, 'confusion/ply/'), exist_ok=True)
# os.makedirs(join(LOGGING_PATH, 'confusion/hist/'), exist_ok=True)

# names = ['sphere', 'torus']
# for i in range(2):
#     for j in range(2):
#         N = len(confusion_ids[i][j])
#         indices = random.sample(list(range(N)), min(N, 5))
#         for index in indices:
#             id = confusion_ids[i][j][index]
#             hist = confusion_hists[i][j][index]
#             src = f'./dataset/ply/training/pointCloud/pointCloud{id}.ply'
#             dst = join(LOGGING_PATH, f'confusion/ply/{names[i]}_{names[j]}_{id}.ply')
#             shutil.copy(src, dst)