In [1]:

import functools
from typing import Callable
from better_partial import _, partial
from jax import vmap


def compose(f, g):
    return lambda *args, **kw: g(f(*args, **kw))

class PointFreeFunction:
    def __init__(self, func):
        self.func = func
        functools.update_wrapper(self, func)
    
    def f(self, *args, **kw):
        return PointFreeFunction(partial(self.func)(*args, **kw))
    
    def vmap(self, *args, **kw):
        return PointFreeFunction(vmap(self.func, *args, **kw))

    def __call__(self, *args, **kw):
        return self.func(*args, **kw)
    
    def __rshift__(self, other: 'PointFreeFunction' | Callable):
        if isinstance(other, PointFreeFunction):
            return PointFreeFunction(compose(self.func, other.func))
        elif callable(other):
            return PointFreeFunction(compose(self.func, other))
        else:
            raise TypeError("other must be callable or PointFreeFunction")

F = PointFreeFunction

In [2]:
from jax.nn import sigmoid
from jax.numpy.linalg import norm
from jax.numpy import sqrt

In [3]:

def affine(x, W, b):
    return W.T @ x + b

def swish(x):
    return x * sigmoid(x)

def sglu(x, wv, wu, wo):
    v = x @ wv
    u = x @ wu
    return (v * u) @ wo

def rmsn(x, d):
    return  x / (norm(x)/ sqrt(d))

def mglu(x, ws):
    for w in ws:
        sglu_w = F(sglu).f(_, **w['sglu'])
        x = rmsn(sglu_w(x), **w['rmsn'])
    return x


In [38]:
# normal_init_std = 0.01
#
# #W1 = init_utils.zerO_init_2D((d_in, d_h1))
# W1 = init_utils.normal_init(next(key_gen), normal_init_std, (d_in, d_h1))
# b1 = zeros((d_h1))
# #W2 = init_utils.zerO_init_2D((d_h1, d_out))
# W2 = init_utils.normal_init(next(key_gen), normal_init_std, (d_h1, d_out))
# b2 = zeros((d_out))
import sys

sys.path.append("../berries")
import init_utils, random_utils
from init_utils import zerO_init_2D
from jax.random import split


def sglu_config(d_in, d_h, d_out, init):
    return {
        "wv": {
            "size": (d_in, d_h),
            "init": init,
        },
        "wu": {
            "size": (d_in, d_h),
            "init": init,
        },
        "wo": {
            "size": (d_h, d_out),
            "init": init,
        },
    }

def rmsn_config(d_out):
    return {
        "d": {
            "const": float(d_out),
        }
    }

def mglu_layer_config(d_in, d_h, d_out, init):
    return {
        "sglu": sglu_config(d_in, d_h, d_out, init),
        "rmsn": rmsn_config(d_out),
    }


def mglu_config(d_in, d_h_layer, d_out, d_h, n_layers, init):
    return tuple([
        mglu_layer_config(d_in, d_h, d_h_layer, init),
        *[mglu_layer_config(d_h_layer, d_h, d_h_layer, init)] * (n_layers - 1),
        mglu_layer_config(d_h_layer, d_h, d_out, init),
    ])



seed = 0
key_gen = random_utils.infinite_safe_keys(seed)


def init_weight(key, init, size):
    if init["type"] == "normal":
        return init_utils.normal_init(key, init["std"], size)
    elif init["type"] == "zer0":
        return zerO_init_2D(size)
    else:
        raise ValueError(f"Unknown init type: {init['type']}")


def init_weights(key, configs):
    if isinstance(configs, tuple):
        keys = key.split(len(configs))
        return tuple(init_weights(k, c) for c, k in zip(configs, keys))
    elif isinstance(configs, dict):
        if "const" in configs:
            return configs["const"]
        if "init" in configs:
            return init_weight(key, **configs)
        else:
            keys = key.split(len(configs))
            return {name: init_weights(k, config) for (name, config), k in zip(configs.items(), keys)}


def fmt_weights(weights, indent=0):
    out_str = ""
    out_n_params = 0
    if isinstance(weights, tuple):
        for w in weights:
            sub_out_str, sub_n_params = fmt_weights(w, indent + 4)
            out_str += f"{' ' * indent}tuple:\n{sub_out_str}"
            out_n_params += sub_n_params
        out_str += f"{' ' * indent}total params: {out_n_params}\n"
    elif isinstance(weights, dict):
        for name, w in weights.items():
            sub_out_str, sub_n_params = fmt_weights(w, indent + 4)
            out_str += f"{' ' * indent}{name}:\n{sub_out_str}"
            out_n_params += sub_n_params
        out_str += f"{' ' * indent}total params: {out_n_params}\n"
    else:
        if hasattr(weights, 'shape'):
            out_n_params += weights.size
            out_str += f"{' ' * indent}array shape: {weights.shape}\n"
        else:
            out_n_params += 1
            out_str += f"{' ' * indent}{weights}\n"
    return out_str, out_n_params


In [35]:
from datasets import load_dataset

mnist = load_dataset("mnist").with_format("jax")
mnistData = mnist['train']
X_img = mnistData['image']
y = mnistData['label']
X_img_test = mnist["test"]["image"]
n_test_samples = X_img_test.shape[0]
X_test = X_img_test.reshape((n_test_samples, -1))
y_test = mnist["test"]["label"]
n_samples, _, _ = X_img.shape
X = X_img.reshape((n_samples, -1))
n_samples, d_in = X.shape
d_out = len(set(y.tolist()))

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [28]:
d_h_layer = 128
d_h = 64
n_layers = 1


In [39]:

import optax
from jax.numpy import mean
from jax import grad, jit
from jax.tree_util import tree_map
from better_partial import _

def accuracy(logits, y):
    return (logits.argmax(-1) == y).mean()


mglu_b = F(mglu).vmap((0, None, None), 0)

def sce_loss(to_logits, x, y, W):
    return optax.softmax_cross_entropy_with_integer_labels(to_logits(x, W), y)




loss_b_all = F(sce_loss).f(mglu, _, _, _).vmap((0, 0, None), 0)
loss_d = loss_b_all.f(X, y, _)
loss_b_d = loss_d >> mean

lr = 0.001
mask_fn = lambda p: tree_map(lambda x: not isinstance(x, int), p)
# opt = optax.multi_transform({"sgd": optax.rmsprop(lr), "zero": optax.set_to_zero()}, mask_fn)
opt = optax.masked(optax.rmsprop(lr), mask_fn)


method = {"type": "zer0", "std": 0.001}
# method = {"type": "normal", "std": 0.01}
W = init_weights(next(key_gen), mglu_config(d_in, d_h_layer, d_out, d_h, n_layers, method))
loss0 = loss_b_d(W)
print(loss0)
print(fmt_weights(W)[0])
state = opt.init(W)

@jit
def update(W, opt_state):
    grads = grad(loss_b_d)(W)
    updates, opt_state = opt.update(grads, opt_state)
    new_W = optax.apply_updates(W, updates)
    return new_W, opt_state

3.1588132
tuple:
    sglu:
        wv:
            array shape: (784, 64)
        wu:
            array shape: (784, 64)
        wo:
            array shape: (64, 128)
        total params: 108544
    rmsn:
        d:
            128.0
        total params: 1
    total params: 108545
tuple:
    sglu:
        wv:
            array shape: (128, 64)
        wu:
            array shape: (128, 64)
        wo:
            array shape: (64, 10)
        total params: 17024
    rmsn:
        d:
            10.0
        total params: 1
    total params: 17025
total params: 125570



In [37]:

def get_accuracy(x, y, W):
    return accuracy(mglu(x, W), y)

get_accuracy_b = F(get_accuracy).vmap(in_axes=(0, 0, None), out_axes=0)
get_accuracy_b_d = get_accuracy_b.f(X, y, _) >> mean
get_accuracy_b_t = get_accuracy_b.f(X_test, y_test, _) >> mean

for i in range(1000):
    W, state = update(W, state)
    if i % 50 == 0:
        print(get_accuracy_b_d(W), get_accuracy_b_t(W), loss_b_d(W))
print(loss_d(W))

0.09871667 0.098 2.9003065
0.9199167 0.923 0.66302013
0.9577 0.95239997 0.43886253
0.97125 0.9625 0.37138352
0.97688335 0.96819997 0.35107404
0.9789 0.9691 0.34038138
0.98256665 0.9693 0.32953045
0.9867833 0.97249997 0.31877425
0.98845 0.97389996 0.31075525
0.99121666 0.9736 0.3010577
0.98918337 0.9719 0.30412602
0.9899833 0.96999997 0.3012402
0.9920667 0.9711 0.2938076
0.9943 0.9741 0.2851047
0.99365 0.97309995 0.28448004
0.99635 0.974 0.27580982
0.9953167 0.9738 0.27615047
0.9942167 0.97239995 0.27758375
0.99805003 0.9745 0.26383737
0.99738336 0.9729 0.265202
[0.25933945 0.24239717 0.25948963 ... 0.24418825 0.2602984  0.25811207]
