In [1]:
import jax.numpy as jnp
import jax.random as random
import jax
import optax
import torch
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)
jax.config.update("jax_enable_x64", True)
#jax.disable_jit(disable=True)

In [2]:
# check if GPU is working
jax.default_backend()
jax.device_put(jax.numpy.ones(1), device=jax.devices('gpu')[0])

Array([1.], dtype=float64)

In [3]:
# set up params
batch_size = 4

import torchvision
import torchvision.transforms as transforms
# first load the dataset
train_data = torchvision.datasets.MNIST(root = './', train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
test_data = torchvision.datasets.MNIST(root = './', train=False, download=True, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=2)

In [4]:
# convert to jnp/np
x_train, y_train = zip(*train_data)
x_train, y_train = jnp.array(x_train)[:40], jnp.array(y_train)[:40]

x_test, y_test = zip(*test_data)
x_test, y_test = jnp.array(x_test), jnp.array(y_test)

In [5]:
# flatten each x
x_train = jnp.array([jnp.ravel(x) for x in x_train])
x_test = jnp.array([jnp.ravel(x) for x in x_test])

In [6]:
# convert ys to one-hot
classes = len(set(y_train.tolist()))
print(classes)
y_train = jax.nn.one_hot(y_train, classes) # from n -> one-hot of n
y_test = jax.nn.one_hot(y_test, classes)

10


In [7]:
#jax.device_put(x_train, device=jax.devices('gpu')[0])
#jax.device_put(y_train, device=jax.devices('gpu')[0])

In [8]:
# train_data[idx][0] => x   (1, 28, 28)
# train_data[idx][1] => y   int
for idx in range(10):
  print(x_train[idx].shape, y_train[idx])
  # print(x_train[idx][0][14], train_data[idx][0][0][14])

jnp.mean(jnp.array(x_train[0]))

(784,) [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
(784,) [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
(784,) [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
(784,) [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
(784,) [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
(784,) [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
(784,) [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
(784,) [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
(784,) [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
(784,) [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]


Array(0.13768007, dtype=float32)

In [9]:
## functions
keys = random.split(random.PRNGKey(10298213), 10)
neurons = [
    28*28,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10,
    10
]

def init_mlp_params(key, neurons):
  # - HE weight initialization
  # bias initializaiton as 0
  wkey, bkey = random.split(key, 2)
  mlp_params = {
      # remember, its xW, not Wx, so W should be (in_vector_size, out_vector_size)
      # so that (m,) @ (m,n) => (n,)
      "weights" : [
          # He initialization: norm(0,1) * (2/sqrt(weight.size))
          random.normal(wkey, shape=(neurons[i], neurons[i+1])) * 2 / jnp.sqrt(neurons[i]*neurons[i+1])
          for i in range(len(neurons) - 1)
      ],
      "biases" : [
          # initialize biases as 0 vectors
          jnp.zeros(shape=neurons[i+1])
          for i in range(len(neurons) - 1)
      ]
  }
  return mlp_params

def mlp_forward(params, x_batch, y_batch):
  # xW, not Wx
  # x_batch y_batch
  x = x_batch
  for i in range(len(neurons)-1):
    x = x @ params["weights"][i]
    x = x + params["biases"][i]
    if i < len(neurons)-2:
      x = jax.nn.relu(x)
    else:
      x = jax.nn.softmax(x)
  return x

def get_loss(params, x_batch, y_batch):
  y_pred_batch = mlp_forward(params, x_batch, y_batch)
  # the reason for using jax.scipy.special.xlogy instead of
  # -jnp.log(y_pred_batch) * y_batch   is that it accounts for 0 in the
  # prediction batch. otherwise, 0 produces -inf and breaks the training
  presum = -jax.scipy.special.xlogy(y_batch, y_pred_batch)
  crossentropyloss = jnp.sum(presum)
  return crossentropyloss


def param_norms(params):
  norms = {
      "weights" : [jnp.log(jnp.linalg.norm(w)) for w in params['weights']],
      'biases'  : [jnp.log(jnp.linalg.norm(b)) for b in params['biases']]
  }
  return norms

params = init_mlp_params(keys[0], neurons)
learning_rate = 0.001
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)
# for MNIST? cross entropy sum(-log(prediction)*real)

def train_step(params, x_batch, y_batch, optimizer, opt_state):
  losses = get_loss(params, x_batch, y_batch)
  grads = jax.grad(get_loss)(params, x_batch, y_batch)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  # ok so concepually, updates are different than grads. grads are used to calculate updates
  # like in adam where the grads are used to calculate the moments, and then the moments
  # combined with the learning rate are used to calculate the change to the params
  # i.e. the updates to the params
  return params, opt_state, losses, grads


import time
time_between_logs = 3
from pprint import pprint

batch_size = 4
train_datapoints = len(x_train)
batches = len(x_train)//batch_size
indices = random.permutation(keys[1], train_datapoints)
# first just overfit it on the first batch or something
epochs = 1000
last_log_time = time.time()

record = []
for epoch in range(epochs):
  for batch in range(batches):
    batch_start = batch*batch_size
    batch_end = batch_start + batch_size
    batch_indices = indices[batch_start:batch_end]
    x_batch, y_batch = x_train[batch_indices], y_train[batch_indices]

    params, opt_state, losses, norms = train_step(params, x_batch, y_batch, optimizer, opt_state)

    if time.time() - last_log_time > time_between_logs:
      last_log_time = time.time()
      print(f"epoch {epoch}, batch {batch}, loss={jnp.mean(losses)}")
      #pprint(norms)
      record.append((epoch, jnp.mean(losses)))

# optax adam

epoch 0, batch 0, loss=9.21034037179805
epoch 0, batch 6, loss=9.201590694414419
epoch 1, batch 2, loss=9.200388543379937
epoch 1, batch 7, loss=9.213365023069759
epoch 2, batch 2, loss=9.195364092740697
epoch 2, batch 7, loss=9.209113784598557
epoch 3, batch 3, loss=9.189825993957939
epoch 3, batch 9, loss=9.2222350834381
epoch 4, batch 4, loss=9.246703906830174
epoch 4, batch 9, loss=9.222059275146552
epoch 5, batch 5, loss=9.18025445698861
epoch 6, batch 0, loss=9.153390519550298
epoch 6, batch 5, loss=9.172164645898151
epoch 7, batch 0, loss=9.142362799220141
epoch 7, batch 5, loss=9.163105322862503
epoch 8, batch 0, loss=9.128399662148325
epoch 8, batch 5, loss=9.153061038150408
epoch 9, batch 0, loss=9.111605202597573
epoch 9, batch 5, loss=9.141821583320704
epoch 10, batch 0, loss=9.09167625390793
epoch 10, batch 5, loss=9.128780780383003
epoch 11, batch 0, loss=9.067599086081284
epoch 11, batch 5, loss=9.112913807656906
epoch 12, batch 0, loss=9.037443442394153
epoch 12, batch 

FloatingPointError: invalid value (inf) encountered in jit(true_divide). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

If you see this error, consider opening a bug report at https://github.com/jax-ml/jax.

In [31]:
def param_norms(params):
  norms = {
      "weights" : [jnp.log(jnp.linalg.norm(w)) for w in params['weights']],
      'biases'  : [jnp.log(jnp.linalg.norm(b)) for b in params['biases']]
  }
  return norms

losses = get_loss(params, x_batch, y_batch)


closed_jaxpr = jax.make_jaxpr(jax.grad(get_loss))(params, x_batch, y_batch)
jax.core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, params, x_batch, y_batch)

ValueError: safe_map() argument 2 is shorter than argument 1

In [10]:

# ChatGPT made this
import pygame
import numpy as np
import jax.numpy as jnp
import jax

# Initialize Pygame
pygame.init()

# Constants
GRID_SIZE = 28
CELL_SIZE = 20  # Size of each cell in the grid (pixels)
WINDOW_SIZE = GRID_SIZE * CELL_SIZE

# Colors
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
GRAY = (200, 200, 200)

# Initialize the grid (28x28, like MNIST)
grid = np.zeros((GRID_SIZE, GRID_SIZE), dtype=np.float32)

# Initialize the window
screen = pygame.display.set_mode((WINDOW_SIZE, WINDOW_SIZE))
pygame.display.set_caption("Draw and Predict")

def draw_grid():
    """Draw the grid on the Pygame window."""
    screen.fill(WHITE)
    for y in range(GRID_SIZE):
        for x in range(GRID_SIZE):
            color = BLACK if grid[y, x] > 0 else WHITE
            pygame.draw.rect(screen, color, (x * CELL_SIZE, y * CELL_SIZE, CELL_SIZE, CELL_SIZE))
            pygame.draw.rect(screen, GRAY, (x * CELL_SIZE, y * CELL_SIZE, CELL_SIZE, CELL_SIZE), 1)  # Grid lines

def get_prediction(grid):
    """Predict the digit based on the current grid."""
    input_image = jnp.array(grid).reshape(1, -1)  # Flatten the grid
    y_pred = mlp_forward(params, input_image, None)  # Predict using the MLP
    predicted_class = jnp.argmax(y_pred)
    print(f"Predicted Probabilities: {y_pred}")
    print(f"Predicted Class: {predicted_class}")
    return predicted_class

def main():
    running = True
    drawing = False  # Track whether the mouse is pressed

    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False

            # Handle mouse clicks
            if event.type == pygame.MOUSEBUTTONDOWN:
                drawing = True
            if event.type == pygame.MOUSEBUTTONUP:
                drawing = False

            # Clear the grid
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_c:  # Press 'C' to clear
                    grid.fill(0)
                if event.key == pygame.K_p:  # Press 'P' to predict
                    prediction = get_prediction(grid)
                    print(f"Model Prediction: {prediction}")

        # Draw on the grid when the mouse is pressed
        if drawing:
            x, y = pygame.mouse.get_pos()
            grid_x, grid_y = x // CELL_SIZE, y // CELL_SIZE
            if 0 <= grid_x < GRID_SIZE and 0 <= grid_y < GRID_SIZE:
                grid[grid_y, grid_x] = 1  # Mark the cell as filled

        # Update the display
        draw_grid()
        pygame.display.flip()

    pygame.quit()

if __name__ == "__main__":
    main()


pygame 2.6.1 (SDL 2.28.4, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html
Predicted Probabilities: [[nan nan nan nan nan nan nan nan nan nan]]
Predicted Class: 0
Model Prediction: 0
Predicted Probabilities: [[nan nan nan nan nan nan nan nan nan nan]]
Predicted Class: 0
Model Prediction: 0
Predicted Probabilities: [[nan nan nan nan nan nan nan nan nan nan]]
Predicted Class: 0
Model Prediction: 0
