In [1]:
import jax.numpy as jnp
import jax.random as random
import jax
import optax
import torch
import time
from pprint import pprint
#jax.config.update("jax_debug_nans", True)
#jax.config.update("jax_debug_infs", True)
#jax.config.update("jax_enable_x64", True)
#jax.disable_jit(disable=True)

In [2]:
# check if GPU is working
jax.default_backend()
jax.device_put(jax.numpy.ones(1), device=jax.devices('gpu')[0])

Array([1.], dtype=float32)

In [3]:
# set up params
batch_size = 4

import torchvision
import torchvision.transforms as transforms
# first load the dataset
train_data = torchvision.datasets.MNIST(root = './', train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
test_data = torchvision.datasets.MNIST(root = './', train=False, download=True, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=2)

In [4]:
# convert to jnp/np
x_train, y_train = zip(*train_data)
x_train, y_train = jnp.array(x_train), jnp.array(y_train)

x_test, y_test = zip(*test_data)
x_test, y_test = jnp.array(x_test), jnp.array(y_test)

In [5]:
# flatten each x
x_train = jnp.array([jnp.ravel(x) for x in x_train])
x_test = jnp.array([jnp.ravel(x) for x in x_test])

In [6]:
# convert ys to one-hot
classes = len(set(y_train.tolist()))
print(classes)
y_train = jax.nn.one_hot(y_train, classes) # from n -> one-hot of n
y_test = jax.nn.one_hot(y_test, classes)

10


In [7]:
#jax.device_put(x_train, device=jax.devices('gpu')[0])
#jax.device_put(y_train, device=jax.devices('gpu')[0])

In [8]:
# train_data[idx][0] => x   (1, 28, 28)
# train_data[idx][1] => y   int
for idx in range(10):
  print(x_train[idx].shape, y_train[idx])
  # print(x_train[idx][0][14], train_data[idx][0][0][14])

jnp.mean(jnp.array(x_train[0]))

(784,) [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
(784,) [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
(784,) [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
(784,) [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
(784,) [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
(784,) [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
(784,) [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
(784,) [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
(784,) [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
(784,) [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]


Array(0.13768007, dtype=float32)

In [None]:
## functions
keys = random.split(random.PRNGKey(10298213), 10)
neurons = [
    28*28,
    28*28,
    28*28,
    28*28,
    28*28,
    10
]

def init_mlp_params(wkey, neurons):
  # - HE weight initialization
  # bias initializaiton as 0
  mlp_params = {
    f"layer_{i}" : {
      # remember, its xW, not Wx, so W should be (in_vector_size, out_vector_size)
      # so that (m,) @ (m,n) => (n,)
      # He initialization: norm(0,1) * (2/sqrt(weight.size))
      "weight" : random.normal(wkey, shape=(neurons[i], neurons[i+1])) * jnp.sqrt(2 / neurons[i]),
      # initialize biases as 0 vectors
      "bias" : jnp.zeros(shape=neurons[i+1])
    } for i in range(len(neurons) - 1)
  }
  return mlp_params

def mlp_forward(params, x_batch):
  # returns LOGITS
  # xW, not Wx
  # x_batch y_batch
  x = x_batch
  for i in range(len(neurons)-1):
    x = x @ params[f"layer_{i}"]["weight"] + params[f"layer_{i}"]["bias"]
    if i < len(neurons)-2:
      x = jax.nn.relu(x)
    else:
      pass
      #x = jax.nn.softmax(x)
      # no. return logits, so log_softmax can be used in get_loss
  return x

def accuracy(params, x_batch, y_batch):
  logits = mlp_forward(params, x_batch)
  predictions = jnp.argmax(logits, axis=1)
  correct = jnp.argmax(y_batch, axis=1)
  return jnp.mean(predictions == correct)


def get_loss(params, x_batch, y_batch):
  logits_batch = mlp_forward(params, x_batch)
  # the reason for using jax.scipy.special.xlogy instead of
  # -jnp.log(y_pred_batch) * y_batch   is that it accounts for 0 in the
  # prediction batch. otherwise, 0 produces -inf and breaks the training
  log_probs = jax.nn.log_softmax(logits_batch) # using this builtin prevents issues when calculating grads
  crossentropyloss_batch = jnp.sum(y_batch * log_probs, axis=1)
  batch_loss = -jnp.mean(crossentropyloss_batch)
  return batch_loss


def param_norms(params):
  norms = {
      "weights" : [jnp.log(jnp.linalg.norm(w)) for w in params['weights']],
      'biases'  : [jnp.log(jnp.linalg.norm(b)) for b in params['biases']]
  }
  return norms


learning_rate = 0.01
optimizer = optax.adam(learning_rate)


@jax.jit
def train_step(params, x_batch, y_batch, opt_state):
  losses = get_loss(params, x_batch, y_batch)
  grads = jax.grad(get_loss)(params, x_batch, y_batch)
  updates, updated_opt_state = optimizer.update(grads, opt_state)
  updated_params = optax.apply_updates(params, updates)
  # ok so concepually, updates are different than grads. grads are used to calculate updates
  # like in adam where the grads are used to calculate the moments, and then the moments
  # combined with the learning rate are used to calculate the change to the params
  # i.e. the updates to the params
  return updated_params, updated_opt_state, losses, grads


def train_loop():
  record = []
  time_limit = 30000000 #seconds

  params = init_mlp_params(keys[0], neurons)
  opt_state = optimizer.init(params)
  # for MNIST? cross entropy sum(-log(prediction)*real)


  batch_size = 8
  train_datapoints = len(x_train)
  batches = len(x_train)//batch_size
  indices = random.permutation(keys[1], train_datapoints)
  # first just overfit it on the first batch or something
  epochs = 500000000
  start_time = time.time()

  for epoch in range(epochs):
    indices = random.permutation(random.PRNGKey(epoch), train_datapoints)
    for batch in range(batches):
      batch_start = batch*batch_size
      batch_end = batch_start + batch_size
      batch_indices = indices[batch_start:batch_end]
      x_batch, y_batch = x_train[batch_indices], y_train[batch_indices]

      acc = accuracy(params, x_batch, y_batch)
      params, opt_state, losses, norms = train_step(params, x_batch, y_batch, opt_state)

      print(f"epoch {epoch}, batch {batch}, loss={jnp.mean(losses)}, acc={acc}")
      #pprint(norms)
      #record.append((epoch, jnp.mean(losses)))

      if jnp.mean(losses) == 0 or time.time() - start_time >= time_limit:
        duration = (time.time() - start_time)
        print(f"DONE in {duration}s")
        steps_per_batch = 8
        batches_per_epoch = batches
        steps = epoch*batches_per_epoch*steps_per_batch + batch*steps_per_batch
        steps_per_second = steps/duration
        print("Samples trained on per second: ", steps_per_second)
        return

train_loop()

