In [None]:
import os
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import json
from utils.experiment import Experiment

In [None]:
dataset = 'point_mnist'
slice_setting = 'over'

In [None]:
if dataset == 'point_mnist':
    df_train = pd.read_csv('../dataset/pointcloud_mnist_2d/train.csv')

    X = df_train[df_train.columns[1:]].to_numpy()
    y = df_train[df_train.columns[0]].to_numpy()

    X = X.reshape(X.shape[0], -1, 3)
    
    num_points = np.sum((X[:, :, 2] > 0).astype(int), axis=1)
    
    set_size_median = np.median(num_points).astype(int)
    n_slices = 8 if slice_setting == 'over' else 2

elif dataset == 'modelnet40':
    set_size_median = 512
    n_slices = 16 if slice_setting == 'over' else 3
    
elif dataset == 'oxford':
    with open('../dataset/oxford/train_test_AE8.pkl', 'rb') as f:
        data = pickle.load(f)

    X_train, y_train, X_test, y_test, classnames = data

    num_points = np.array([i.shape[0] for i in X_train])

    set_size_median = np.median(num_points).astype(int)
    n_slices = 128 if slice_setting == 'over' else 8

print(dataset, set_size_median, n_slices)

In [None]:
code_length = 1024
ref = 'rand'
seeds = [0, 1, 4, 10, 16]
ks = [4, 8, 16]
reports = []

### FS

In [None]:
for seed in seeds:
    for k in ks:
        exp = Experiment(dataset, 'fs', 'faiss-lsh', 
                         random_state=seed, ref_func=ref, k=k, ref_size=set_size_median, code_length=code_length)
        exp.test()
        report = exp.get_exp_report()
        print(report)
        reports.append(report)

### SWE

In [None]:
for seed in seeds:
    for k in ks:
        exp = Experiment(dataset, 'swe', 'faiss-lsh', random_state=seed, ref_func=ref, k=k, ref_size=set_size_median, code_length=code_length, num_slices=n_slices)
        exp.test()
        report = exp.get_exp_report()
        print(report)
        reports.append(report)

### WE

In [None]:
for seed in seeds:
    for k in ks:
        exp = Experiment(dataset, 'we', 'faiss-lsh', 
                         random_state=seed, ref_func=ref, k=k, ref_size=set_size_median, code_length=code_length)
        exp.test()
        report = exp.get_exp_report()
        print(report)
        reports.append(report)

### Cov

In [None]:
for seed in seeds:
    for k in ks:
        exp = Experiment(dataset, 'cov', 'faiss-lsh',
                         random_state=seed, k=k, ref_size=set_size_median, code_length=code_length)
        exp.test()
        report = exp.get_exp_report()
        print(report)
        reports.append(report)

### GeM-1

In [None]:
for seed in seeds:
    for k in ks:
        exp = Experiment(dataset, 'gem', 'faiss-lsh',
                         random_state=seed, k=k, ref_size=set_size_median, code_length=code_length, power=1)
        exp.test()
        report = exp.get_exp_report()
        print(report)
        reports.append(report)

### GeM-2

In [None]:
for seed in seeds:
    for k in ks:
        exp = Experiment(dataset, 'gem', 'faiss-lsh',
                         random_state=seed, k=k, ref_size=set_size_median, code_length=code_length, power=2)
        exp.test()
        report = exp.get_exp_report()
        print(report)
        reports.append(report)

### GeM-4

In [None]:
for seed in seeds:
    for k in ks:
        exp = Experiment(dataset, 'gem', 'faiss-lsh',
                         random_state=seed, k=k, ref_size=set_size_median, code_length=code_length, power=4)
        exp.test()
        report = exp.get_exp_report()
        print(report)
        reports.append(report)

In [None]:
import altair as alt

In [None]:
labels = {'fs': 'FSPool', 'swe': 'SLOSH', 'we': 'WE', 
          'cov': 'Cov', 'gem-1': 'GeM-1', 'gem-2': 'GeM-2', 'gem-4': 'GeM-4'}

In [None]:
data = pd.DataFrame(reports)
data['pooling'] = data['pooling'].apply(lambda x: labels[x])

In [None]:
points = alt.Chart(data[data['k'] == 4]).mark_point().encode(
    alt.X('mean(emb_time_per_sample):Q', title='Average Embedding Time'),
    alt.Y('mean(acc):Q', title='Accuracy'),
    color=alt.Color('pooling:N', legend=None),
).properties(
    width=240,
    height=240
)

In [None]:
text = points.mark_text(
    align='left',
    baseline='middle',
    dx=5,
    size=15
).encode(
    text='pooling:N'
)

In [None]:
alt.layer(points + text).configure_axis(
    labelFontSize=12,
    titleFontSize=16
)

In [None]:
pd.options.display.float_format = "{:,.2f}".format

In [None]:
data.groupby(['pooling', 'k'])[['precision_k', 'acc']].mean()

In [None]:
data.groupby(['pooling', 'k'])[['precision_k', 'acc']].std()