# A Notebook for Visualizing memory of trained sequencial policies

Note that this codebase is made for evaluations in mujoco environments, therefore probably not compatible with other environments

## Set log directory and import configs

In [None]:
import yaml
class DotDict(dict):
    """d.a 처럼 접근 가능한 dict. 중첩도 재귀로 변환."""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

    def to_dict(self):
        def _plain(x):
            if isinstance(x, dict):
                return {k: _plain(v) for k, v in x.items()}
            if isinstance(x, list):
                return [_plain(v) for v in x]
            return x
        return _plain(self)

def to_dotdict(x):
    if isinstance(x, dict):
        return DotDict({k: to_dotdict(v) for k, v in x.items()})
    if isinstance(x, list):
        return [to_dotdict(v) for v in x]
    return x

def unwrap_value_nodes(obj):
    if isinstance(obj, dict):
        if set(obj.keys()) == {"value"}:
            return unwrap_value_nodes(obj["value"])
        return {k: unwrap_value_nodes(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [unwrap_value_nodes(x) for x in obj]
    return obj

def load_cfg(config_path):
    with open(config_path, "r", encoding="utf-8") as f:
        raw = yaml.safe_load(f)

    raw = unwrap_value_nodes(raw)

    # 필요한 3개만 추출
    picked = {k: raw[k] for k in ["config_env", "config_rl", "config_seq"]}

    # 점 접근 가능하게 변환
    return to_dotdict(picked)

In [None]:
import os

log_dir = "SET_YOUR_LOG_DIR_HERE" # Example: "logs/mujoco/ant-dir/run_name_2026-01-05-17:27:32"
if log_dir == "SET_YOUR_LOG_DIR_HERE":
    raise ValueError("Please set the 'log_dir' variable to your actual log directory path.")

config_dir = os.path.join(log_dir, "wandb/latest-run/files/config.yaml")


cfg = load_cfg(config_dir)
config_env = cfg.config_env
config_rl = cfg.config_rl
config_seq = cfg.config_seq
print("Environment Config:", config_env)
print("RL Config:", config_rl)
print("Sequence Model Config:", config_seq)

## Make environment, Set seed and device

In [None]:
import gymnasium as gym
from gymnasium.envs.registration import register
from torchkit.pytorch_utils import set_gpu_mode

seed = 42
gpu_id = 0
device = f'cuda:{gpu_id}'
set_gpu_mode(True, gpu_id)


ENTRY_POINTS = {"cheetah-vel": "envs.mujoco:HalfCheetahVelEnv", "ant-dir": "envs.mujoco:AntDirEnv", 
                "hopper-param": "envs.mujoco:HopperRandParamsEnv", "walker-param": "envs.mujoco:Walker2DRandParamsEnv"}
env_name = config_env.env_name
entry_point = ENTRY_POINTS[env_name]
register(
    env_name,
    entry_point=entry_point,
    max_episode_steps=200,
    kwargs=dict(terminate_when_unhealthy=config_env.terminate_when_unhealthy) if env_name not in ["cheetah-vel"] else {} # cheetah-vel does not have is_healthy
)

env = gym.make(env_name)
env.max_episode_steps = getattr(
    env, "max_episode_steps", env.spec.max_episode_steps
)
env.reset(seed=seed) # Set random seed
env.action_space.seed(seed)
env.observation_space.seed(seed)

In [None]:
action_space = env.action_space
observation_space = env.observation_space

act_dim = action_space.shape[0]
obs_dim = observation_space.shape[0]

print("obs space", observation_space)
print("act space", action_space)
print("obs_dim", obs_dim, "act_dim", act_dim)

## Instantiate agent and buffer

In [None]:
from policies.models.policy_rnn_shared import ModelFreeOffPolicy_Shared_RNN as Policy_Shared_RNN
agent_class = Policy_Shared_RNN

agent = agent_class(
    obs_dim=obs_dim,
    action_dim=act_dim,
    config_seq=config_seq,
    config_rl=config_rl,
    freeze_critic=True,
).to(device)


from buffers.rollout_buffer import RolloutBuffer
buffer = RolloutBuffer(observation_dim=obs_dim,
            action_dim=act_dim,
            max_episode_len=env.max_episode_steps,
            num_episodes=10000, # Not used in ICL testing
            normalize_transitions=config_rl.normalize_transitions,
            is_ppo=False,
            )

### Load checkpoints

In [None]:
import torch
agent_checkpoint_path = os.path.join(log_dir, "policy_checkpoint_latest.pth")
agent.load_state_dict(torch.load(agent_checkpoint_path, map_location=device))
if config_rl.normalize_transitions:
    buffer_checkpoint_path = os.path.join(log_dir, "buffer_checkpoint_latest.pth")
    buffer.load_state_dict(torch.load(buffer_checkpoint_path, map_location=device))

## Run evaluation

### Define neccessary functions

In [None]:
def get_initial_dummies(env, obs):
    prev_obs = obs.clone()
    action = torch.FloatTensor([env.action_space.sample()]).to(device).reshape(1, -1)  # (1, A) for continuous action, (1, 1) for discrete action
    reward = torch.zeros((1, 1)).to(device)
    term = torch.zeros((1, 1)).to(device)
    return prev_obs,action,reward,term

In [None]:
def act(internal_state, action, reward, prev_obs, obs, deterministic, initial):
    if buffer.normalize_transitions:
        obs = buffer.observation_rms.norm(obs)
        prev_obs = buffer.observation_rms.norm(prev_obs)
        reward = buffer.rewards_rms.norm(reward, scale=False)
    action, internal_state = agent.act(
        prev_internal_state=internal_state,
        prev_action=action,
        prev_reward=reward,
        prev_obs=prev_obs,
        obs=obs,
        deterministic=deterministic,
        initial=initial,
    )
    return action, internal_state

### Policy Rollout

In [None]:
import numpy as np

num_rollouts = 500
max_ep_len = env.max_episode_steps
returns = []
memories = []
contexts = []

print("Rollout Start")
print("Architecture:", config_seq.seq_model.name)
for rollout_idx in range(num_rollouts):
    done = False
    ep_return = 0.0
    t = 0
    obs, info = env.reset()
    obs = torch.from_numpy(obs).float().to(device).unsqueeze(0) # (1, obs_dim)
    prev_obs, action, reward, term = get_initial_dummies(env, obs)
    internal_state = None
    initial=True
    
    while not done and t < max_ep_len:
        action, internal_state = act(
            internal_state=internal_state,
            action=action,
            reward=reward,
            prev_obs=prev_obs,
            obs=obs,
            deterministic=True,
            initial=initial,
        )
        initial=False
        np_action = action.to("cpu").detach().numpy().squeeze(0) # (act_dim,)
        next_obs, reward, terminated, truncated, info = env.step(np_action)
        ep_return += reward
        next_obs = torch.from_numpy(next_obs).float().to(device).unsqueeze(0)  # (1, obs_dim)
        reward = torch.tensor([[reward]], dtype=torch.float32).to(device)  # (1, 1)
        done = terminated or truncated

        prev_obs = obs.clone()
        obs = next_obs.clone()
        t += 1

    returns.append(ep_return)
    hidden = agent.head.seq_model.internal_state_to_hidden(internal_state)[0][0] # (1, 1, hidden_size) -> (hidden_size,)
    if config_seq.project_output:
        hidden = hidden / torch.linalg.vector_norm(hidden).clamp(min=1e-6) * np.sqrt(len(hidden))
    memories.append(hidden.detach().cpu().numpy())
    contexts.append(info["context"])
    print(f"Rollout {rollout_idx + 1}/{num_rollouts}, Return: {ep_return}")

returns = np.array(returns)
memories = np.array(memories)  # (num_rollouts, hidden_size)
contexts = np.array(contexts)  # (num_rollouts, context_dim)
print(f"Return over {num_rollouts} rollouts: avg {np.mean(returns)}, std {np.std(returns)}")
print(f"memories shape: {memories.shape}, contexts shape: {contexts.shape}")

### TSNE Visualization Cell

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

name = f'tsne_{env_name}_{config_seq.seq_model.name}'

# Initialize t-SNE (n_components=2 for 2D visualization)
# random_state is recommended for reproducibility
tsne = TSNE(n_components=2, random_state=0)

# Fit and transform the data to the low-dimensional embedding
memories_embedded = tsne.fit_transform(memories)

plt.figure(figsize=(10, 8))

plt.scatter(
    memories_embedded[:, 0],
    memories_embedded[:, 1],
    c=contexts[:, 0],      # 0th index of context is sufficient for cheetah-vel (May change based on environment)
    s=12,                  # Marker size
)

# Add a colorbar to visualize the colormap scale
plt.colorbar(label="context")

plt.title(name)
plt.tight_layout()
plt.show()
