In [1]:
import numpy as np 
import jax
import jax.numpy as jnp 
import flax.nnx 
import optax
import orbax.checkpoint as ocp

In [2]:
class sampleCnn(flax.nnx.Module):
    def __init__(self,
                 rngs,
                 *args, **kwargs):
        super(flax.nnx.Module, self).__init__(*args, **kwargs)
        
        self.rngs = rngs
        
        # input: [n, 128, 128, 3]
        self.conv1 = flax.nnx.Conv(3, 16, (3,3), strides=(2,2), rngs=self.rngs)
        self.conv2 = flax.nnx.Conv(16, 32, (3,3), strides=(2,2), rngs=self.rngs)
        self.conv3 = flax.nnx.Conv(32, 64, (3,3), strides=(2,2), rngs=self.rngs)
        
    def __call__(self, x):
        c1 = self.conv1(x)
        c2 = self.conv2(c1)
        c3 = self.conv3(c2)
        
        return c3
        

In [3]:
modelA = sampleCnn(rngs=flax.nnx.Rngs(0))
modelB = sampleCnn(rngs=flax.nnx.Rngs(1))
sampleInput = np.random.random([3, 128, 128, 3])
modelA(sampleInput) != modelB(sampleInput)

Array([[[[ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         ...,
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True]],

        [[ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         ...,
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True]],

        [[ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  True, ...,  True,  True,  True],
         ...,
         [ True,  True,  True, ...,  True,  True,  True],
         [ True,  True,  T

In [4]:
# saving model weights
graphdef, state = flax.nnx.split(modelA)

# orbax使用物件管理參數
ocpOptions = ocp.CheckpointManagerOptions(
    save_interval_steps=2,
    max_to_keep=2,
    # other options
)
checkpointer = ocp.CheckpointManager(
    '/workspaces/byol_jax/basic_ops/checkPoints/',
    options=ocpOptions,
    )

# validating how the orbax running
for test_step in range(10):
    checkpointer.save(step=test_step, args=ocp.args.StandardSave(state))
checkpointer.wait_until_finished() # orbax使用非同步執行儲存，所以需要等待他全部存完


# tutorial的方式，一定要使用絕對路徑。好像不是太直觀，因為會使用Test工具把原本的資料夾刪除
# ckpt_dir = ocp.test_utils.erase_and_create_empty('/workspaces/byol_jax/basic_ops/checkPoints/') # test_utils是方便測試用的，平常不太需要用
# checkpointer = ocp.StandardCheckpointer()
# checkpointer.save(ckpt_dir / 'state', state)



ERROR:absl:[process=0] Failed to run 1 Handler Commit operations or the Commit callback in background save thread, directory: /workspaces/byol_jax/basic_ops/checkPoints/2
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 207, in _thread_func
    _background_wait_for_commit_futures(
  File "/usr/local/lib/python3.11/dist-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 125, in _background_wait_for_commit_futures
    on_commit_callback()
  File "/usr/local/lib/python3.11/dist-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 403, in _callback
    _on_commit_callback(
  File "/usr/local/lib/python3.11/dist-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 52, in _on_commit_callback
    atomicity.on_commit_callback(
  File "/usr/local/lib/python3.11/dist-packages/orbax/checkpoint/_src/path/atomicity.py", line 563, in on_c

PermissionError: [Errno 13] Permission denied: '/workspaces/byol_jax/basic_ops/checkPoints/2.orbax-checkpoint-tmp-2' -> '/workspaces/byol_jax/basic_ops/checkPoints/2'

In [None]:
graphdef, state = flax.nnx.split(modelB)

# 製作orbax物件讀取數據
ocpOptions = ocp.CheckpointManagerOptions(
    save_interval_steps=2,
    max_to_keep=2,
    # other options
    )
checkpointerLoader = ocp.CheckpointManager(
    '/workspaces/byol_jax/basic_ops/checkPoints/',
    options=ocpOptions,
    )

# 取得模型結構
# targetModelStateAbstract = jax.tree_map(ocp.tree.to_shape_dtype_struct, state)

restoreState = checkpointerLoader.restore(checkpointerLoader.latest_step())

modelB = flax.nnx.merge(graphdef, restoreState)

modelA(sampleInput) != modelB(sampleInput)





Array([[[[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]],

        [[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]],

        [[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, Fa