In [1]:
import jax
import jax.numpy as np
import jax.random as jr

In [2]:
from jax.tree_util import register_pytree_node_class

In [14]:
from collections import namedtuple

Children = namedtuple("children", "a b")

class ParameterProp:
    def __init__(self, name, trainable=True):
        self.name = name
        self.trainable = True

@register_pytree_node_class
class Model:
    def __init__(self, a, b):
        self.a = ssm.Variable(a)
        self.b = b
        
        self.parameter_properties = dict(
            a = ParameterProp("a", True),
            b = ParameterProp("b", False),
        )
        
    def __call__(self, x):
        return self.a * x + self.b * x
    
    def tree_flatten(self):
        children = Children(a=self.a, b=self.b)
        aux_data = self.parameter_properties
        # aux_data = None
        return children, aux_data
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        a, b = children
        parameter_properties = aux_data
        obj = cls(a, b)
        obj.parameter_properties = parameter_properties
        return obj
    
@jax.jit
def f(x):
    return x

model = Model(1, 2)
out = f(model)

In [None]:
a = ssm.Variable(a, group=0)

model(a, 

In [15]:
model.tree_flatten()

(children(a=1, b=2),
 {'a': <__main__.ParameterProp at 0x7fa8214fdcd0>,
  'b': <__main__.ParameterProp at 0x7fa8214fdfd0>})

In [16]:
from jax.tree_util import tree_flatten
# print(out.tree_flatten())
print(tree_flatten(out))
# print(ravel_pytree(out))

([DeviceArray(1, dtype=int32, weak_type=True), DeviceArray(2, dtype=int32, weak_type=True)], PyTreeDef(CustomNode(<class '__main__.Model'>[{'a': <__main__.ParameterProp object at 0x7fa8214fdcd0>, 'b': <__main__.ParameterProp object at 0x7fa8214fdfd0>}], [*, *])))


In [17]:
import equinox as eqx
import jax.random as jr
import jax.numpy as np
import jax

model_key = jr.PRNGKey(0)

# model = eqx.nn.MLP(
#     in_size=1, out_size=1, width_size=1, depth=1, key=model_key
# )

class ParameterProp:
    def __init__(self, name, trainable=True):
        self.name = name
        self.trainable = True

@register_pytree_node_class
class Model:
    def __init__(self, a, b):
        self.a = a
        self.b = b
        
        self.parameter_properties = dict(
            a = ParameterProp("a", True),
            b = ParameterProp("b", False),
        )
        
    def __call__(self, x):
        return self.a * x + self.b * x
    
    def tree_flatten(self):
        children = (self.a, self.b)
        aux_data = self.parameter_properties
        return children, aux_data
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        a, b = children
        parameter_properties = aux_data
        obj = cls(a, b)
        obj.parameter_properties = parameter_properties
        return obj
    
model = Model(np.array([1.]), np.array([5.]))


# `filter_jit` and `filter_value_and_grad` are thin wrappers around the usual `jax` functions, that automatically
# inspect the arguments of the function, JIT with respect to all JAX arrays, and differentiate with respect to all
# floating point JAX arrays (i.e. the parameters of the model).
# @eqx.filter_jit
@eqx.filter_value_and_grad
def loss(model, x, y):
    pred_y = model(x)
    return np.mean((y - pred_y) ** 2)

val, grad = loss(model, np.ones((1,)), np.ones((1,)))
print(grad.a, grad.b)

[10.] [10.]


In [18]:
@jax.grad
def f(model, x):
    return model.a * x + model.b * x

grad = f(model, 1)
print(grad.a, grad.b)

TypeError: Gradient only defined for scalar-output functions. Output had shape: (1,).

In [19]:
import jax
import jax.numpy as np
import jax.random as jr

In [20]:
from jax.tree_util import register_pytree_node_class

In [21]:
from typing import Any, Callable, List, Optional, Tuple, Union
from typing_extensions import get_args
import functools as ft
import jax

PyTree = any

def _make_filter_tree(mask: Union[bool, Callable[[Any], bool]], arg: Any) -> bool:
    if isinstance(mask, bool):
        return mask
    elif callable(mask):
        return jax.tree_map(mask, arg)
    else:
        raise ValueError("`filter_spec` must consist of booleans and callables only.")

def filter(pytree: PyTree, filter_spec: PyTree, inverse: bool = False, replace: Any = None) -> PyTree:
    inverse = bool(inverse)  # just in case, to make the != trick below work reliably
    filter_tree = jax.tree_map(_make_filter_tree, filter_spec, pytree)
    return jax.tree_map(
        lambda mask, x: x if bool(mask) != inverse else replace, filter_tree, pytree
    )


def partition(pytree: PyTree, filter_spec: PyTree, replace: Any = None) -> PyTree:
    filter_tree = jax.tree_map(_make_filter_tree, filter_spec, pytree)
    left = jax.tree_map(lambda mask, x: x if mask else replace, filter_tree, pytree)
    right = jax.tree_map(lambda mask, x: replace if mask else x, filter_tree, pytree)
    return left, right

def _combine(*args):
    for arg in args:
        if arg is not None:
            return arg
    return None

def _is_none(x):
    return x is None

def combine(*pytrees: PyTree):
    pytrees = [pytree for pytree in pytrees if pytree is not None]
    return jax.tree_map(_combine, *pytrees, is_leaf=_is_none)

def filter_value_and_grad(
    fun, *, filter_spec=lambda x: True, argnums=None, **gradkwargs
):
    if argnums is not None:
        raise ValueError(
            "`argnums` should not be passed. If you need to differentiate "
            "multiple objects then collect them into a tuple and pass that "
            "as the first argument."
        )

    @ft.partial(jax.value_and_grad, argnums=0, **gradkwargs)
    def fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs):
        x = combine(diff_x, nondiff_x)
        return fun(x, *args, **kwargs)

    def fun_value_and_grad_wrapper(x, *args, **kwargs):
        diff_x, nondiff_x = partition(x, filter_spec)
        return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)

    return fun_value_and_grad_wrapper

In [22]:
from collections import namedtuple

class ParameterProp:
    def __init__(self, name, trainable=True):
        self.name = name
        self.trainable = trainable
    def __repr__(self):
        return f"<Parameter name='{self.name}' trainable={self.trainable}>"
        
@register_pytree_node_class
class Model:
    def __init__(self, a, b, c):

        self.parameters = dict(a=a, b=b, c=c)
        self.a = a
        self.b = b
        self.c = c
        
        self.parameter_properties = dict(
            a = ParameterProp("a", trainable=True),
            b = ParameterProp("b", trainable=False),
            c = ParameterProp("c", trainable=False),
        )
        
    def __repr__(self):
        return f"<Model a={self.a} b={self.b} c={self.c}>"
    
    def tree_flatten(self): --> (trainable_unconstrained_params, static), aux_data
        params = []
        static = []
        
        for param in params:
            if param.trainable == True:
                put it in params
            else:
                put it in static
        
        for param, prop in self.parameter_properties.items():
            if prop.trainable:
                params.append(self.parameters[param])
            else:
                static.append(self.parameters[param])
        children = (params, static)
        aux_data = self.parameter_properties
        return children, aux_data
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        params, static = children
        parameter_properties = aux_data
        obj = cls(*params, *static)
        # obj.parameter_properties = parameter_properties
        return obj

In [23]:
model = Model(1., 3., 5.)

In [24]:
pytree, structure = tree_util.tree_flatten(model)

NameError: name 'tree_util' is not defined

In [25]:
structure.walk(lambda x: x, lambda x: x+10, np.arange(structure.num_leaves))

NameError: name 'structure' is not defined

In [399]:
@jax.jit
def f(x):
    return x

out = f(model)
out

<Model a=1.0 b=3.0 c=5.0>

In [396]:
from jax import tree_util
pytree, pytree_structure = tree_util.tree_flatten(model)
left, right = partition(model, True)

### Question:

How to use parameter properties to lookup boolean for `filter_spec`?

Issue: we need same `aux_data` for `tree_map` to work...

In [332]:
params, static = partition(model, filter_spec=mask)

In [26]:
model = Model(1., 3.)
mask = Model(True, False)

def objective(model):
    return model.a * model.b

print("regular function")
value, grad = filter_value_and_grad(objective, filter_spec=mask)(model)
print(grad.a)
print(grad.b)

print()
print("jitted function")
value, grad = jax.jit(filter_value_and_grad(objective, filter_spec=mask))(model)
print(grad.a)
print(grad.b)

TypeError: __init__() missing 1 required positional argument: 'c'