# TPUs and KNN

In [1]:
import jax
import jax.numpy as jnp
import tensorflow_datasets as tfds

from jax.config import config
config.update("jax_enable_x64", True)

In [2]:
def get_datasets():
    """Load MNIST train and test datasets into memory."""
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    return train_ds, test_ds

In [224]:
train, test = get_datasets()

X_train = train["image"].reshape(-1, 28 ** 2)
X_train = 4 * X_train - 2
y_train = train["label"]

X_test = test["image"].reshape(-1, 28 ** 2)
print(X_test.shape)
X_test = 4 * X_test - 2
y_test = test["label"]

n_train = 30_000
n_test = 30

X_train = X_train[:n_train]
y_train = y_train[:n_train]

X_test = X_test[:n_test]
y_test = y_test[:n_test]

(10000, 784)


In [229]:
30_000 / 8

3750.0

In [253]:
def compute_diff(u, v):
#     return (u[None, :] - v[:, None]) ** 2
    return (u[:, None] - v[None, :]) ** 2


devices = 1
ndev_train = n_train // devices
compute_diff = jax.vmap(compute_diff, in_axes=1, out_axes=-1)
compute_diff = jax.pmap(compute_diff, in_axes=0)

def compute_distance(U, V):
    return compute_diff(U, V).mean(axis=-1)

p_argsort = jax.pmap(jnp.argsort, in_axes=0)
def compute_k_closest(U, V, k):
    D = compute_distance(U, V)
    
    D = D.reshape(devices, n_test // devices, -1)
    nearest = p_argsort(D)[..., 1:k+1]
    return nearest

In [255]:
compute_k_closest(X_test, X_train, 5).shape

(1, 30, 5)

In [251]:
jax.local_device_count()

1

In [None]:
Dj = compute_distance(X_test, X_train)
Dj.shape

In [202]:
Dreshape = Dj.reshape(1, 10, -1)
Dreshape

DeviceArray([[[2.1976, 2.4276, 2.206 , ..., 2.3649, 3.0102, 1.5677],
              [1.3495, 1.3143, 2.1858, ..., 1.8039, 2.6597, 2.2468],
              [3.2746, 3.3349, 3.0668, ..., 2.3765, 2.0354, 2.8196],
              ...,
              [1.0756, 1.2676, 2.4354, ..., 2.65  , 2.4429, 2.4995],
              [1.778 , 1.6293, 2.7443, ..., 1.2514, 2.6058, 2.1182],
              [2.0825, 2.3377, 1.6483, ..., 1.8558, 2.5907, 1.6154]]],            dtype=float32)

In [216]:
p_argsort = jax.pmap(jnp.argsort, in_axes=0)
p_argsort(Dreshape).shape

(1, 10, 90000)

In [217]:
pmap_k_closest = jax.pmap(compute_k_closest, "k", in_axes=(None, None, 0))

In [186]:
pmap_k_closest(X_test, X_train, np.array([1]))

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

In [121]:
jnp.set_printoptions(suppress=True, precision=4)

In [49]:
np.random.seed(314)
k_closest = np.random.randint(1, 100, (3 * 5, 5))
k_closest

array([[ 9, 74, 52, 87, 43],
       [79,  8, 99, 59, 87],
       [72, 81, 36, 23,  8],
       [89,  5, 33, 92, 69],
       [46, 89, 17, 10, 62],
       [55, 46, 67, 72, 87],
       [37,  9, 51,  1,  8],
       [83, 58, 89, 38, 78],
       [62, 94, 64, 92, 10],
       [86,  5, 51, 10, 87],
       [12, 19, 17, 86,  1],
       [81, 59, 98, 23, 27],
       [95, 66, 95, 67, 60],
       [62, 45, 67, 50,  8],
       [18,  4, 81, 92, 60]])

In [50]:
np.argsort(k_closest)

array([[0, 4, 2, 1, 3],
       [1, 3, 0, 4, 2],
       [4, 3, 2, 0, 1],
       [1, 2, 4, 0, 3],
       [3, 2, 0, 4, 1],
       [1, 0, 2, 3, 4],
       [3, 4, 1, 0, 2],
       [3, 1, 4, 0, 2],
       [4, 0, 2, 3, 1],
       [1, 3, 2, 0, 4],
       [4, 0, 2, 1, 3],
       [3, 4, 1, 0, 2],
       [4, 1, 3, 0, 2],
       [4, 1, 3, 0, 2],
       [1, 0, 4, 2, 3]])

In [51]:
k_closest[0, [0, 4, 2, 1, 3]]

array([ 9, 43, 52, 74, 87])

In [53]:
k_closest_shuffled = k_closest.reshape(3, 5, 5)
k_closest_shuffled

array([[[ 9, 74, 52, 87, 43],
        [79,  8, 99, 59, 87],
        [72, 81, 36, 23,  8],
        [89,  5, 33, 92, 69],
        [46, 89, 17, 10, 62]],

       [[55, 46, 67, 72, 87],
        [37,  9, 51,  1,  8],
        [83, 58, 89, 38, 78],
        [62, 94, 64, 92, 10],
        [86,  5, 51, 10, 87]],

       [[12, 19, 17, 86,  1],
        [81, 59, 98, 23, 27],
        [95, 66, 95, 67, 60],
        [62, 45, 67, 50,  8],
        [18,  4, 81, 92, 60]]])

In [47]:
np.argsort(A_shuffled)

array([[[0, 4, 2, 1, 3],
        [1, 3, 0, 4, 2],
        [4, 3, 2, 0, 1],
        [1, 2, 4, 0, 3],
        [3, 2, 0, 4, 1]],

       [[1, 0, 2, 3, 4],
        [3, 4, 1, 0, 2],
        [3, 1, 4, 0, 2],
        [4, 0, 2, 3, 1],
        [1, 3, 2, 0, 4]],

       [[4, 0, 2, 1, 3],
        [3, 4, 1, 0, 2],
        [4, 1, 3, 0, 2],
        [4, 1, 3, 0, 2],
        [1, 0, 4, 2, 3]]])

In [113]:
key = jax.random.PRNGKey(314)
A = jax.random.normal(key, (10, 5))
A

DeviceArray([[ 1.1104, -1.2406, -0.6909,  1.2509,  0.0491],
             [ 1.6945, -0.2667,  0.048 ,  0.1682, -0.9681],
             [-0.841 , -0.5139,  1.1786, -0.0084,  2.3075],
             [ 0.3981,  0.1632, -0.8194, -2.0301, -1.0782],
             [ 2.4647, -0.6206, -1.1689, -0.6257,  0.2767],
             [-1.1944,  1.883 ,  0.9415,  0.5941, -1.4575],
             [-0.1479, -0.2617, -0.1321, -0.1167, -1.0322],
             [ 0.357 ,  0.3082,  0.369 , -0.2063, -0.8255],
             [ 1.0441, -2.2558,  1.0567,  0.1041, -0.6379],
             [-0.0799, -0.0156,  1.1134,  0.8509, -0.7933]],            dtype=float64)

In [117]:
jax.lax.dynamic_slice(A, (1, 1), (4, 5))

DeviceArray([[ 1.6945, -0.2667,  0.048 ,  0.1682, -0.9681],
             [-0.841 , -0.5139,  1.1786, -0.0084,  2.3075],
             [ 0.3981,  0.1632, -0.8194, -2.0301, -1.0782],
             [ 2.4647, -0.6206, -1.1689, -0.6257,  0.2767]],            dtype=float64)

In [102]:
pmap_k_closest(X_test, X_train, np.array([1, 2, 3]))

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

In [79]:
k_nearest = compute_k_closest(X_test, X_train, 4)
(y_train[k_nearest, ...].mean(axis=1).round() == y_test).mean()

0.8666666666666667

In [80]:
(y_train[k_nearest, None] == jnp.arange(10)[None, None, :]).shape

(30, 4, 10)

## pKNN

In [69]:
def compute_likelihood(X, y, beta, k):
    Q = len(jnp.unique(y))
    k_closest = compute_k_closest(X, X, k=k)
    y_closest = y[k_closest]
    # Comparing the "k" closest datapoints of xi to yi
    num = y_closest == y[:, None]
    num = jnp.exp(beta * num.mean(axis=1) + num.max())
    
    den = y_closest[..., None] == jnp.arange(Q)[None, None, :]
    den = jnp.exp(beta * den.mean(axis=1)).sum(axis=-1)
    
    likelihood = num / den
    return likelihood

In [70]:
beta = 10
k = 1
L = compute_likelihood(X_train, y_train, beta, k)

In [71]:
L.prod()

DeviceArray(2.47249603e-44, dtype=float64)

In [56]:
jnp.log(L).max()

DeviceArray(-0.00040852, dtype=float64)

In [182]:
X_train.shape

(5000, 784)