# Imports set up rendering

In [1]:
#@title Check if MuJoCo installation was successful

import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
# %env MUJOCO_GL=egl
os.environ['MUJOCO_GL'] = 'egl'
os.environ['PYOPENGL_PLATFORM'] = 'egl'
try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true "
    "--xla_gpu_triton_gemm_any=True "
    # "--xla_gpu_enable_async_collectives=true "
    # "--xla_gpu_enable_latency_hiding_scheduler=true "
    # "--xla_gpu_enable_highest_priority_async_stream=true "
)
# os.environ['CUDA_VISIBLE_DEVICES'] = '1' # Use GPU 1
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true " "--xla_gpu_triton_gemm_any=True "
)

Wed Oct 16 23:08:11 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02             Driver Version: 550.107.02     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A6000               On  |   00000000:41:00.0 Off |                  Off |
| 30%   40C    P8             18W /  300W |      11MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX A6000               On  |   00

In [3]:
%load_ext autoreload
%autoreload 2
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Use GPU 1
import functools
import jax
# jax.config.update("jax_enable_x64", True)

n_gpus = jax.device_count(backend="gpu")
print(f"Using {n_gpus} GPUs")
from typing import Dict
from brax import envs
import mujoco
import pickle
import warnings
import mediapy as media
import hydra
import jax.numpy as jp

from omegaconf import DictConfig, OmegaConf
from brax.training.agents.ppo import networks as ppo_networks
from custom_brax import custom_ppo as ppo
from custom_brax import custom_wrappers
from custom_brax import custom_ppo_networks
from orbax import checkpoint as ocp
from flax.training import orbax_utils
from preprocessing.mjx_preprocess import process_clip_to_train
from envs.Fly_Env_Brax import FlyTracking, FlyMultiClipTracking
from utils.utils import *
from utils.fly_logging import log_eval_rollout

warnings.filterwarnings("ignore", category=DeprecationWarning)
# jax.config.update("jax_enable_x64", True)

from hydra import initialize, compose
from hydra.core.hydra_config import HydraConfig
from hydra.core.global_hydra import GlobalHydra


Using 2 GPUs


2024-10-16 23:11:42.296400: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
  self.hub = sentry_sdk.Hub(client)


# Load configs

In [5]:
dataset = "run"
with initialize(version_base=None, config_path="configs"):
    cfg=compose(config_name='config.yaml',overrides= [f"dataset=fly_{dataset}", f"train=train_fly_{dataset}", "paths=walle"],return_hydra_config=True,)
    HydraConfig.instance().set_config(cfg)

In [6]:
for k in cfg.paths.keys():
    if (k != 'user'):
        cfg.paths[k] = Path(cfg.paths[k])
        cfg.paths[k].mkdir(parents=True, exist_ok=True)
env_cfg = cfg.dataset
env_args = cfg.dataset.env_args
# reference_path = cfg.paths.data_dir/ f"clips/all_clips_batch_interp.p"
reference_path = cfg.paths.data_dir/ f"clips/0.p"
reference_path.parent.mkdir(parents=True, exist_ok=True)

with open(reference_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)


# Load env

# Test retrain

In [None]:
from orbax import checkpoint as ocp
from flax.training import orbax_utils
import optax

In [66]:
dataset = 'run'

from envs.Fly_Env_Brax import FlyTracking, FlyMultiClipTracking, FlyRunSim, _bounded_quat_dist
with initialize(version_base=None, config_path="configs"):
    cfg=compose(config_name='config.yaml',overrides= [f"dataset=fly_{dataset}", f"train=train_fly_{dataset}", "paths=walle"],return_hydra_config=True,)
    HydraConfig.instance().set_config(cfg)
    

env_args = cfg.dataset.env_args
print(cfg.train.env_name)
envs.register_environment("fly_freejnt_clip", FlyTracking)
envs.register_environment("fly_freejnt_multiclip", FlyMultiClipTracking)
envs.register_environment("fly_run_policy", FlyRunSim)
# cfg_load = OmegaConf.load('/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt/run_id=21377073//logs/run_config.yaml')
# cfg_load.paths = cfg.paths
env = envs.get_environment(
    cfg.train.env_name,
    reference_clip=reference_clip,
    **cfg.dataset.env_args,
)

fly_run_policy
self._steps_for_cur_frame: 1


In [67]:
rollout_env = custom_wrappers.RenderRolloutWrapperTracking(env)

# rollout_env = custom_wrappers.RenderRolloutWrapperTracking(env)
# define the jit reset/step functions
jit_reset = jax.jit(rollout_env.reset)
jit_step = jax.jit(rollout_env.step)
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)


In [68]:
state.info['reference_obs_size']

Array(215, dtype=int32, weak_type=True)

In [50]:
model_path = Path('/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt/run_id=21377073/ckpt/21377073')
##### Get all the checkpoint files #####
ckpt_files = sorted([Path(f.path) for f in os.scandir(model_path) if f.is_dir()])
max_ckpt = ckpt_files[-1].as_posix()
env_args = cfg.dataset.env_args
print(max_ckpt)

/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt/run_id=21377073/ckpt/21377073/089


In [51]:
from custom_brax.custom_losses import PPONetworkParams
def policy_params_fn(num_steps, make_policy, params, policy_params_fn_key, model_path=model_path):
    # save checkpoints
    orbax_checkpointer = ocp.PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target(params)
    path = Path('/data/users/eabe/biomech_model/Flybody/RL_Flybody/ckpt/run_id=21356039/ckpt/Test_path') / f'{num_steps}'
    orbax_checkpointer.save(path, params, force=True, save_args=save_args)


episode_length = (env_args.clip_length - 50 - env_cfg.ref_traj_length) * env._steps_for_cur_frame
print(f"episode_length {episode_length}")
def create_mask():
    mask = {'params': {'encoder': 'encoder', 'decoder': 'decoder'}}
    value = {'params': 'encoder'}
    return PPONetworkParams(mask,value)

train_fn = functools.partial(
            ppo.train,
            num_envs=cfg.train["num_envs"],
            num_timesteps=0,
            num_evals=1,
            num_resets_per_eval=cfg.train['num_resets_per_eval'],
            reward_scaling=cfg.train['reward_scaling'],
            episode_length=episode_length,
            normalize_observations=True,
            action_repeat=cfg.train['action_repeat'],
            clipping_epsilon=cfg.train["clipping_epsilon"],
            unroll_length=cfg.train['unroll_length'],
            num_minibatches=cfg.train["num_minibatches"],
            num_updates_per_batch=cfg.train["num_updates_per_batch"],
            discounting=cfg.train['discounting'],
            learning_rate=cfg.train["learning_rate"],
            kl_weight=cfg.train["kl_weight"],
            entropy_cost=cfg.train['entropy_cost'],
            batch_size=cfg.train["batch_size"],
            seed=cfg.train['seed'],
            network_factory=functools.partial(
                custom_ppo_networks.make_intention_ppo_networks,
                encoder_hidden_layer_sizes=cfg.train['encoder_hidden_layer_sizes'],
                decoder_hidden_layer_sizes=cfg.train['decoder_hidden_layer_sizes'],
                value_hidden_layer_sizes=cfg.train['value_hidden_layer_sizes'],
            ),
            restore_checkpoint_path=max_ckpt,
        )

make_inference_fn, params, _= train_fn(environment=env,)
params2 = (params[0],params[1].policy)
# make_inference_fn, params, _= ppo.train(environment=env, num_timesteps=0, episode_length=1000, policy_params_fn=policy_params_fn, restore_checkpoint_path=ckpt_path / '3072000')

episode_length 946
passed_init


In [52]:
from brax.training.acme import running_statistics
from custom_brax.custom_losses import PPONetworkParams
def create_mask():
    mask = {'params': {'encoder': 'encoder', 'decoder': 'decoder'}}
    value = {'params': 'encoder'}
    return PPONetworkParams(mask,value)


In [11]:
def create_mask():
    mask = {'params': {'encoder': 'encoder', 'decoder': 'decoder'}}
    value = {'params': 'encoder'}
    return PPONetworkParams(mask,value)


In [12]:

episode_length = (env_args.clip_length - 50 - env_cfg.ref_traj_length) * env._steps_for_cur_frame
print(f"episode_length {episode_length}")


train_fn = functools.partial(
    ppo.train,
    num_envs=1,
    num_timesteps=10,
    num_evals=0,
    num_resets_per_eval=cfg.train['num_resets_per_eval'],
    reward_scaling=cfg.train['reward_scaling'],
    episode_length=100,
    normalize_observations=True,
    action_repeat=cfg.train['action_repeat'],
    clipping_epsilon=cfg.train["clipping_epsilon"],
    unroll_length=cfg.train['unroll_length'],
    num_minibatches=cfg.train["num_minibatches"],
    num_updates_per_batch=cfg.train["num_updates_per_batch"],
    discounting=cfg.train['discounting'],
    learning_rate=cfg.train["learning_rate"],
    kl_weight=cfg.train["kl_weight"],
    entropy_cost=cfg.train['entropy_cost'],
    batch_size=cfg.train["batch_size"],
    seed=cfg.train['seed'],
    network_factory=functools.partial(
        custom_ppo_networks.make_intention_ppo_networks,
        encoder_hidden_layer_sizes=cfg.train['encoder_hidden_layer_sizes'],
        decoder_hidden_layer_sizes=cfg.train['decoder_hidden_layer_sizes'],
        value_hidden_layer_sizes=cfg.train['value_hidden_layer_sizes'],
    ),
    restore_checkpoint_path=None,
    freeze_fn=create_mask,
)


episode_length 946.0


In [55]:
params[1].policy['params']['encoder']

{'params': {'decoder': {'LayerNorm_0': {'bias': Array([ 0.01731048, -0.00904714,  0.09407102, -0.02594654,  0.01593495,
            0.03715621,  0.00610518, -0.00967912,  0.02021122,  0.06619641,
            0.01199926,  0.01486713,  0.05610774,  0.01528364, -0.01477283,
            0.00311791, -0.02716998, -0.04215733,  0.01167444,  0.03843452,
           -0.03186993, -0.00058468,  0.0149497 , -0.03511928,  0.05585234,
            0.02927883,  0.048905  ,  0.02959033, -0.0074123 ,  0.00204154,
           -0.00644582, -0.02323694,  0.03025895,  0.02481595, -0.00344412,
            0.01790519, -0.01900788,  0.03540667, -0.03799602,  0.04787494,
           -0.00616974,  0.04134345, -0.01724523,  0.00967702, -0.0109723 ,
            0.02648342,  0.0207802 ,  0.07306565, -0.0298298 , -0.00095152,
            0.01424145, -0.02661853, -0.09939934,  0.01710664,  0.05070613,
           -0.05165924, -0.0133182 , -0.02271423, -0.03774946,  0.02539193,
            0.02962468,  0.02722426, -0.0109

In [53]:
import custom_brax.custom_losses as ppo_losses


In [69]:
network_factory=functools.partial(
                custom_ppo_networks.make_intention_ppo_networks,
                encoder_hidden_layer_sizes=cfg.train['encoder_hidden_layer_sizes'],
                decoder_hidden_layer_sizes=cfg.train['decoder_hidden_layer_sizes'],
                value_hidden_layer_sizes=cfg.train['value_hidden_layer_sizes'],
            )

normalize = lambda x, y: x

ppo_network = network_factory(
        state.obs.shape[-1],
        5,
        env.action_size,
        preprocess_observations_fn=normalize,
    )

In [70]:
init_params = ppo_losses.PPONetworkParams(
    policy=ppo_network.policy_network.init(rng),
    value=ppo_network.value_network.init(rng),
)

In [72]:
init_params.policy['params']['encoder']

{'hidden_0': {'kernel': Array([[ 0.5868452 , -0.4743108 , -0.5653391 , ...,  0.34651327,
           0.246423  , -0.5464664 ],
         [-0.6521018 ,  0.6345195 , -0.07838096, ...,  0.5168104 ,
           0.3124815 , -0.17869227],
         [-0.07024699,  0.4433011 , -0.5715655 , ...,  0.7282831 ,
           0.13873805,  0.65838903],
         [ 0.5081505 ,  0.17985278,  0.52154577, ..., -0.3119995 ,
           0.20252481, -0.45054638],
         [ 0.2406773 ,  0.5199413 , -0.6996059 , ...,  0.6365529 ,
          -0.5714238 ,  0.3969369 ]], dtype=float32),
  'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

In [None]:
init_params.policy['params']

In [47]:
from flax.core import frozen_dict
from flax.core.frozen_dict import FrozenDict

def create_mask(params, label_fn):
    mask = {}
    def _map(params, mask, label_fn):
        for k in params:
            if not label_fn(k):
                mask[k] = False
            else:
                if isinstance(params[k], FrozenDict):
                    mask[k] = {}
                    _map(params[k], mask[k], label_fn)
                else:
                    mask[k] = True
    _map(params, mask, label_fn)
    return frozen_dict.freeze(mask)


def print_tree(d, depth, print_value=False):
    for k in d.keys():
        if isinstance(d[k], FrozenDict):
            print('  ' * depth, k)
            print_tree(d[k], depth + 1, print_value)
        else:
            if print_value:
                print('  ' * depth, k, d[k])
            else:
                print('  ' * depth, k)


def compare_params(lhs, rhs, depth):
    for k in lhs.keys():
        if isinstance(lhs[k], FrozenDict):
            print('  ' * depth, k)
            compare_params(lhs[k], rhs[k], depth + 1)
        else:
            print('  ' * depth, k, jp.mean(jp.abs(lhs[k] - rhs[k])))

In [85]:
def create_mask():
    mask = {'params': {'encoder': 'encoder', 'decoder': 'decoder'}}
    value = {'params': 'encoder'}
    return PPONetworkParams(mask,value)


optimizer = optax.multi_transform({'encoder': optax.adam(learning_rate=1.0),'decoder': optax.set_to_zero()},
                           create_mask())
optimizer.init(init_params)
# state = train_state.TrainState.create(apply_fn=model.apply,
#                                       params=params,
#                                       tx=tx)

PartitionState(inner_states={'encoder': MaskedState(inner_state=(ScaleByAdamState(count=Array(0, dtype=int32), mu=PPONetworkParams(policy={'params': {'decoder': MaskedNode(), 'encoder': {'LayerNorm_0': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 

In [81]:
policy=ppo_network.policy_network.init(rng)
value=ppo_network.value_network.init(rng)

In [82]:
training_state = TrainingState(  # pytype: disable=wrong-arg-types  # jax-ndarray
    optimizer_state=optimizer.init(
        init_params
    ),  # pytype: disable=wrong-arg-types  # numpy-scalars
    params=init_params,
    normalizer_params=running_statistics.init_state(
        specs.Array(env_state.obs.shape[-1:], jnp.dtype("float32"))
    ),
    env_steps=0,
)



PartitionState(inner_states={'encoder': MaskedState(inner_state=(ScaleByAdamState(count=Array(0, dtype=int32), mu=PPONetworkParams(policy={'params': {'encoder': {'LayerNorm_0': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0

In [22]:

# def map_nested_fn(fn):
#     '''Recursively apply `fn` to key-value pairs of a nested dict.'''
#     def map_fn(nested_dict):
#         return {'set_zero': (map_fn(v) if (isinstance(v, dict) & k=='decoder') else fn('adam', v)) for k, v in nested_dict.items()}
    
    
# def map_nested_fn(fn):
#     '''Recursively apply `fn` to Parameters to freeze.'''
#     def map_fn(nested_val):
#         if isinstance(nested_val, dict):
#             return {k: (map_fn(v) if (if isinstance(nested_val, dict)) fn('set_zero',v) elif k=='decoder' else fn('adam', v)) for k, v in nested_val.items()}
#     return map_fn



def construct_key_value_pairs(fn):
    def map_func(leaf):
        if isinstance(leaf,dict):
            for k,v in leaf.items():
                return map_func(v)
        else:
            if leaf == 'decoder':
                return fn('decoder',leaf)
            else:
                return fn('adam',leaf)
    return map_func
label_fn = construct_key_value_pairs(lambda k,_: k )

optimizer = optax.multi_transform(
                    {
                     'adam': optax.adam(cfg.train["learning_rate"]), 
                     'set_zero': optax.set_to_zero()
                     },
                    label_fn
                    )

opt_state = optimizer.init(params[1].__dict__)



PartitionState(inner_states={'adam': MaskedState(inner_state=(ScaleByAdamState(count=Array(0, dtype=int32), mu={'policy': {'params': {'decoder': {'LayerNorm_0': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
     

In [124]:
optimizer = optax.multi_transform(
                    {
                     'adam': optax.adam(cfg.train["learning_rate"]), 
                     'decoder': optax.set_to_zero()
                     },
                    label_fn
                    )

opt_state = optimizer.init(params)

In [None]:
opt_state

In [212]:
optimizer = optax.adam(learning_rate=cfg.train['learning_rate'])
opt_state = optimizer.init(params[:2])

In [213]:
opt_state

(ScaleByAdamState(count=Array(0, dtype=int32), mu=(RunningStatisticsState(mean=Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), std=Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), count=Array(0., dtype=float32), summed_variance=Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)), PPONetworkParams(policy={'params': {'decoder': {'LayerNorm_0': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.