# A Notebook for Testing In-Context Learning ablility of trained sequencial policies

Note that this codebase is made for evaluations in mujoco environments, therefore probably not compatible with other environments

## Make environment, Set seed and device

In [15]:
import gymnasium as gym
from gymnasium.envs.registration import register
from torchkit.pytorch_utils import set_gpu_mode

seed = 42
gpu_id = 0
device = f'cuda:{gpu_id}'
set_gpu_mode(True, gpu_id)


ENTRY_POINTS = {"cheetah-vel": "envs.mujoco:HalfCheetahVelEnv", "ant-dir": "envs.mujoco:AntDirEnv", 
                "hopper-param": "envs.mujoco:HopperRandParamsEnv", "walker-param": "envs.mujoco:Walker2DRandParamsEnv"}
env_name = 'ant-dir'  # Example environment name
entry_point = ENTRY_POINTS[env_name]
register(
    env_name,
    entry_point=entry_point,
    max_episode_steps=200,
    kwargs=dict(terminate_when_unhealthy=True) if env_name not in ["cheetah-vel"] else {} # cheetah-vel does not have is_healthy
)

env = gym.make(env_name)
env.max_episode_steps = getattr(
    env, "max_episode_steps", env.spec.max_episode_steps
)
env.reset(seed=seed) # Set random seed
env.action_space.seed(seed)
env.observation_space.seed(seed)

  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


42

In [16]:
action_space = env.action_space
observation_space = env.observation_space

act_dim = action_space.shape[0]
obs_dim = observation_space.shape[0]

print("obs space", observation_space)
print("act space", action_space)
print("obs_dim", obs_dim, "act_dim", act_dim)

obs space Box(-inf, inf, (105,), float64)
act space Box(-1.0, 1.0, (8,), float32)
obs_dim 105 act_dim 8


## Instantiate agent and buffer

In [17]:
from policies.models.policy_rnn_shared import ModelFreeOffPolicy_Shared_RNN as Policy_Shared_RNN
agent_class = Policy_Shared_RNN
from configs.seq_models.gpt_default import get_config as get_config_gpt
from configs.seq_models.hist_default import get_config as get_config_hist
from configs.seq_models.lstm_default import get_config as get_config_lstm
from configs.rl.sac_default import get_config as get_config_rl

config_seq = get_config_gpt()
config_seq.seq_model.max_seq_length = env.max_episode_steps + 1
if config_seq.seq_model.name == "hist":
    config_seq.seq_model.init_emb_mode = "parameter" # For hist
config_rl = get_config_rl()

agent = agent_class(
    obs_dim=obs_dim,
    action_dim=act_dim,
    config_seq=config_seq,
    config_rl=config_rl,
    freeze_critic=True,
).to(device)


from buffers.rollout_buffer import RolloutBuffer
buffer = RolloutBuffer(observation_dim=obs_dim,
            action_dim=act_dim,
            max_episode_len=env.max_episode_steps,
            num_episodes=10000, # Not used in ICL testing
            normalize_transitions=True,
            is_ppo=False,
            )

{'h.0.ln_1.weight': torch.Size([128]), 'h.0.ln_1.bias': torch.Size([128]), 'h.0.attn.c_attn.weight': torch.Size([128, 384]), 'h.0.attn.c_attn.bias': torch.Size([384]), 'h.0.attn.c_proj.weight': torch.Size([128, 128]), 'h.0.attn.c_proj.bias': torch.Size([128]), 'h.0.ln_2.weight': torch.Size([128]), 'h.0.ln_2.bias': torch.Size([128]), 'h.0.mlp.c_fc.weight': torch.Size([128, 512]), 'h.0.mlp.c_fc.bias': torch.Size([512]), 'h.0.mlp.c_proj.weight': torch.Size([512, 128]), 'h.0.mlp.c_proj.bias': torch.Size([128]), 'ln_f.weight': torch.Size([128]), 'ln_f.bias': torch.Size([128])}
Normalize transitions: True


### Load checkpoints

In [18]:
import torch
agent_checkpoint_path = "./logs/mujoco/ant-dir/gpt_1_2025-12-06-22:00:20/policy_checkpoint_latest.pth"
agent.load_state_dict(torch.load(agent_checkpoint_path, map_location=device))
buffer_checkpoint_path = "./logs/mujoco/ant-dir/gpt_1_2025-12-06-22:00:20/buffer_checkpoint_latest.pth"
buffer.load_state_dict(torch.load(buffer_checkpoint_path, map_location=device))

## Run evaluation

### Define neccessary functions

In [19]:
def get_initial_dummies(env, obs):
    prev_obs = obs.clone()
    action = torch.FloatTensor([env.action_space.sample()]).to(device).reshape(1, -1)  # (1, A) for continuous action, (1, 1) for discrete action
    reward = torch.zeros((1, 1)).to(device)
    term = torch.zeros((1, 1)).to(device)
    return prev_obs,action,reward,term

In [20]:
def act(internal_state, action, reward, prev_obs, obs, deterministic, initial):
    if buffer.normalize_transitions:
        obs = buffer.observation_rms.norm(obs)
        prev_obs = buffer.observation_rms.norm(prev_obs)
        reward = buffer.rewards_rms.norm(reward)
    action, internal_state = agent.act(
        prev_internal_state=internal_state,
        prev_action=action,
        prev_reward=reward,
        prev_obs=prev_obs,
        obs=obs,
        deterministic=deterministic,
        initial=initial,
    )
    return action, internal_state

### Vanilla Evaluation Cell

In [21]:
import numpy as np

num_rollouts = 10
max_ep_len = env.max_episode_steps
returns = []

print("Vanilla Rollout Test")
print("Architecture:", config_seq.seq_model.name)
for rollout_idx in range(num_rollouts):
    done = False
    ep_return = 0.0
    t = 0
    obs, info = env.reset()
    obs = torch.from_numpy(obs).float().to(device).unsqueeze(0) # (1, obs_dim)
    prev_obs, action, reward, term = get_initial_dummies(env, obs)
    internal_state = None
    initial=True
    
    while not done and t < max_ep_len:
        action, internal_state = act(
            internal_state=internal_state,
            action=action,
            reward=reward,
            prev_obs=prev_obs,
            obs=obs,
            deterministic=True,
            initial=initial,
        )
        initial=False
        np_action = action.to("cpu").detach().numpy().squeeze(0) # (act_dim,)
        next_obs, reward, terminated, truncated, info = env.step(np_action)
        ep_return += reward
        next_obs = torch.from_numpy(next_obs).float().to(device).unsqueeze(0)  # (1, obs_dim)
        reward = torch.FloatTensor([[reward]]).to(device)  # (1, 1)
        done = terminated or truncated

        prev_obs = obs.clone()
        obs = next_obs.clone()
        t += 1

    returns.append(ep_return)
    print(f"Rollout {rollout_idx + 1}/{num_rollouts}, Return: {ep_return}")

returns = np.array(returns)
print(f"Return over {num_rollouts} rollouts: avg {np.mean(returns)}, std {np.std(returns)}")

Vanilla Rollout Test
Architecture: gpt
Rollout 1/10, Return: 767.9789963147491
Rollout 2/10, Return: 483.54009839051014
Rollout 3/10, Return: 463.68521597318636
Rollout 4/10, Return: 236.57234305020677
Rollout 5/10, Return: 642.2710001487184
Rollout 6/10, Return: 587.5448901108724
Rollout 7/10, Return: 944.7995966988555
Rollout 8/10, Return: 1028.2935457978717
Rollout 9/10, Return: 746.8665812540876
Rollout 10/10, Return: 673.5070801388183
Return over 10 rollouts: avg 657.5059347877876, std 221.49697124327716


### In Context Learning Evaluation Cell

In [22]:
import numpy as np
print("In Context Learning Rollout Test")
print("Architecture:", config_seq.seq_model.name)

num_rollouts = 10
max_ep_len = env.max_episode_steps
returns = []

total_rollouts = 2 * num_rollouts
for rollout_idx in range(total_rollouts):
    if rollout_idx % 2 == 0:
        keep_context = False
    else:
        keep_context = True
    done = False
    ep_return = 0.0
    t = 0
    obs, info = env.reset(options={'keep_context': keep_context})
    obs = torch.from_numpy(obs).float().to(device).unsqueeze(0) # (1, obs_dim)
    prev_obs, action, reward, term = get_initial_dummies(env, obs)
    if not keep_context:
        internal_state = None
        initial=True

    while not done and t < max_ep_len:
        action, new_internal_state = act(
            internal_state=internal_state,
            action=action,
            reward=reward,
            prev_obs=prev_obs,
            obs=obs,
            deterministic=True,
            initial=initial,
        )
        if t == 0 and keep_context:
            internal_state = internal_state # Internal state must be preserved at the dummy reset at t=0
        else:
            internal_state = new_internal_state
        initial=False
        np_action = action.to("cpu").detach().numpy().squeeze(0) # (act_dim,)
        next_obs, reward, terminated, truncated, info = env.step(np_action)
        ep_return += reward
        next_obs = torch.from_numpy(next_obs).float().to(device).unsqueeze(0)  # (1, obs_dim)
        reward = torch.FloatTensor([[reward]]).to(device)  # (1, 1)
        done = terminated or truncated

        prev_obs = obs.clone()
        obs = next_obs.clone()
        t += 1

    returns.append(ep_return)
    if not keep_context:
        print(f"Rollout {rollout_idx + 1}/{total_rollouts} (Demo agent), Return: {ep_return}")
    else:
        print(f"Rollout {rollout_idx + 1}/{total_rollouts} (In-Context agent), Return: {ep_return}")

returns = np.array(returns)

In Context Learning Rollout Test
Architecture: gpt
Rollout 1/20 (Demo agent), Return: 187.54519195144542
Rollout 2/20 (In-Context agent), Return: 58.77346646799395
Rollout 3/20 (Demo agent), Return: 652.6070923081727
Rollout 4/20 (In-Context agent), Return: 117.84690021426667
Rollout 5/20 (Demo agent), Return: 466.12637530942567
Rollout 6/20 (In-Context agent), Return: 117.93786456226694
Rollout 7/20 (Demo agent), Return: 783.9232823757804
Rollout 8/20 (In-Context agent), Return: 85.04459597740666
Rollout 9/20 (Demo agent), Return: 353.91509173610245
Rollout 10/20 (In-Context agent), Return: 98.979252761565
Rollout 11/20 (Demo agent), Return: 1158.353785926193
Rollout 12/20 (In-Context agent), Return: 75.37303808055076
Rollout 13/20 (Demo agent), Return: 416.1202213205586
Rollout 14/20 (In-Context agent), Return: 54.6367613971334
Rollout 15/20 (Demo agent), Return: 875.5109840309419
Rollout 16/20 (In-Context agent), Return: 122.20902937183098
Rollout 17/20 (Demo agent), Return: 835.353