In [33]:
import jax
import jax.numpy as jnp
import tensorflow_datasets as tfds

from jax.config import config
config.update("jax_enable_x64", True)

In [34]:
def get_datasets():
    """Load MNIST train and test datasets into memory."""
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    return train_ds, test_ds

In [35]:
train, test = get_datasets()

X_train = train["image"].reshape(-1, 28 ** 2)
X_train = 4 * X_train - 2
y_train = train["label"]

X_test = test["image"].reshape(-1, 28 ** 2)
X_test = 4 * X_test - 2
y_test = test["label"]

n_train = 1_000
n_test = 30

X_train = X_train[:n_train]
y_train = y_train[:n_train]

X_test = X_test[:n_test]
y_test = y_test[:n_test]

In [37]:
def compute_diff(u, v):
#     return jnp.sum((u - v) ** 2)
    return (u[None, :] - v[:, None]) ** 2

compute_diff = jax.vmap(compute_diff, in_axes=1, out_axes=-1)

def compute_distance(U, V):
    return compute_diff(U, V).mean(axis=-1)

def compute_k_closest(U, V, k):
    D = compute_distance(U, V)
    nearest = jnp.argsort(D)[..., 1:k+1]
    return nearest

In [38]:
Dj = compute_distance(X_test[:n_test, :], X_train[:n_train, :])
Dj.shape

(1000, 30)

In [39]:
k_nearest = compute_k_closest(X_train, X_train, 4)

In [42]:
(y_train[k_nearest, None] == jnp.arange(10)[None, None, :]).shape

(1000, 4, 10)

In [67]:
def compute_likelihood(X, y, beta, k):
    Q = len(jnp.unique(y))
    k_closest = compute_k_closest(X, X, k=k)
    y_closest = y[k_closest]
    # Comparing the "k" closest datapoints of xi to yi
    num = y_closest == y[:, None]
    num = jnp.exp(beta * num.mean(axis=1) + num.max())
    
    den = y_closest[..., None] == jnp.arange(Q)[None, None, :]
    den = jnp.exp(beta * den.mean(axis=1)).sum(axis=-1)
    
    likelihood = num / den
    return likelihood

In [68]:
beta = 10
k = 1
L = compute_likelihood(X_train, y_train, beta, k)

In [70]:
L.prod()

DeviceArray(2.47249603e-44, dtype=float64)

In [56]:
jnp.log(L).max()

DeviceArray(-0.00040852, dtype=float64)

In [182]:
X_train.shape

(5000, 784)

In [179]:
X_test[..., None] - X_train.T[:, None, :]

(30, 784, 1)

In [183]:
X_train.T[:, None, :].shape

(784, 1, 5000)

In [190]:
X_test.shape

(30, 784)

In [201]:
def compute_k_closest(u, v, k):
    print(u.shape)
    print(v.shape)
    D = (u[None, :] - v[:, None]) ** 2
    print(D.shape)

compute_k_closest = jax.vmap(compute_k_closest, in_axes=[1, 1, None], out_axes=-1)
compute_k_closest(X_test, X_train, 4).shape

(30,)
(5000,)
(5000, 30)


AttributeError: 'NoneType' object has no attribute 'shape'

In [1]:
knn = jax.pmap(knn)

train, test = get_datasets()

X_train = train["image"].reshape(-1, 28 ** 2)
X_train = 4 * X_train - 2
y_train = train["label"]

X_test = test["image"].reshape(-1, 28 ** 2)
X_test = 4 * X_test - 2
y_test = test["label"]

print(X_train.shape)
print(y_train.shape)

n_train = 30_000
n_test = 30
k = 20
# import pdb; pdb.set_trace()
yhat = knn(X_test[:n_test, :], X_train[:n_train, :], y_train[:n_train], k)
print((y_test[:n_test] == yhat).mean())

Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


(60000, 784)
(60000,)


ValueError: pmap got arg 3 of rank 0 but axis to be mapped 0. The tree of ranks is:
((2, 2, 1, 0), {})