In [1]:
import sys

sys.path.append("../berries")


In [4]:
from datasets import load_dataset
cache_dir="$HOME/.cache/huggingface/datasets"
mnist = load_dataset("mnist", cache_dir=cache_dir, trust_remote_code=True).with_format("jax")
mnistData = mnist['train']
# X_img = mnistData['image']
# y = mnistData['label']
# X_img_test = mnist["test"]["image"]
# n_test_samples = X_img_test.shape[0]
# y_test = mnist["test"]["label"]
# n_samples, _, _  = X_img.shape
# # X_train = X_img.reshape((n_samples, 1, 28, 28))
# # X_test = X_img_test.reshape((n_test_samples, 1, 28, 28))
# n_channels = 1
# d_x = (28, 28)
# d_y = len(set(y.tolist()))

In [3]:
n_test_samples = X_test.shape[0]

In [4]:
X = X / 255.0
X_test = X_test / 255.0

In [5]:
d_mglu_h_layer = 64
d_mglu_h = 64
n_mglu_layers = 2

n_mixer_layers = 2
d_mixer_channels = 16

d_encode_hidden = 32

from jax.numpy import meshgrid, arange
pos_x1, pos_x2 = meshgrid(arange(d_x[0]), arange(d_x[1]))
pos_x1 = pos_x1.flatten()

pos_x2 = pos_x2.flatten()

import random_utils
seed = 0
key_gen = random_utils.infinite_safe_keys(seed)

In [6]:
from jax.numpy import array, exp, mean
import nn
import importlib
importlib.reload(nn)
import pf
importlib.reload(pf)
from nn import mglu_net_config, mglu_net, sglu, sglu_config
from pf import F, _
import optax

h_axis = 1
x_axis = 0
d_h_axis = d_mixer_channels
d_x_axis = d_x[0] * d_x[1]

def gaussian_activation(a, x):
    return exp((-0.5 * x ** 2) / a ** 2)

def pos_encode(W, x1, x2, v):
    rep = array([x1, x2, v])
    l1 = W['0'] @ rep
    a = W['a']
    a1 = gaussian_activation(a, l1)
    return W['1'] @ a1

def mixer_head(W, x):
    val_flat = x.reshape(-1)
    return F(pos_encode).vmap((None, 0, 0, 0), 0)(W, pos_x1, pos_x2, val_flat)


def mixer_block(W, X):
    mixer_h  = F(mglu_net).f(_, W['h']).vmap(x_axis, x_axis)
    mixer_x = F(mglu_net).f(_, W['x']).vmap(h_axis, h_axis)
    return X + mixer_h(X) + mixer_x(X)
    

def mixer(W, x):
    X = mixer_head(W['head'], x)
    for _ in range(n_mixer_layers):
        X = mixer_block(W['block'], X)
    return X

def mixer_net(W, x):
    X = mixer(W['mixer'], x) 
    return F(sglu).f(_, **W['out'])(X.sum(axis=x_axis))

mglu_net_conf_local = F(mglu_net_config).f(_, d_mglu_h_layer, _, d_mglu_h, n_mglu_layers, _)

def mixer_block_config(init):
    return {
        'h': mglu_net_conf_local(d_h_axis, d_h_axis, init),
        'x': mglu_net_conf_local(d_x_axis, d_x_axis, init)
    }

def mixer_head_config(init):
    return {
        'a': {
            "size": (d_encode_hidden,),
            "const": 1.0
        },
        '0': {
            "size": ( d_encode_hidden, 3),
            "init": init
        },
        '1': {
            "size": (d_mixer_channels, d_encode_hidden),
            "init": init
        }
    }


def mixer_config(init):
    return {
        'mixer': {
            'head': mixer_head_config(init),
            'block': mixer_block_config(init),
        },
        'out': sglu_config(d_mixer_channels, d_mglu_h,  d_y, init)
    }

def loss_1(W, x, y):
    return optax.softmax_cross_entropy_with_integer_labels(mixer_net(W, x), y)

loss_batch = F(loss_1).vmap((None, 0, 0), 0) >> F(mean)

In [7]:
import optax
from jax import grad, jit
from jax.tree_util import tree_map
from nn import init_weights, fmt_weights
import init_utils



lr = 0.0001
mask_fn = lambda p: tree_map(lambda x: not isinstance(x, int), p)
# opt = optax.multi_transform({"sgd": optax.rmsprop(lr), "zero": optax.set_to_zero()}, mask_fn)
opt = optax.masked(optax.adam(lr), mask_fn)

#method = {"type": "zer0", "std": 0.001}
method = {"type": "normal", "std": 0.01}
# method = {"type": "normal", "std": 0.1}
W = init_weights(next(key_gen), mixer_config(method))
print(fmt_weights(W)[0])
loss0 = loss_batch(W, X_test[:100, :], y_test[:100])
print(loss0)
state = opt.init(W)

@jit
def update(W, x, y, opt_state):
    grads = grad(loss_batch)(W, x, y)
    updates, opt_state = opt.update(grads, opt_state)
    new_W = optax.apply_updates(W, updates)
    return new_W, opt_state

mixer:
    head:
        a:
            1.0
        0:
            array shape: (32, 3)
        1:
            array shape: (16, 32)
        total params: 609
    block:
        h:
            mglu:
                tuple:
                    sglu:
                        wv:
                            array shape: (16, 64)
                        wu:
                            array shape: (16, 64)
                        wo:
                            array shape: (64, 64)
                        total params: 6144
                    rmsn:
                        d:
                            64.0
                        total params: 1
                    total params: 6145
                tuple:
                    sglu:
                        wv:
                            array shape: (64, 64)
                        wu:
                            array shape: (64, 64)
                        wo:
                            array shape: (64, 64)
                        tot

In [8]:
from plot_utils import visualize_matrix
from IPython.display import display
import math, random
import jax.numpy as np


def accuracy(logits, y):
    return (logits.argmax(-1) == y).mean()

def get_accuracy(x, y, W):
    return accuracy(mixer_net(W, x), y)

batch_size = 128

train_index = random.sample(range(n_samples), batch_size * 2)
test_index = random.sample(range(n_test_samples), batch_size * 2)

get_accuracy_b = F(get_accuracy).vmap(in_axes=(0, 0, None), out_axes=0)
get_accuracy_b_d = get_accuracy_b.f(X[train_index, :], y[array(train_index)], _) >> mean
# get_accuracy_b_t = get_accuracy_b.f(X_test[test_index, :], y_test[array(test_index)], _) >> mean
get_accuracy_b_t = get_accuracy_b.f(X_test, y_test, _) >> mean

def loss_b_dt(W):
    # return mean(loss_batch(W, X_test[test_index, :], yy[array(test_index)]))
    return mean(loss_batch(W, X_test, y_test))



def sample():
    index = random.sample(range(n_samples), batch_size)
    return X[index, :], y[array(index)]

for i in range(5000):
    xx, yy = sample()
    W, state = update(W, xx, yy, state)
    if i % 200 == 0:
        print(get_accuracy_b_d(W), get_accuracy_b_t(W), loss_b_dt(W))


print(loss_b_dt(W))


2024-06-11 13:35:25.323531: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.88GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2023817216 bytes.