In [5]:
from jaxtyping import Float, Array, Int, PyTree

import jax
import jax.numpy as jnp
from jax import random as jrandom

import equinox as eqx
from equinox import Module

import optax as opt

from src.nn import *

import numpy as np

from matplotlib import pyplot as plt

import streamlit as st
from src.nn import *

In [6]:
# Hyperparameters
BATCH_SIZE = 64
LEARNING_RATE = 3e-4
EPOCHS = 2000
PRINT_EVERY = 100
SEED = 0

In [7]:
# Data
def load_data(b: int = 100, key: jrandom.PRNGKey = jrandom.PRNGKey(0)):
    key, subkey = jrandom.split(key)
    X = jrandom.normal(subkey, (b, 2, 10, 10))
    Y = X[:, 0] @ X[:, 1]

    return X, Y


In [8]:
# Evaluation and Training
def loss(model: NN, x: Float[Array, "b 2 10 10"], y: Float[Array, "b 10 10"]) -> Float[Array, ""]:
    y_pred = jax.vmap(model)(x[:, 0], x[:, 1])
    return jnp.mean((y - y_pred) ** 2)
    # return cross_entropy(y, y_pred)

def cross_entropy(y: Float[Array, "b 10 10"], y_pred: Float[Array, "b 10 10"]) -> Float[Array, ""]:
    return -jnp.mean(y * jnp.log(y_pred))

def accuracy(y: Float[Array, "b 10 10"], y_pred: Float[Array, "b 10 10"]) -> Float[Array, ""]:
    return jnp.mean(y == y_pred)

def evaluate(model: NN, b: int = 100, key: jrandom.PRNGKey = jrandom.PRNGKey(0)) -> Float[Array, ""]:
    X, Y = load_data(b, key)
    return loss(model, X, Y)

def train(
        model: NN,
        optim: opt.GradientTransformation,
        steps: int,
        print_every: int
) -> NN:
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    @eqx.filter_jit
    def step(
        model: NN,
        opt_state: PyTree,
        x: Float[Array, "b 2 10 10"],
        y: Float[Array, "b 10 10"]
    ) -> NN:
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_step = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_step, loss_value
    
    def infinite_trainloader():
        while True:
            X, Y = load_data(BATCH_SIZE)
            yield X, Y

    for s, (x, y) in zip(range(steps), infinite_trainloader()):
        model, opt_state, loss_value = step(model, opt_state, x, y)

        if s % print_every == 0 or s == steps - 1:
            print(f"Step {s}, Loss: {loss_value}")

    return model


### Initialize and train the model

In [11]:
print("Training CNN")

net = CNN()

print("Network parameters: ", net.param_count(), "\n")

net = train(net, opt.adam(LEARNING_RATE), EPOCHS, PRINT_EVERY)

print("\nNetwork parameters: ", net.param_count(), "\n")
net.split_layer(5, 25)

print(net, "\n")

net = train(net, opt.adam(LEARNING_RATE), 2*EPOCHS, PRINT_EVERY)


Training CNN
Network parameters:  79588 

Step 0, Loss: 10.12450885772705
Step 100, Loss: 8.0065279006958
Step 200, Loss: 3.3893675804138184
Step 300, Loss: 0.8956431150436401
Step 400, Loss: 0.12749134004116058
Step 500, Loss: 0.010441564954817295
Step 600, Loss: 0.0008798211347311735
Step 700, Loss: 8.334965241374448e-05
Step 800, Loss: 8.176562005246524e-06
Step 900, Loss: 7.860908226575702e-07
Step 1000, Loss: 7.070565999356404e-08
Step 1100, Loss: 5.716437989633505e-09
Step 1200, Loss: 4.062150060768488e-10
Step 1300, Loss: 3.002399717733084e-11
Step 1400, Loss: 4.542908150356739e-12
Step 1500, Loss: 1.6038366381504465e-12
Step 1600, Loss: 8.326447712737883e-13
Step 1700, Loss: 5.250597049158423e-13
Step 1800, Loss: 3.983249006241779e-13
Step 1900, Loss: 3.163669142196984e-13
Step 1999, Loss: 2.759897616434054e-13

Network parameters:  79588 

CNN(
  layers=[
    Conv2d(
      num_spatial_dims=2,
      weight=f32[4,2,4,4],
      bias=f32[4,1,1],
      in_channels=2,
      out_chan