In [1]:
import jax.numpy as jnp
import jax.random as random
import jax
import optax
import torch
import time
from pprint import pprint
import flax

In [2]:
# load mnist stuff
batch_size = 4

import torchvision
import torchvision.transforms as transforms
# first load the dataset
train_data = torchvision.datasets.MNIST(root = './', train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
test_data = torchvision.datasets.MNIST(root = './', train=False, download=True, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=2)


# convert to jnp/np
x_train, y_train = zip(*train_data)
x_train, y_train = jnp.array(x_train), jnp.array(y_train)

x_test, y_test = zip(*test_data)
x_test, y_test = jnp.array(x_test), jnp.array(y_test)

# convert ys to one-hot
classes = len(set(y_train.tolist()))
print(classes)
y_train = jax.nn.one_hot(y_train, classes) # from n -> one-hot of n
y_test = jax.nn.one_hot(y_test, classes)

10


Custo model after this CNN: CNN but with skip connections as the first few layers

custom training after this: CNN with 3d convolution for minecraft worlds
  - single minecraft world
  - put chunks in
  - classify biome


In [20]:
# Util functions
def init_cnn_params(key, input_shape, output_shape):
  keys = random.split(key, 20)
  cnn_params = {}

  # shape: (1, 28, 28) => conv1 => (3, 28-(4-1), 28-(4-1))
  conv1 = {
    #out_channels, in_channels, kernel_row, kernel_col
    # 4x4 kernel. 1 input (mnist), 3 outputs
    "conv1" : jax.nn.initializers.glorot_uniform()(keys[0], (3, 1, 4, 4))
  }
  current_shape = (3, 28-(4-1), 28-(4-1))

  conv2 = {
    "conv2" : jax.nn.initializers.glorot_uniform()(keys[2], (3, 3, 4, 4))
  }
  current_shape = (3, 25, 25) # this conv will have a skip connection and mode will be 'same'

  # shape: (3, 25, 25) => maxpool 2x2 => (3, 25 - (2-1), 25-(2-1))
  current_shape = (3, 24, 24)

  # shape: (3, 24, 24) => conv1 => (7, 24-(10-1), 24-(10-1))
  conv3 = {
    #out_channels, in_channels, kernel_row, kernel_col
    # 4x4 kernel. 1 input (mnist), 3 outputs
    "conv3" : jax.nn.initializers.glorot_uniform()(keys[3], (7, 3, 10, 10))
  }
  current_shape = (7, 15, 15)

  conv4 = {
    "conv4" : jax.nn.initializers.glorot_uniform()(keys[4], (7, 7, 10, 10))
  }
  current_shape = (7, 15, 15) # this conv will have a skip connection and mode will be 'same'
  
  # shape: (7, 15, 15) => ravel => (7 * 15 * 15,)
  current_shape = (7 * 15 * 15,)
  fc1 = {
    # since its xW for batching, W shape=(input, output)
    "fc1" : {
      "w" : random.normal(keys[1], shape=(current_shape[0], 128)) * jnp.sqrt(2 / current_shape[0]),
      'b' : jnp.zeros(shape=128)
      }
  }

  # shape: (7 * 15 * 15,) => fc1 => (128,)
  current_shape = (128,)
  fc2 = {
    "fc2" : {
      "w" : random.normal(keys[5], shape=(current_shape[0], output_shape[0])) * jnp.sqrt(2 / current_shape[0]),
      'b' : jnp.zeros(shape=output_shape)
      }
  }

  cnn_params.update(conv1)
  cnn_params.update(conv2)
  cnn_params.update(conv3)
  cnn_params.update(conv4)
  cnn_params.update(fc1)
  cnn_params.update(fc2)

  return cnn_params



def maxpool(layer, maxpool_shape):
  # just use upper left pixel of kernel as the output for now
  out = jnp.zeros(shape=(layer.shape[0] - maxpool_shape[0] + 1, len(layer[0]) - maxpool_shape[1] + 1))
  for row in range(len(layer) - maxpool_shape[0]):
    for col in range(len(layer[0]) - maxpool_shape[1]):
      max_val = -10000
      for krow in range(maxpool_shape[0]):
        for kcol in range(maxpool_shape[1]):
          max_val = max(max_val, layer[row + krow][col + kcol])
      out = out.at[row, col].set(max_val)
  return out

def maxpool_layers(input_layers, maxpool_shape):
  output_layers = []
  for layer in input_layers:
    output_layers.append(maxpool(layer, maxpool_shape))
  return jnp.array(output_layers)

def crossentropyloss(logits, y):
  yhat = jax.nn.log_softmax(logits)
  return jnp.sum(-yhat * y)



In [21]:
# training functions

# output logits if classes. convert to yhat in other functions.
def get_loss(cnn_params, x, y):
  logits = cnn_forward(cnn_params, x)
  loss = crossentropyloss(logits, y)
  return loss

# yes i know its inefficient idc
def get_accuracy(cnn_params, xs, ys):
  correct = 0
  for idx in range(len(xs)):
    logits = cnn_forward(cnn_params, xs[idx])
    yhat, y = jnp.argmax(logits), jnp.argmax(ys[idx])
    if yhat == y:
      correct += 1
  return correct / len(xs)


@jax.jit
def cnn_forward(cnn_params, x):
  # do conv, then relu, then maxpool, then fc.
  x = jnp.array([
    jax.scipy.signal.convolve(x, kernel, mode="valid")[0]
    for kernel in cnn_params["conv1"]
  ])
  x = jax.nn.relu(x)
  print(x.shape)

  # residual/skip connnection. mode=same to output the same shape.
  x_residual = jnp.array([
    jax.scipy.signal.convolve(x, kernel, mode='same')[0]
    for kernel in cnn_params["conv2"]
  ])
  x_residual = jax.nn.relu(x_residual)
  x = x + x_residual
  x = jax.nn.relu(x)
  print(x.shape)

  # maxpool 2x2
  x = jnp.array([
    # why do i have to reshape it?
    # channel[:, :, None]  (25, 25) => (25, 25, 1)
    flax.linen.max_pool(channel[:, :, None], window_shape=(2, 2))
    for channel in x
  ])
  x = jnp.squeeze(x, axis=-1) # (24, 24, 1) => (24, 24)
  print(x.shape)

  x = jnp.array([
    jax.scipy.signal.convolve(x, kernel, mode="valid")[0]
    for kernel in cnn_params["conv3"]
  ])
  x = jax.nn.relu(x)
  print(x.shape)

  x_residual = jnp.array([
    jax.scipy.signal.convolve(x, kernel, mode='same')[0]
    for kernel in cnn_params["conv4"]
  ])
  x_residual = jax.nn.relu(x_residual)
  x = x + x_residual
  x = jax.nn.relu(x)
  print(x.shape)


  x = jnp.ravel(x) # all channels to 1 layer
  print(x.shape)
  x = x @ cnn_params["fc1"]["w"] + cnn_params["fc1"]["b"]
  x = jax.nn.relu(x)

  x = x @ cnn_params["fc2"]["w"] + cnn_params["fc2"]["b"]

  return x



In [26]:
# do training

key = random.PRNGKey(29383)
keys = random.split(key, 100)

input_shape = (1, 28, 28)
output_shape = (10,)
cnn_params = init_cnn_params(keys[0], input_shape, output_shape)

learning_rate = 0.001
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(cnn_params)

@jax.jit
def train_step(cnn_params, x, y, opt_state):
  loss, grads = jax.value_and_grad(get_loss)(cnn_params, x, y)
  updates, updated_opt_state = optimizer.update(grads, opt_state)
  updated_cnn_params = optax.apply_updates(cnn_params, updates)
  return updated_cnn_params, updated_opt_state, loss

x = x_train[0]
y = y_train[0]
import time
start = time.time()

steps = 400000
print_every = 100
train_indices = random.permutation(keys[2], len(x_train))
test_indices = random.permutation(keys[3], len(x_test))
moving_avg = []
for step in range(steps):
  idx = train_indices[step % len(x_train)]
  cnn_params, opt_state, loss = train_step(cnn_params, x_train[idx], y_train[idx], opt_state)
  test_idx = test_indices[step % len(x_test)]
  val_loss = get_loss(cnn_params, x_test[test_idx], y_test[test_idx])
  moving_avg.append([loss, val_loss])
  if step % print_every == 0:
    loss, val_loss = jnp.mean(jnp.array(moving_avg), axis=0)
    accuracy = get_accuracy(cnn_params, x_test[:print_every], y_test[:print_every])
    print(f"step {step}, loss = {loss:0.4f}, val_loss = {val_loss:0.4f}, val_accuracy={accuracy:0.5f}")
    moving_avg = []
print("total: ", time.time() - start)


step 0, loss = 1.8796, val_loss = 2.5550, val_accuracy=0.08000
step 100, loss = 2.1550, val_loss = 2.0442, val_accuracy=0.39000
step 200, loss = 1.3083, val_loss = 1.3555, val_accuracy=0.44000
step 300, loss = 0.6268, val_loss = 1.0453, val_accuracy=0.78000
step 400, loss = 0.9674, val_loss = 0.8587, val_accuracy=0.85000
step 500, loss = 0.8854, val_loss = 0.7839, val_accuracy=0.90000
step 600, loss = 0.7537, val_loss = 0.4948, val_accuracy=0.86000
step 700, loss = 0.5497, val_loss = 0.6564, val_accuracy=0.89000
step 800, loss = 0.3841, val_loss = 0.4677, val_accuracy=0.89000
step 900, loss = 0.5004, val_loss = 0.7872, val_accuracy=0.78000
step 1000, loss = 0.5024, val_loss = 0.3308, val_accuracy=0.95000
step 1100, loss = 0.5568, val_loss = 0.4890, val_accuracy=0.88000
step 1200, loss = 0.4629, val_loss = 0.2669, val_accuracy=0.94000
step 1300, loss = 0.2745, val_loss = 0.1886, val_accuracy=0.95000
step 1400, loss = 0.3162, val_loss = 0.3098, val_accuracy=0.79000
step 1500, loss = 0.52

KeyboardInterrupt: 

In [29]:
import pygame
import numpy as np
import jax.numpy as jnp
import scipy.ndimage


def apply_brush(grid, x, y, intensity=40, radius=2):
    """
    Applies a Gaussian-like brush stroke to the grid.
    Args:
        grid: The drawing grid (28x28 numpy array).
        x, y: The coordinates of the center of the brush.
        intensity: The max intensity to add.
        radius: The size of the brush.
    """
    for dx in range(-radius, radius + 1):
        for dy in range(-radius, radius + 1):
            nx, ny = x + dx, y + dy
            if 0 <= nx < grid.shape[1] and 0 <= ny < grid.shape[0]:
                distance = (dx ** 2 + dy ** 2) ** 0.5
                weight = max(0, 1 - distance / radius)  # Linear fall-off
                grid[ny, nx] = min(255, int(grid[ny, nx]) + int(intensity * weight))

# Initialize Pygame
pygame.init()

# Grid settings
grid_size = 28
square_size = 20  # Each square is 20x20 pixels, making a 560x560 window
screen_size = grid_size * square_size

# Create the drawing grid (28x28)
drawing_grid = np.zeros((grid_size, grid_size), dtype=np.uint8)

# Initialize Pygame screen
screen = pygame.display.set_mode((screen_size, screen_size))
pygame.display.set_caption("Drawing")  # Default status

# Colors
black = (0, 0, 0)

# Function to draw the grid on the screen
def draw_grid():
    for x in range(grid_size):
        for y in range(grid_size):
            intensity = drawing_grid[y][x]
            color = (intensity, intensity, intensity)  # Grayscale color based on intensity
            pygame.draw.rect(
                screen,
                color,
                (x * square_size, y * square_size, square_size, square_size)
            )

# Main loop
running = True
drawing = False

while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False

        elif event.type == pygame.MOUSEBUTTONDOWN:
            drawing = True
            pygame.display.set_caption("Drawing")  # Set title to indicate drawing

        elif event.type == pygame.MOUSEBUTTONUP:
            drawing = False

        elif event.type == pygame.KEYDOWN:
            if event.key == pygame.K_c:  # Clear the grid
                drawing_grid = np.zeros((grid_size, grid_size), dtype=np.uint8)
                screen.fill(black)
                pygame.display.set_caption("Cleared")  # Update title
                pygame.display.flip()

            elif event.key == pygame.K_p:  # Predict
                # Pass the 28x28 numpy array to the CNN
                img = drawing_grid
                img_jax = jnp.expand_dims(jnp.expand_dims(jnp.array(img, dtype=jnp.float32), axis=0), axis=0)  # Shape: (1, 1, 28, 28)
                prediction = cnn_forward(cnn_params, img_jax[0])  # Forward pass with your JAX CNN
                predicted_digit = int(jnp.argmax(prediction))
                pygame.display.set_caption(f"Predicted Digit: {predicted_digit}")  # Update title with prediction

            elif event.key == pygame.K_l:  # Load a random image from x_train
                random_index = np.random.randint(0, len(x_test))
                drawing_grid = (np.array(x_test[random_index][0]) * 255).astype(np.uint8)  # Ensure NumPy and scale to [0, 255]
                screen.fill(black)
                pygame.display.set_caption("Loaded Example")  # Update title
                pygame.display.flip()

    # Handle drawing
    if drawing:
        mouse_x, mouse_y = pygame.mouse.get_pos()
        grid_x = mouse_x // square_size
        grid_y = mouse_y // square_size
        if 0 <= grid_x < grid_size and 0 <= grid_y < grid_size:
            apply_brush(drawing_grid, grid_x, grid_y, intensity=80, radius=2)

    # Update the screen
    draw_grid()
    pygame.display.flip()

pygame.quit()
