In [5]:
import jax
import jax_metrics as jm
import jax.numpy as jnp
from jax import grad, jit, vmap
from functools import partial
from jax import random
from jax.scipy.special import logsumexp
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'

In [58]:
class NeuralNetwork():
    """
    We have a basic netwrk implementation
    """
    def __init__(self, network_params):
        layers, step_size, epochs, batch, n_targets = network_params
        self.layer_sizes = layers
        self.step_size = step_size
        self.num_epochs = epochs
        self.batch_size = batch
        self.n_targets = n_targets
        self.params = self._init_network_params(self.layer_sizes, random.PRNGKey(0))
        self.batched_predict = vmap(self.predict, in_axes=(None, 0))

    @staticmethod
    def _random_layer_params(m, n, key, scale=1e-2):
        w_key, b_key = random.split(key)
        return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

    # Initialize all layers for a fully-connected neural network with sizes "sizes"
    def _init_network_params(self, sizes, key):
        keys = random.split(key, len(sizes))
        return [self._random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

    @staticmethod
    def relu(x):
        return jnp.maximum(0, x)

    def predict(self, params, image):
        # Layer calculations using 
        # [Layer1, Layer2, Layer3] = [:-1]
        activations = image
        for w, b in params[:-1]:
            outputs = jnp.dot(w, activations) + b
            activations = self.relu(outputs)
        # Output Layer
        final_w, final_b = params[-1]
        logits = jnp.dot(final_w, activations) + final_b
        return jax.nn.softmax(logits)
    
    def batch_predict(self, params, batched_images):
        return self.batched_predict(params, batched_images)
    
    @staticmethod
    def one_hot(x, k, dtype=jnp.float32):
        """Create a one-hot encoding of x of size k."""
        return jnp.array(x[:, None] == jnp.arange(k), dtype)
    
    def accuracy(self, images, targets):
        target_class = jnp.argmax(targets, axis=1)
        predicted_class = jnp.argmax(self.batched_predict(self.params, images), axis=1)
        return jnp.mean(predicted_class == target_class)

    def loss(self, params, images, targets):
        preds = self.batched_predict(params, images)
        return jnp.mean(jm.losses.crossentropy(preds, targets))

    @partial(jit, static_argnums=(0,))
    def update(self, params, x, y):
        grads = grad(self.loss)(params, x, y)
        return [(w - self.step_size * dw, b - self.step_size * db)
                      for (w, b), (dw, db) in zip(params, grads)] 
    
    #@partial(jit, static_argnums=(0,))
    def dloss(self, params, x, y):
        return grad(self.loss)(params, x, y)

In [54]:
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 30
batch_size = 128
n_targets = 10
network_params = (layer_sizes, step_size, num_epochs, batch_size, n_targets)

In [59]:
network = NeuralNetwork(network_params)

In [26]:
import tensorflow as tf
tf.config.set_visible_devices([], device_type='GPU')
import tensorflow_datasets as tfds
data_dir = '/home/remote_code/Jax_Excercises/tfds'

In [27]:
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']

In [28]:
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

In [29]:
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = network.one_hot(train_labels, num_labels)

In [30]:
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = network.one_hot(test_labels, num_labels)

In [None]:
import time

def get_train_batches():
    # as_supervised=True gives us the (image, label) as a tuple instead of a dict
    ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
    # You can build up an arbitrary tf.data input pipeline
    ds = ds.batch(batch_size).prefetch(1)
    # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
    return tfds.as_numpy(ds)

for epoch in range(num_epochs):
    start_time = time.time()
    for x, y in get_train_batches():
        x = x.astype(jnp.float32)
        x = jnp.reshape(x, (len(x), num_pixels))
        y = network.one_hot(y, num_labels)
        network.params = network.update(network.params, x, y)
    epoch_time = time.time() - start_time

    train_acc = network.accuracy(train_images, train_labels)
    test_acc = network.accuracy(test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

In [42]:
network.predict(network.params, x[0])

Array([1.6497804e-20, 2.5545521e-24, 1.5067530e-13, 3.7225091e-11,
       1.6700621e-27, 1.5232610e-21, 2.0894807e-30, 1.0000000e+00,
       7.8525401e-22, 3.5597523e-16], dtype=float32)

In [50]:
len(y)

96

In [61]:
z = jax.xla_computation(network.dloss)(network.params, x, y)

In [62]:
with open("t.dot", "w") as f:
    f.write(z.as_hlo_dot_graph())

In [None]:
jax.