In [1]:
import numpy as np 
import jax
import jax.numpy as jnp 
import flax
import flax.nnx 
import optax
import orbax.checkpoint as ocp

In [2]:
class sampleCnn(flax.nnx.Module):
    def __init__(self,
                 rngs,
                 *args, **kwargs):
        super(flax.nnx.Module, self).__init__(*args, **kwargs)
        
        self.rngs = rngs
        self.trainableLayers = []
        
        def append_trainable_layer(l):
            self.trainableLayers.append(l)
            return(l)
        
        # input: [n, 128, 128, 3]
        self.conv1 = append_trainable_layer(flax.nnx.Conv(3, 16, (3,3), strides=(2,2), rngs=self.rngs))
        self.conv2 = append_trainable_layer(flax.nnx.Conv(16, 32, (3,3), strides=(2,2), rngs=self.rngs))
        self.conv3 = append_trainable_layer(flax.nnx.Conv(32, 64, (3,3), strides=(2,2), rngs=self.rngs))
        
        
    def __call__(self, x):
        c1 = self.conv1(x)
        c2 = self.conv2(c1)
        c3 = self.conv3(c2)
        
        return c3
        
modelA = sampleCnn(rngs=flax.nnx.Rngs(0))
modelB = sampleCnn(rngs=flax.nnx.Rngs(1))

In [3]:

modelAState = flax.nnx.state(modelA)
params = modelAState.filter(flax.nnx.Param)
newParams = jax.tree_util.tree_map(lambda x: x + 1, params)
# newParamsTree = jax.tree_util.tree_map(lambda x, x_bran: flax.nnx.Param(x_bran),
#                                        params, newParams)
newState = flax.nnx.State.merge(modelAState, newParams)
flax.nnx.update(modelA, newState)
flax.nnx.state(modelA)


State({
  'conv1': {
    'bias': VariableState( # 16 (64 B)
      type=Param,
      value=Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],      dtype=float32)
    ),
    'kernel': VariableState( # 432 (1.7 KB)
      type=Param,
      value=Array([[[[1.2059559 , 0.81337595, 0.845183  , 0.7616481 , 0.8203863 ,
                1.1222979 , 1.1498961 , 0.78981054, 1.3247747 , 0.64166176,
                0.7395842 , 1.0278939 , 0.765144  , 1.0499297 , 1.3375859 ,
                1.1050476 ],
               [1.1235734 , 0.9015963 , 1.1884761 , 1.0752815 , 0.8684114 ,
                0.8205348 , 1.1568873 , 1.2174608 , 1.1870264 , 0.8975728 ,
                1.1830642 , 1.3221965 , 0.5833026 , 1.0919819 , 1.348174  ,
                0.76131094],
               [0.87772715, 1.3884116 , 1.1656785 , 0.9587893 , 0.6631649 ,
                0.8303846 , 0.5929531 , 1.3070254 , 1.2815622 , 1.0730343 ,
                1.0130435 , 1.2337084 , 0.95566255, 0.92918885, 1.142758  ,
 

In [4]:
# 思考邏輯: 
# 1. 解開state
# 2. 從state中取得parameter tree
# 3. 修改paramter tree
# 4. 把新的paramter tree跟原本的state進行merge => 必須要用merge的念建
# 5. 把新的state跟model進行update

modelAState = flax.nnx.state(modelA)
modelBState = flax.nnx.state(modelB)
modelAParams = modelAState.filter(flax.nnx.Param)
modelBParams = modelBState.filter(flax.nnx.Param)
newTree = jax.tree_util.tree_map(lambda x, y: x * .9 + y * .1,
                             modelAParams, modelBParams)
newState = flax.nnx.State.merge(modelAState, newTree)
flax.nnx.update(modelA, newState)
flax.nnx.state(modelA)



State({
  'conv1': {
    'bias': VariableState( # 16 (64 B)
      type=Param,
      value=Array([0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9,
             0.9, 0.9, 0.9], dtype=float32)
    ),
    'kernel': VariableState( # 432 (1.7 KB)
      type=Param,
      value=Array([[[[1.0964199 , 0.71706724, 0.77934545, 0.6659444 , 0.70190716,
                1.0385334 , 1.0454845 , 0.7050905 , 1.2033985 , 0.57091224,
                0.6861553 , 0.9125151 , 0.6717621 , 0.9454768 , 1.169522  ,
                1.0119697 ],
               [1.0005788 , 0.78959876, 1.0693257 , 0.9882534 , 0.77783   ,
                0.70466906, 1.0835153 , 1.0974486 , 1.0890632 , 0.8162275 ,
                1.0681878 , 1.2046207 , 0.4971938 , 1.0016893 , 1.2281232 ,
                0.7165636 ],
               [0.8184405 , 1.2443447 , 1.0514735 , 0.8800377 , 0.60160524,
                0.71886575, 0.53412575, 1.143737  , 1.1807564 , 0.99013287,
                0.9304484 , 1.0840666 , 0.8463092 , 0