In [1]:
import functools
from pprint import pprint
from typing import *

import flax
import flax.traverse_util
import jax
import numpy as np
import orbax.checkpoint as orbax
from flax import nnx
from jax import lax
from jax import numpy as jnp
from jax import tree_util as jtu
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec

key = jax.random.key(0)

In [2]:
class TwoLayerMLP(nnx.Module):
    def __init__(self, dim, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
        self.linear2 = nnx.Linear(dim, dim, rngs=rngs)

    def __call__(self, x):
        x = self.linear1(x)
        return self.linear2(x)


In [3]:
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
print(np.allclose(model(x), model.linear2(model.linear1(x))))

True


In [5]:
k = nnx.PathContains("linear1")
gdef, temp, state = nnx.split(model, k, nnx.Param)

In [6]:
temp
state


State({
  'linear1': {
    'bias': VariableState(
      type=Param,
      value=Array([0., 0., 0., 0.], dtype=float32)
    ),
    'kernel': VariableState(
      type=Param,
      value=Array([[-0.8034531 , -0.34071913, -0.94082963,  0.0100597 ],
             [ 0.26146442,  1.1247739 ,  0.5456374 , -0.37416402],
             [ 1.0281807 , -0.6798804 , -0.14884011,  0.0569495 ],
             [-0.44308174, -0.60587114,  0.434087  , -0.40541086]],      dtype=float32)
    )
  }
})

In [5]:
out = gdef.apply(state, x)
print(out)

<flax.nnx.nnx.proxy_caller.CallableProxy object at 0x79ac15309330>


In [6]:
nnx.display(state)

In [32]:
import optax

model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))

gdef, state = nnx.split(model)


# 定义一个函数来过滤梯度
def filter_gradients(grads):
    def filter_fn(path, value):
        # 检查路径中是否包含 'linear1'
        # print(str(path))
        if "linear1" in str(path):
            return jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), value)
            # print(111)
            # return None
        return value

    return jtu.tree_map_with_path(filter_fn, grads)


optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(state)


In [39]:
@jax.jit
def update_step(params, x, y, opt_state):
    def loss_fn(params):
        y_pred = nnx.merge(gdef, params)(x)
        return jnp.mean((y_pred - y) ** 2)

    loss, grads = jax.value_and_grad(loss_fn)(params)
    # print(grads)
    filtered_grads = filter_gradients(grads)
    # print(filtered_grads)
    updates, opt_state = optimizer.update(filtered_grads, opt_state)
    # print(updates)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss


# 生成一些随机数据用于训练
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (10, 4))
y = jax.random.normal(key, (10, 4))

# 训练循环
num_steps = 100
for step in range(num_steps):
    state, opt_state, loss = update_step(state, x, y, opt_state)
    if step % 10 == 0:
        print(f"Step {step}, Loss: {loss}")

initial_linear1_params = jax.tree_util.tree_leaves(model.linear1)[0]
final_linear1_params = jax.tree_util.tree_leaves(nnx.merge(gdef, state).linear1)[0]


Step 0, Loss: 2.1327953338623047
Step 10, Loss: 1.6399110555648804
Step 20, Loss: 1.2668898105621338
Step 30, Loss: 0.9820842146873474
Step 40, Loss: 0.7653554081916809
Step 50, Loss: 0.6038841605186462
Step 60, Loss: 0.48465976119041443
Step 70, Loss: 0.39579707384109497
Step 80, Loss: 0.32791709899902344
Step 90, Loss: 0.2745225131511688


In [40]:
final_linear1_params.kernel


Param(
  value=Array([[-0.8034531 , -0.34071913, -0.94082963,  0.0100597 ],
         [ 0.26146442,  1.1247739 ,  0.5456374 , -0.37416402],
         [ 1.0281807 , -0.6798804 , -0.14884011,  0.0569495 ],
         [-0.44308174, -0.60587114,  0.434087  , -0.40541086]],      dtype=float32)
)

In [41]:
initial_linear1_params.kernel


Param(
  value=Array([[-0.8034531 , -0.34071913, -0.94082963,  0.0100597 ],
         [ 0.26146442,  1.1247739 ,  0.5456374 , -0.37416402],
         [ 1.0281807 , -0.6798804 , -0.14884011,  0.0569495 ],
         [-0.44308174, -0.60587114,  0.434087  , -0.40541086]],      dtype=float32)
)