# Basic MNIST Example

This basic example shows loading from a YAML file. You can specify all the parameters in the yaml file, but we're going to load the raw data using tensorflow.


In [None]:
import numpy as np
import tensorflow as tf
from pygoko import CoverTree
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train.astype(np.float32)
y_train = y_train.astype(np.int64)
x_train = x_train.reshape(-1, 28*28)

Here we build the covertree, with a leaf cutoff and a minimum resolution index to control the size of the tree. 

The minimum resolution index is the scale at which the tree stops splitting. This can be used to control the L2 error (we use the standard fast implementation, which is not the most accurate), or to specify a scale at which the KNN doesn't matter to you. 

The leaf cutoff controls the size of individual leafs of the tree. If a node covers fewer than this number of points, the splitting stops and the node becomes a leaf. 

The scale base controls the down-step of each split. 1.3 is a good default. It is usually close to the fastest at creating the tree but can be hard to reason about. Another popular choice is 2, which means the radius halves at each step. 

In [None]:
tree = CoverTree()
tree.set_leaf_cutoff(10)
tree.set_scale_base(1.3)
tree.set_min_res_index(-20)
tree.fit(x_train,y_train)

Here's the basic KNN for this data structure. 

In [None]:
point = np.zeros([784], dtype=np.float32)
tree.knn(point,5)

The nodes are addressable by specifying the scale index, and the point index (in the originating dataset). This errors out if you supply an address that isn't known tot he tree. (Currently this is rust panicing about you unwrapping an option that is a None). Only use this creation method with known, correct, addresses.

In [None]:
root = tree.root()
print(f"Root address: {root.address()}")
for child in root.children():
    child_address = child.address()
    # The following is the same node as the child:
    copy_of_child = tree.node(child_address)
    print(f"  Child address: {copy_of_child.address()}")

If a query point were to belong to the dataset that the tree was constructed from, but was never selected as a routing node, then it would end up at a particular leaf node. This leaf node is deterministic (given the pre-built tree). The path for the query point is the addresses of the nodes from the root node to this leaf.

In [None]:
path = tree.path(point)
print(path)

print("Summary of the labels of points covered by the node at address")
for dist, address in path:
    node = tree.node(address)
    label_summary = node.label_summary()
    print(f"Address: {address} \t Summary: {label_summary}")

We can also query for the path of known points, by index in the original dataset.

In [None]:
path = tree.known_path(40000)

print("Summary of the labels of points covered by the node at address")
for dist, address in path:
    node = tree.node(address)
    label_summary = node.label_summary()
    print(f"Address: {address} \t Summary: {label_summary}")

