In [1]:
from pynndescent import NNDescent

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons, make_moons, fetch_openml
from sklearn.model_selection import train_test_split


import cluster_represent.similarities as sims
import cluster_represent.crs as crs

# Setup

In [31]:
dataset = 'moons'
similarity = sims.cosine

# Load the Dataset

In [32]:
if dataset == 'moons':  # MOONS
    X, y = make_moons(n_samples=400, random_state=1, noise=0.1)
elif dataset == 'MNIST_Fashion':  # MNIST Fashion
    X, y = fetch_openml('mnist_784', version=1, return_X_y=True)

In [33]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

# Get Individual Classes

In [34]:
classes = set(y)

In [35]:
class_samples = dict()
for class_id in classes:
    selection_arr = y_train == class_id
    class_samples[class_id] = X_train[selection_arr]

# Find Representatives

In [50]:
representatives = dict()
for class_id, samples in class_samples.items():
    neighbors, distances = crs.create_knn_graph(samples, 3, similarity=similarity)
    representative_indexes = crs.find_representatives(samples, neighbors, distances, 0.95)
    representatives[class_id] = [samples[i] for i in representative_indexes]

  self._set_arrayXarray(i, j, x)
  self._set_arrayXarray(i, j, x)


In [51]:
for class_id, reps in representatives.items():
    print(class_id, ':', len(reps))

0 : 36
1 : 30


# Classification

In [52]:
hits = 0
misses = 0
for sample, label in zip(X_test, y_test):
    best_sim = 0
    guess_label = None
    for class_id, reps in representatives.items():
        for representative in reps:
            sim = similarity(representative, sample)
            if best_sim < sim:
                best_sim = sim
                guess_label = class_id
    if label == guess_label:
        hits += 1
    else:
        misses += 1

In [53]:
print("Hits:", hits, "\nMisses:", misses)

Hits: 61 
Misses: 19
