In [17]:
from typing import Tuple, List, Callable, Sequence

import jax
from jax import value_and_grad, jit, vmap, Array
from jax.random import KeyArray
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import chex
from optax import adam, GradientTransformation, apply_updates

from dataclasses import dataclass

@dataclass
class Config:
  sizes: Sequence[int]
  seed: int
  epochs: int
  learning_rate: float
  activation_fn: Callable
  goodness_fn: Callable
  flat_shape: Tuple[int] = (784,)

class Layer(nn.Module):
  size: int
  activation_fn: Callable

  @nn.compact
  def __call__(self, x):
    x = x / jnp.linalg.norm(x, 2, keepdims = True)
    return self.activation_fn(nn.Dense(self.size)(x))

def create_network(sizes: Sequence[int], learning_rate: float, activation_fn: Callable):  
  return [(Layer(size, activation_fn), adam(learning_rate)) for size in sizes]

ForwardForwardLayer = Tuple[Layer, GradientTransformation]
Network = List[ForwardForwardLayer]

In [162]:
from keras.datasets import mnist

def load() -> Tuple[Array, Array, Array, Array]:
  """Remotely load MNIST data to JAX Arrays."""

  # Load Data
  (X_train, y_train), (X_test, y_test) = mnist.load_data()

  # Scale & Flatten images
  X_train = (X_train.astype("float32") / 255).reshape(*X_train.shape[:-2], -1)
  X_test = (X_test.astype("float32") / 255).reshape(*X_test.shape[:-2], -1)

  return (
    jnp.array(X_train),
    jnp.array(y_train),
    jnp.array(X_test),
    jnp.array(y_test)
  )

def swap_perfect(key: KeyArray, y: Array) -> jnp.array:
  """Swap labels such that each entry is definitely different to it's initial entry.
  
  Args:
    key: jax.random.PRNGKey
    y: jnp.array
  
  Returns:
    y_out: jnp.array, of shuffled labels
  """

  def swap(args):
    key, label, uniques = (args)
    key, subkey = jax.random.split(key)
    sample = jax.random.choice(subkey, jnp.setdiff1d(uniques, label))
    return subkey, sample

  uniques = jnp.unique(y, size = 10)

  y_out = []
  for label in y.flatten():
    key, negative_label = swap((key, label, uniques))
    y_out.append(negative_label)

  return jnp.array(y_out).reshape(y.shape)

@jax.jit
def overlay(X: Array, y: Array, l: int = 25) -> Array:
  """Combines X and y into a single vector which is compatible with training
  via the forward-forward algorithm. In this example, we make the top line of pixels
  correspond to the given label. Also flattens each array for use with MLP.
  
  Args:
    X: Training examples.
    y: Correct or incorrect labels.
  
  Returns:
    out -> Xy array.
  """
  _X = X
  return _X.at[:, 0:l].set(jnp.full((l, y.shape[0]), y).T)

@jax.jit
def prep_input(key: KeyArray, X: chex.Array, y: chex.Array) -> Tuple[chex.Array, chex.Array]:
  X_pos = overlay(X, y)
  X_neg = overlay(X, jax.random.permutation(key, y))  
  return X_pos, X_neg

In [163]:
@jit
def goodness_original(a: Array) -> Array:
  """Goodness of fit as the sum of squares of activations."""
  return (a ** 2).sum()

@jit
def loss(
  A_pos: Array, 
  A_neg: Array,
  theta: float
) -> float:
  """Compute loss on positive and negative examples."""
  loss_pos = ((goodness_original(A_pos) - theta) * -1)
  loss_neg = (goodness_original(A_neg) - theta)
  return (loss_pos + loss_neg).mean()

In [159]:
X_train, y_train, X_test, y_test = load()

X_train.shape
X_test.shape

(10000, 784)

In [160]:
X_train.at[:, 0:25]

_IndexUpdateRef(Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), (slice(None, None, None), slice(0, 25, None)))

In [161]:
chex.Ar

Array([[5., 5., 5., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [4., 4., 4., ..., 0., 0., 0.],
       ...,
       [5., 5., 5., ..., 0., 0., 0.],
       [6., 6., 6., ..., 0., 0., 0.],
       [8., 8., 8., ..., 0., 0., 0.]], dtype=float32)

In [166]:
from flax.training.train_state import TrainState
from functools import partial

def train_layer(
  key: KeyArray,
  X: chex.Array,
  y: chex.Array,
  fflayer: ForwardForwardLayer,
  epochs: int,
  theta: int, 
  flat_shape: Tuple[int] = (784,)
):

  @value_and_grad
  @partial(jit, static_argnums=(3,))
  def loss(params, X_pos, X_neg, goodness_fn):
    A_pos = state.apply_fn({'params': params}, X_pos)
    A_neg = state.apply_fn({'params': params}, X_neg)
    loss_pos = -(goodness_fn(A_pos) - theta)
    loss_neg = (goodness_fn(A_neg) - theta)
    return (loss_pos + loss_neg).mean()

  @jit
  def train_step(inkey, X_pos, X_neg, state):
    inkey, subkey = jax.random.split(inkey, 2)
    loss_val, grads = loss(state.params, X_pos, X_neg)
    state = state.apply_gradients(grads=grads)
    return subkey, loss_val, state

  X_init = jax.random.normal(key, flat_shape)
  layer, optimizer = fflayer
  params = layer.init(key, X_init)

  state = TrainState.create(
        apply_fn = layer.apply,
        tx = optimizer,
        params = params['params']
    )
  
  for epoch in range(epochs):
    key, subkey = jax.random.split(key, 2)
    X_pos, X_neg = prep_input(subkey, X, y)
    key, loss_val, state = train_step(subkey, X_pos, X_neg, state)
    if epoch % 10 == 0: print(f'Epoch {epoch}, loss: {loss_val}')

  # Get out to feed to next layer
  X_in, _ = prep_input(subkey, X, y)
  X_out = state.apply_fn({'params': state.params}, X_in)

  return state, X_out

TrainedNet = List[TrainState]

def train(key: KeyArray, net: Network, X: chex.Array, y: chex.Array, epochs: int, theta: int) -> TrainedNet:
  _X = X
  trained = []

  # Train all Network Layers
  for l in net:
    state, _X = train_layer(key, _X, y, l, epochs, theta)
    trained.append(state)
  
  return trained

In [197]:
def predict(
  trainedNet: TrainedNet,
  X: Array,
  y: Array,
):
  layer_activations = []
  for label in jnp.unique(y):
    y_sgl = jnp.full(y.shape, label)
    X_t = overlay(X, y_sgl)

    activations = []
    for state in trainedNet:
      A_t = state.apply_fn({'params': state.params}, X_t)
      print(type(A_t))
      activations.append(A_t)
    
    layer_activations.append(activations)
  
  #return jnp.argmax(jnp.concatenate(preds), axis = 1)
  return layer_activations

In [198]:
activations = predict(out, X_test, y_test)

<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>
<class 'jaxlib.xla_extension.Array'>


In [204]:
arr = jnp.array([i[0] for i in activations])

In [222]:
for lab in activations:
  overall = []
  for i in lab:
    overall.append(jnp.sum(i, axis = 1))

[0.02560646 0.02708386 0.02655082 ... 0.02623161 0.02462271 0.02641665]
[0.02560646 0.02708386 0.02655082 ... 0.02623161 0.02462271 0.02641665]
[0.02560646 0.02708386 0.02655082 ... 0.02623161 0.02462271 0.02641665]
[0.02560646 0.02708386 0.02655082 ... 0.02623161 0.02462271 0.02641665]
[0.02560646 0.02708386 0.02655082 ... 0.02623161 0.02462271 0.02641665]
[0.02560646 0.02708386 0.02655082 ... 0.02623161 0.02462271 0.02641665]
[0.02560646 0.02708386 0.02655082 ... 0.02623161 0.02462271 0.02641665]
[0.02560646 0.02708386 0.02655082 ... 0.02623161 0.02462271 0.02641665]
[0.02560646 0.02708386 0.02655082 ... 0.02623161 0.02462271 0.02641665]
[0.02560646 0.02708386 0.02655082 ... 0.02623161 0.02462271 0.02641665]
[0.02560646 0.02708386 0.02655082 ... 0.02623161 0.02462271 0.02641665]
[0.02560646 0.02708386 0.02655082 ... 0.02623161 0.02462271 0.02641665]
[0.02560646 0.02708386 0.02655082 ... 0.02623161 0.02462271 0.02641665]
[0.02560646 0.02708386 0.02655082 ... 0.02623161 0.02462271 0.02

Array([-0.01477796, -0.01367337, -0.01355522, ...,  0.02623161,
        0.02462271,  0.02641665], dtype=float32)

In [214]:
activations[0][0].shape.flatten()

(10000, 784)

In [211]:
for arr in activations:
  print(jnp.sum(jnp.concatenate([i.flatten() for i in arr])))

57.213707
111.53833
124.179
125.82974
125.49598
124.83792
124.19403
123.63144
123.15091
122.74477


In [200]:
key = jax.random.PRNGKey(42)
X_train, y_train, _, _ = load()
net = create_network([784, 500], 0.001, jax.nn.gelu)

In [None]:
out = train(key, net, X_train, y_train, 1, 2)



Epoch 0, loss: -1.1920928955078125e-05
Epoch 0, loss: -5.960464477539062e-07


In [178]:
predict(out, X_test, y_test)

[[-3.2968895e-04  5.9063575e-04 -5.5327767e-04 ... -4.3996162e-04
   4.0091891e-04 -4.5514942e-04]
 [-7.4209651e-04  4.6982433e-04 -6.4880645e-04 ... -1.1954390e-04
   7.6915481e-04 -6.3632877e-05]
 [-5.3363788e-04  4.9608975e-04 -3.2628627e-04 ... -5.1652413e-04
   4.8414589e-04 -3.3191088e-04]
 ...
 [-1.7386554e-04  6.4366229e-04 -4.5608421e-04 ... -7.8958255e-04
   8.2040555e-04 -3.2022904e-04]
 [-6.9002621e-04  7.4529066e-04 -4.4723458e-04 ... -3.5964299e-04
   6.8084250e-04 -2.2933687e-04]
 [-4.0271605e-04  5.5800460e-04 -3.7934104e-04 ... -5.6520308e-04
   8.3004928e-04 -3.6683548e-04]]
[[-4.7052230e-04  6.0838094e-04 -5.2044558e-04 ...  3.1496963e-04
   2.7054030e-04 -2.6022096e-04]
 [-1.5891490e-04  1.8861234e-04 -5.9090240e-04 ...  3.7103155e-04
   1.4261724e-04 -2.7947058e-04]
 [-4.7179338e-04  6.1404379e-04 -5.5733329e-04 ...  4.2096520e-04
   5.4968859e-04 -5.5224111e-04]
 ...
 [-1.6870211e-04  7.4912957e-04 -5.6288246e-04 ...  2.7740659e-04
   6.4104080e-04 -2.3615651e-04]

TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.

In [172]:
import flax
flax.__version__

'0.5.3'

In [173]:
import chex
chex.__version__

'0.1.5'