In [1]:
import jax.numpy as jnp
import jax.random as random
import jax
import optax
import torch
import time
from pprint import pprint

In [2]:
# load mnist stuff
batch_size = 4

import torchvision
import torchvision.transforms as transforms
# first load the dataset
train_data = torchvision.datasets.MNIST(root = './', train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
test_data = torchvision.datasets.MNIST(root = './', train=False, download=True, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=2)


# convert to jnp/np
x_train, y_train = zip(*train_data)
x_train, y_train = jnp.array(x_train), jnp.array(y_train)

x_test, y_test = zip(*test_data)
x_test, y_test = jnp.array(x_test), jnp.array(y_test)

# convert ys to one-hot
classes = len(set(y_train.tolist()))
print(classes)
y_train = jax.nn.one_hot(y_train, classes) # from n -> one-hot of n
y_test = jax.nn.one_hot(y_test, classes)

10


Custo model after this CNN: CNN but with skip connections as the first few layers

custom training after this: CNN with 3d convolution for minecraft worlds
  - single minecraft world
  - put chunks in
  - classify biome


In [3]:
# Util functions
def init_cnn_params(key, all_layer_params, layer_types):
  keys = random.split(key, 20)
  cnn_params = []
  for i in range(len(all_layer_params)):
    layer_type, params = all_layer_params[i]

    if layer_type == "input":
      input_size = params
      layer_params = input_size
      current_shape = input_size # invariant: this branch will always run first

    elif layer_type == "conv":
      kernel_count, n, m = params
      current_layer_count = current_shape[0]
      current_layer_rows = current_shape[1]
      current_layer_columns = current_shape[2]
      layer_params = [random.normal(keys[0], shape=(n, m)) for kernel in range(kernel_count*current_layer_count)]
      current_shape = (current_layer_count*kernel_count, current_layer_rows - n, current_layer_columns - m)

    elif layer_type == "pool":
      n, m = params
      layer_params = (n, m)
      current_shape = (current_shape[0], current_shape[1] - n, current_shape[2] - m)

    elif layer_type == "fc":
      # idk man just pile em all together.
      output_shape = params
      current_layer_count = current_shape[0]
      current_layer_rows = current_shape[1]
      current_layer_columns = current_shape[2]
      input_shape = current_layer_count*current_layer_rows*current_layer_columns
      layer_params = {
        'weights' : random.normal(keys[1], shape=(input_shape, output_shape)),
        'biases' : jnp.zeros(shape=(output_shape,))
        }
      current_shape = output_shape

    else:
      raise ValueError('invalid layer type')
    
    cnn_params.append(layer_params)
  return cnn_params

# if conv, init kernel weights and output layer biases
# if pool ... actually im not sure if this even has weights
# if fc just do a normal linear layer


def convolve(layer, kernel):
  # just use upper left pixel of kernel as the output for now
  out = jnp.zeros(shape=(len(layer) - len(kernel), len(layer[0]) - len(kernel[0])))
  for row in range(len(layer) - len(kernel)):
    for col in range(len(layer[0]) - len(kernel[0])):
      ksum = 0
      for krow in range(len(kernel)):
        for kcol in range(len(kernel[0])):
          ksum += layer[row + krow][col + kcol] * kernel[krow][kcol]
      out = out.at[row, col].set(ksum)
  return out

def convolve_layers(input_layers, kernels):
  output_layers = []
  for layer in input_layers:
    for kernel in kernels:
      output_layers.append(convolve(layer, kernel))
  return output_layers


def maxpool(layer, maxpool_shape):
  # just use upper left pixel of kernel as the output for now
  out = jnp.zeros(shape=(len(layer) - maxpool_shape[0], len(layer[0]) - maxpool_shape[1]))
  for row in range(len(layer) - maxpool_shape[0]):
    for col in range(len(layer[0]) - maxpool_shape[1]):
      max_val = -10000
      for krow in range(maxpool_shape[0]):
        for kcol in range(maxpool_shape[1]):
          max_val = max(max_val, layer[row + krow][col + kcol])
      out = out.at[row, col].set(max_val)
  return out

def maxpool_layers(input_layers, maxpool_shape):
  output_layers = []
  for layer in input_layers:
    output_layers.append(maxpool(layer, maxpool_shape))
  return output_layers

def crossentropyloss(logits, y):
  yhat = jax.nn.log_softmax(logits)
  return jnp.sum(-yhat * y)



In [4]:
# training functions

def cnn_forward(cnn_params, cnn_layer_types, x):
  for i in range(len(cnn_params)):
    try:
      print(x.shape)
    except:
      print(len(x), x[0].shape)
    layer_params, layer_type  = cnn_params[i], cnn_layer_types[i]
    if layer_type == "input":
      continue
    elif layer_type == "conv":
      print("convolving")
      kernels = layer_params
      x = convolve_layers(x, kernels)
    elif layer_type == "pool":
      print('maxpooling')
      maxpool_shape = layer_params
      x = maxpool_layers(x, maxpool_shape)
    elif layer_type == "fc":
      print('fcing')
      x = jnp.ravel(jnp.array(x))
      x = x @ layer_params["weights"] + layer_params["biases"]
    else:
      raise ValueError(f'invalid layer type {layer_type}')
  return x


# output logits if classes. convert to yhat in other functions.

def get_loss(cnn_params, cnn_layer_types, x, y):
  logits = cnn_forward(cnn_params, cnn_layer_types, x)
  loss = crossentropyloss(logits, y)
  return loss




In [6]:
# do training

layers = [
  ("input", (1, 28, 28)),
  ("conv", (3, 4, 4)), # 3 4x4 kernels for each input layer => 3x layers next
  ("pool", (3, 3)), # pool all by a 3x3 kernel
  ("fc", 10) # activations (no relu tho) are output logits for 0-9
]
cnn_layer_types = [
  layer[0] for layer in layers
]

key = random.PRNGKey(2983)
keys = random.split(key, 100)

cnn_params = init_cnn_params(keys[0], layers, cnn_layer_types)

learning_rate = 0.01
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(cnn_params)

@jax.jit
def train_step(cnn_params, x, y, opt_state):
  loss, grads = jax.value_and_grad(get_loss)(cnn_params, cnn_layer_types, x, y)
  updates, updated_opt_state = optimizer.update(grads, opt_state)
  updated_cnn_params = optax.apply_updates(cnn_params, updates)
  return updated_cnn_params, updated_opt_state, loss

x = x_train[0]
y = y_train[0]
train_step(cnn_params, x, y, opt_state)


TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.