In [90]:

import functools
from typing import Callable
from better_partial import _, partial
from jax import vmap


def compose(f, g):
    return lambda *args, **kw: g(f(*args, **kw))

class PointFreeFunction:
    def __init__(self, func):
        self.func = func
        functools.update_wrapper(self, func)
    
    def f(self, *args, **kw):
        return PointFreeFunction(partial(self.func)(*args, **kw))
    
    def vmap(self, *args, **kw):
        return PointFreeFunction(vmap(self.func, *args, **kw))

    def __call__(self, *args, **kw):
        return self.func(*args, **kw)
    
    def __rshift__(self, other: 'PointFreeFunction' | Callable):
        if isinstance(other, PointFreeFunction):
            return PointFreeFunction(compose(self.func, other.func))
        elif callable(other):
            return PointFreeFunction(compose(self.func, other))
        else:
            raise TypeError("other must be callable or PointFreeFunction")

F = PointFreeFunction

In [107]:
from jax.nn import sigmoid
from jax.numpy.linalg import norm
from jax.numpy import sqrt

In [114]:

def affine(x, W, b):
    return W.T @ x + b

def swish(x):
    return x * sigmoid(x)

def sglu(x, wv, wu, wo):
    v = x @ wv
    u = x @ wu
    return (v * u) @ wo

def rmsn(x, d):
    return  x / (norm(x)/ sqrt(d))

def mglu(x, ws):
    for w in ws:
        sglu_w = F(sglu).f(_, **w['sglu'])
        x = rmsn(sglu_w(x), **w['rmsn'])
    return x


In [123]:
# normal_init_std = 0.01
#
# #W1 = init_utils.zerO_init_2D((d_in, d_h1))
# W1 = init_utils.normal_init(next(key_gen), normal_init_std, (d_in, d_h1))
# b1 = zeros((d_h1))
# #W2 = init_utils.zerO_init_2D((d_h1, d_out))
# W2 = init_utils.normal_init(next(key_gen), normal_init_std, (d_h1, d_out))
# b2 = zeros((d_out))
import sys

sys.path.append("../berries")
import init_utils, random_utils
from init_utils import zerO_init_2D
from jax.random import split


def sglu_config(d_in, d_h, d_out, init):
    return {
        "wv": {
            "size": (d_in, d_h),
            "init": init,
        },
        "wu": {
            "size": (d_in, d_h),
            "init": init,
        },
        "wo": {
            "size": (d_h, d_out),
            "init": init,
        },
    }

def rmsn_config(d_out):
    return {
        "d": {
            "const": d_out,
        }
    }

def mglu_layer_config(d_in, d_h, d_out, init):
    return {
        "sglu": sglu_config(d_in, d_h, d_out, init),
        "rmsn": rmsn_config(d_out),
    }


def mglu_config(d_in, d_h_layer, d_out, d_h, n_layers, init):
    return tuple([
        mglu_layer_config(d_in, d_h, d_h_layer, init),
        *[mglu_layer_config(d_h_layer, d_h, d_h_layer, init)] * (n_layers - 1),
        mglu_layer_config(d_h_layer, d_h, d_out, init),
    ])



seed = 0
key_gen = random_utils.infinite_safe_keys(seed)


def init_weight(key, init, size):
    if init["type"] == "normal":
        return init_utils.normal_init(key, init["std"], size)
    elif init["type"] == "zer0":
        return zerO_init_2D(size)
    else:
        raise ValueError(f"Unknown init type: {init['type']}")


def init_weights(key, configs):
    if isinstance(configs, tuple):
        keys = key.split(len(configs))
        return tuple(init_weights(k, c) for c, k in zip(configs, keys))
    elif isinstance(configs, dict):
        if "const" in configs:
            return configs["const"]
        if "init" in configs:
            return init_weight(key, **configs)
        else:
            keys = key.split(len(configs))
            return {name: init_weights(k, config) for (name, config), k in zip(configs.items(), keys)}


In [19]:
from datasets import load_dataset

mnist = load_dataset("mnist").with_format("jax")
mnistData = mnist['train']
X_img = mnistData['image']
y = mnistData['label']
X_img_test = mnist["test"]["image"]
n_test_samples = X_img_test.shape[0]
X_test = X_img_test.reshape((n_test_samples, -1))
y_test = mnist["test"]["label"]
n_samples, _, _ = X_img.shape
X = X_img.reshape((n_samples, -1))
n_samples, d_in = X.shape
d_out = len(set(y.tolist()))

In [95]:
d_h_layer = 128
d_h = 64


In [127]:

import optax
from jax.numpy import mean
from jax import grad, jit
from jax.tree_util import tree_map
from better_partial import _

def accuracy(logits, y):
    return (logits.argmax(-1) == y).mean()


mglu_b = F(mglu).vmap((0, None, None), 0)

def sce_loss(to_logits, x, y, W):
    return optax.softmax_cross_entropy_with_integer_labels(to_logits(x, W), y)




loss_b_all = F(sce_loss).f(mglu, _, _, _).vmap((0, 0, None), 0)
loss_d = loss_b_all.f(X, y, _)
loss_b_d = loss_d >> mean

lr = 0.001
mask_fn = lambda p: tree_map(lambda x: not isinstance(x, int), p)
opt = optax.multi_transform({"sgd": optax.rmsprop(lr), "zero": optax.set_to_zero()}, mask_fn)


method = {"type": "zer0", "std": 0.001}
# method = {"type": "normal", "std": 0.01}
W = init_weights(next(key_gen), mglu_config(d_in, d_h_layer, d_out, d_h, 2, method))
loss0 = loss_b_d(W)
print(loss0)
state = opt.init(W)

@jit
def update(W, opt_state):
    grads = grad(loss_b_d)(W)
    updates, opt_state = opt.update(grads, opt_state)
    new_W = optax.apply_updates(W, updates)
    return new_W, opt_state

2.303737


In [128]:

def get_accuracy(x, y, W):
    return accuracy(mglu(x, W), y)

get_accuracy_b = F(get_accuracy).vmap(in_axes=(0, 0, None), out_axes=0)
get_accuracy_b_d = get_accuracy_b.f(X, y, _) >> mean
get_accuracy_b_t = get_accuracy_b.f(X_test, y_test, _) >> mean

for i in range(1000):
    W, state = update(W, state)
    if i % 50 == 0:
        print(get_accuracy_b_d(W), get_accuracy_b_t(W), loss_b_d(W))
print(loss_d(W))

TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.