In [24]:
import os
import torch as th
import numpy as np
import json 
from datetime import datetime
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import motornet as mn
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.cross_decomposition import CCA
import math
from torch.distributions import Normal
from gymnasium import Wrapper
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.callbacks import ProgressBarCallback


In [2]:
from dynamics_analysis import *
from environment import *
from networks import *
from neural_activity import *
from utils import *
from training import *

# Archive

In [77]:
class L1LossCallback(BaseCallback):
    def __init__(self, verbose=0, log_path="./logs/"):
        super().__init__(verbose)
        self.log_path = log_path
        self.losses = []

    def _on_step(self) -> bool:
        # Info dict for each parallel env
        infos = self.locals.get("infos", [])

        for info in infos:
            if 'goal' in info and 'fingertip' in info:
                goal = np.array(info['goal'])
                fingertip = np.array(info['fingertip'])

                l1_loss = np.abs(goal - fingertip).sum()
                self.losses.append(l1_loss)

                if self.verbose > 0:
                    print(f"L1 loss: {l1_loss:.4f}")

        return True

    def _on_training_end(self):
        # Optional: Save the loss values for later
        os.makedirs(self.log_path, exist_ok=True)
        np.save(os.path.join(self.log_path, "l1_losses.npy"), np.array(self.losses))

In [None]:
class MotorNetBatchLoadWrapper(Wrapper):
    """
    Ensures actions have shape (batch, n_muscles) and supplies
    default zero loads for endpoint and joints on every step().
    """

    def __init__(self, env: mn.environment.RandomTargetReach):
        super().__init__(env)
        # Probe once to figure out load shapes:
        obs, info = env.reset(options={"batch_size": 1})
        geom_state = info["states"]["geometry"]  # (1, n_points, 2)
        joint_state = info["states"]["joint"]    # (1, n_joints)

        # For endpoint_load we need one value per coordinate axis (x,y)
        self.endpoint_dim = geom_state.shape[2]  # = 2
        # For joint_load we need one per joint
        self.n_joints = info["states"]["joint"].shape[1] // 2
        print(f"Detected: endpoint_dim={self.endpoint_dim}, n_joints={self.n_joints}, n_muscles={self.n_muscles}")
        # For actions we need one per muscle
        self.n_muscles   = env.n_muscles


    def step(self, action):
        action = np.array(action, dtype=np.float32)
        if action.ndim == 1:
            action = action[None, :]  # shape (1, act_dim)

        batch = action.shape[0]

        endpoint_load = np.zeros((batch, self.endpoint_dim), dtype=np.float32)
        joint_load = np.zeros((batch, self.n_joints), dtype=np.float32)

        obs, _, terminated, truncated, info = self.env.step(
            action=action,
            endpoint_load=endpoint_load,
            joint_load=joint_load
        )

        target = info["goal"]
        fingertip = info["states"]["fingertip"]
        reward = -np.linalg.norm(fingertip - target, axis=1)  # shape (batch,)

        # Use the first sample
        obs = obs[0]
        reward = float(reward[0])
        terminated = bool(terminated)
        truncated = bool(truncated)
        info = {k: (v[0] if isinstance(v, np.ndarray) and v.shape[0] == batch else v) for k, v in info.items()}

        return obs, reward, terminated, truncated, info
     

In [None]:
class CustomACPolicy(ActorCriticPolicy):
    def __init__(self, observation_space, action_space, lr_schedule, **kwargs):
        device = kwargs.get('device', 'cpu')
        kwargs.pop('device')
        super().__init__(observation_space, action_space, lr_schedule, **kwargs)
        self._build_network(device)
        self._hidden_states = None

    def _build_network(self, device):
        input_dim = self.observation_space.shape[0]
        self.action_dim = self.action_space.shape[0]
        self.hidden_dim = 64  # Match your GRU hidden size

        # Actor (GRU-based) and Critic
        self.actor = Policy(input_dim, self.hidden_dim, self.action_dim, device)
        self.critic = Critic(input_dim, device)

    def reset_hidden(self, batch_size=1):
        self._hidden_states = self.actor.init_hidden(batch_size)

    def forward(self, obs, deterministic=False):
        if self._hidden_states is None or self._hidden_states.size(1) != obs.shape[0]:
            self.reset_hidden(obs.shape[0])
        
        mean, new_hidden = self.actor(obs, self._hidden_states)
        self._hidden_states = new_hidden.detach()

        # Define action distribution
        std = th.ones_like(mean) * 0.1  # or learnable
        dist = th.distributions.Normal(mean, std)

        # Sample action
        actions = dist.mean if deterministic else dist.sample()
        log_probs = dist.log_prob(actions).sum(dim=-1, keepdim=True)

        values = self.critic(obs)

        return actions, values, log_probs

    def evaluate_actions(self, obs, actions):
        if self._hidden_states is None or self._hidden_states.size(1) != obs.shape[0]:
            self.reset_hidden(obs.shape[0])

        mean, new_hidden = self.actor(obs, self._hidden_states)
        self._hidden_states = new_hidden.detach()

        std = th.ones_like(mean) * 0.1
        dist = th.distributions.Normal(mean, std)

        log_probs = dist.log_prob(actions).sum(dim=-1, keepdim=True)
        entropy = dist.entropy().sum(dim=-1, keepdim=True)
        values = self.critic(obs)

        if values.ndim == 1:
            values = values.unsqueeze(1)

        return values, log_probs, entropy
    
    
def PPO_train():
    # Create the MotorNet environment
    effector = mn.effector.RigidTendonArm26(muscle=mn.muscle.MujocoHillMuscle())
    env = mn.environment.RandomTargetReach(effector=effector, max_ep_duration=5.)

    # 2. Wrap it
    wrapped_env = MotorNetBatchLoadWrapper(env)

    #sanity check
    print("Wrapped action_space:", wrapped_env.action_space)
    print("Wrapped observation_space:", wrapped_env.observation_space)
    
    # Instantiate the custom PPO model
    model = PPO(
        CustomACPolicy,
        wrapped_env,
        learning_rate=3e-4,
        n_steps=10000,  # Adjust based on your environment's requirements
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        verbose=1,
        policy_kwargs={
            'device': 'cuda:0' if th.cuda.is_available() else 'cpu'
        }
    )
    callback = L1LossCallback(verbose=1)
    # Train the model
    model.learn(total_timesteps=1000, callback=callback)
    
    return model

In [79]:
trained_model = PPO_train()


Detected: endpoint_dim=6, n_joints=2, n_muscles=6
Wrapped action_space: Box(0.0, 1.0, (6,), float32)
Wrapped observation_space: Box(-inf, inf, (16,), float32)
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 501      |
|    ep_rew_mean     | -180     |
| time/              |          |
|    fps             | 160      |
|    iterations      | 1        |
|    time_elapsed    | 62       |
|    total_timesteps | 10000    |
---------------------------------


In [88]:
def evaluate_pretrained(policy, env, batch_size):
    """Evaluation function with hidden state management"""
    # Reset hidden states
    policy.hidden_states = policy.actor.init_hidden(batch_size)
    
    # Initialize environment
    obs, info = env.reset(options={"batch_size": batch_size})
    terminated = np.zeros(batch_size, dtype=bool)
    xy = [info["states"]["fingertip"][:, None, :]]
    tg = [info["goal"][:, None, :]]

    # Run evaluation episode
    while not terminated.all():
        action, _, _, _ = policy(obs, deterministic=True)  # Deterministic actions
        obs, _, terminated, _, info = env.step(action)
        xy.append(info["states"]["fingertip"][:, None, :])
        tg.append(info["goal"][:, None, :])

    # Plot results
    xy = th.cat(xy, axis=1).detach().numpy()
    tg = th.cat(tg, axis=1).detach().numpy()
    plot_simulations(xy=xy, target_xy=tg)

In [89]:
# Create the MotorNet environment
effector = mn.effector.RigidTendonArm26(muscle=mn.muscle.MujocoHillMuscle())
env = mn.environment.RandomTargetReach(effector=effector, max_ep_duration=5.)

# 2. Wrap it
wrapped_env = MotorNetBatchLoadWrapper(env)
device = th.device("cpu")
trained_model.policy.to(device)
evaluate_pretrained(trained_model.policy, wrapped_env, batch_size=1)


Detected: endpoint_dim=6, n_joints=2, n_muscles=6


RuntimeError: Input and hidden tensors are not at the same device, found input tensor at cpu and hidden tensor at cuda:0

# The one

In [10]:
DEVICE = "cuda:0" if th.cuda.is_available() else "cpu"
n_batch = 1000
batch_size = 128
total_timesteps = n_batch * batch_size

from environment import *
from utils import *
    
arm = Arm('arm26')
env = CustomTargetReach(
    effector=arm.effector, 
    obs_noise=0.0,
    proprioception_noise=0.0,
    vision_noise=0.0,
    action_noise=0.0
    )

In [4]:
class Policy(th.nn.Module):

    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, device, sigma=0.1):
        super().__init__()
        self.device = device
        self.hidden_dim = hidden_dim
        self.n_layers = 1
        self.sigma = sigma
        self.noise = th.zeros(output_dim, device = device)
        self.timestep_counter = 0
        self.resample_threshold = np.random.randint(16, 25)
        self.gru = nn.GRU(input_dim, hidden_dim, 1, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

        # Apply custom initialization
        for name, param in self.named_parameters():
            
            if "gru" in name:
                if "weight_ih" in name:
                    nn.init.orthogonal_(param)
                elif "weight_hh" in name:
                    nn.init.orthogonal_(param)
                elif "bias_ih" in name:
                    nn.init.zeros_(param)
                elif "bias_hh" in name:
                    nn.init.zeros_(param)
            elif "fc" in name:
                if "weight" in name:
                    nn.init.orthogonal_(param)
                elif "bias" in name:
                    nn.init.constant_(param, -10.0)
            else:
                raise ValueError(f"Unexpected parameter: {name}")
        
        self.to(device)

    def forward(self, x, h0):
        y, h = self.gru(x.unsqueeze(1), h0)  
        u = self.sigmoid(self.fc(y.squeeze(1))) 

        # Apply periodic Gaussian noise
        self.timestep_counter += 1
        if self.timestep_counter >= self.resample_threshold:
            self.resample_noise()

        return u + self.noise, h
    
    def resample_noise(self):

        self.noise = th.randn_like(self.noise) * self.sigma
        self.timestep_counter = 0
        self.resample_threshold = np.random.randint(16, 25)
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(self.device)
        return hidden
    
class Critic(th.nn.Module):
    def __init__(self, input_size, device):
        super().__init__()
        self.device = device
        
        # Define network layers properly
        self.fc1 = nn.Linear(input_size, 100)
        self.fc2 = nn.Linear(100, 64)
        self.fc3 = nn.Linear(64, 1)
        
        # Activation functions
        self.tanh = nn.Tanh()

        # Apply orthogonal initialization
        self._initialize_weights()

        self.to(self.device)

    def _initialize_weights(self):
        """Apply orthogonal initialization with gain=1 and bias=0."""
        for layer in [self.fc1, self.fc2, self.fc3]:
            nn.init.orthogonal_(layer.weight, gain=1)  # Orthogonal init with gain=1
            nn.init.zeros_(layer.bias)  # Bias = 0

    def forward(self, x):
        x = self.tanh(self.fc1(x))
        x = self.tanh(self.fc2(x))
        return self.fc3(x)

In [None]:
class DummyExtractor(nn.Module):
    def __init__(self, latent_dim_pi: int, latent_dim_vf: int):
        super().__init__()
        # SB3 only needs these for sizing:
        self.latent_dim_pi = latent_dim_pi
        self.latent_dim_vf = latent_dim_vf

    def forward(self, features):
        # SB3 never actually calls it during inference for MLP policies,
        # but it expects two outputs.
        return features, features
    
class StableNetwork(ActorCriticPolicy):

    def __init__(self, observation_space, action_space, lr_schedule, **kwargs):
        device = kwargs.get('device', 'cpu')
        kwargs.pop('device')
        self.hidden_dim = 64
        self.hidden_states = None
        super().__init__(observation_space, action_space, lr_schedule, **kwargs)

        self.actor = Policy(self.observation_space.shape[0], self.hidden_dim, self.action_space.shape[0], device)
        self.critic = Critic(self.observation_space.shape[0], device)
        self.to(device)
        self.sigma = 0.01

    def _build_mlp_extractor(self):
        # Replace the old Dummy with our Module‐based one:
        self.mlp_extractor = DummyExtractor(
            latent_dim_pi=self.hidden_dim,
            latent_dim_vf=self.hidden_dim,
        )

    def forward(self, obs, deterministic=False):
        """
        Given observations, returns:
          - actions  (batch, act_dim)
          - values   (batch, 1)
          - log_prob (batch,)
        """
        # ensure obs on correct device
        obs = obs.to(self.device)

        # init hidden if first call
        batch_size = obs.shape[0]
        if self.hidden_states is None:
            self.hidden_states = self.actor.init_hidden(batch_size)

        # actor forward
        mean_actions, new_hidden = self.actor(obs, self.hidden_states)
        self.hidden_states = new_hidden.detach()  # detach BPTT

        # critic forward
        values = self.critic(obs)

        # build distribution
        log_std = th.tensor(self.sigma, device=self.device).expand_as(mean_actions)
        dist = DiagGaussianDistribution(self.action_space.shape[0])
        dist = dist.proba_distribution(mean_actions, log_std)

        # sample or mode
        actions = dist.mode() if deterministic else dist.sample()
        log_prob = dist.log_prob(actions)
        
        return actions, values, log_prob
    
    def evaluate_actions(self, obs, actions, hidden_states=None):
        """
        Used by SB3 to compute loss:
          - returns values, log_prob(actions), entropy
        """
        obs = obs.to(self.device)
        batch_size = obs.shape[0]

         # Handle GRU hidden state properly
        if hidden_states is not None:
            h = hidden_states
        else:
            # fallback: make sure self.hidden_states has correct batch size
            if self.hidden_states is None or self.hidden_states.shape[1] != batch_size:
                h = th.zeros(1, batch_size, self.actor.hidden_dim, device=self.device)
            else:
                h = self.hidden_states

        mean_actions, _ = self.actor(obs, h)
        values = self.critic(obs)
        
        log_std = th.tensor(self.sigma, device=self.device).expand_as(mean_actions)
        dist = DiagGaussianDistribution(self.action_space.shape[0])
        dist = dist.proba_distribution(mean_actions, log_std)
                
        log_prob = dist.log_prob(actions)
        entropy = dist.entropy()
        return values, log_prob, entropy
    
    def predict_values(self, obs):
        obs = obs.to(self.device)
        return self.critic(obs)

class LossTrackingCallback(BaseCallback):
    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.losses = []

    def _on_step(self) -> bool:
        # SB3 doesn't expose loss directly, so we approximate it
        # You could modify PPO class to store actual loss if needed
        if "train/policy_loss" in self.logger.name_to_value:
            self.losses.append({
                'policy_loss': self.logger.name_to_value['train/policy_loss'],
                'value_loss': self.logger.name_to_value['train/value_loss'],
                'entropy_loss': self.logger.name_to_value['train/entropy_loss']
            })
        return True

    def _on_rollout_end(self) -> None:
        # This function is called after each policy update
        # No real "loss", so you could simulate with mean reward or value loss
        ep_rewards = self.locals["rollout_buffer"].rewards
        avg_reward = sum(ep_rewards) / len(ep_rewards)
        self.losses.append(-avg_reward)  # Just an indicator proxy for loss

In [21]:
class CustomTargetReach(mn.environment.Environment):
    """A custom reaching task:
       - The effector starts at a random state.
       - The target is drawn uniformly from the joint space (projected to Cartesian space),
         with a 1 cm radius (r = 0.01 m).
       - A 200 ms go-cue delay is imposed during which no movement is allowed.
       - Episodes last for 5 seconds (max_ep_duration=5.0 s).
       - The episode terminates early if the effector's endpoint stays within the target 
         region for at least 800 ms.
       - The reward is given by: 
           Rₗ = - y_pos * L1_norm(xₜ - xₜ′) - y_ctrl * ( (uₜ * f / ∥f∥₂²)² ),
         with
           y_pos = 0 if ∥xₜ - xₜ′∥₂ < r, else 1, and
           y_ctrl = 1 if ∥xₜ - xₜ′∥₂ < r, else 0.03.
    """
    
    def __init__(self, effector, **kwargs):
        # Set task-specific parameters:
        self.elapsed = 0.0              # elapsed time in the episode
        self.distance_criteria = 0.005 
        self.cue_delay = 0.2            # 200 ms no-move phase
        # We assume the environment's dt is defined (e.g. dt = 0.02 s)
        self.dt = kwargs.pop("dt", 0.02)  # default timestep
        self.batch_size = kwargs.pop("batch_size", 128)
        self.hold_threshold = 0.5       # 500 ms hold threshold (in seconds)
        self.hold_time = th.zeros(self.batch_size, device=self.device)            # duration the effector is continuously within target

        # Pass max_ep_duration to parent (5 seconds)
        kwargs.setdefault("max_ep_duration", 5.0)
        super().__init__(effector, **kwargs)
        self.__name__ = "RandomTargetReach"
    
    def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple:
        """
        Reset the environment:
          - Draw a random target joint state.
          - Define a start joint state that is 1 cm away from the target (by adding small random noise).
          - Convert joint states to Cartesian coordinates to set the target.
          - Reset internal timers and observation buffers.
        """
        self._set_generator(seed)  # set PRNG seeds
        options = {} if options is None else options
        batch_size = options.get("batch_size", 1)
        deterministic: bool = options.get('deterministic', False)

        # 1. Sample random start joint state from the full joint space.
        q_target = self.effector.draw_random_uniform_states(batch_size)  # shape: (batch_size, n_joints)
        q_start = self.effector.draw_random_uniform_states(batch_size)

        obs, info = super().reset(seed=seed, options=options)
        
        # 4. Set the goal.
        cart_goal = self.joint2cartesian(q_target)
        self.goal = cart_goal
        info["goal"] = cart_goal if self.differentiable else self.detach(cart_goal)
        
        # 6. Reset internal timers.
        self.elapsed = 0.0
        self.hold_time = th.zeros(batch_size, device=self.device)

        # 7. Initialize observation buffers.
        action = th.zeros((batch_size, self.muscle.n_muscles)).to(self.device)
        self.obs_buffer["proprioception"] = [self.get_proprioception()] * len(self.obs_buffer["proprioception"])
        self.obs_buffer["vision"] = [self.get_vision()] * len(self.obs_buffer["vision"])
        self.obs_buffer["action"] = [action] * self.action_frame_stacking

        # 8. Get the initial observation.
        obs = self.get_obs(deterministic=deterministic)
        info.update({
        "states": self.states,
        "action": action,
        "noisy action": action,  # no noise at reset
        })
        return obs, info

    def apply_noise(self, loc, noise):
        """
        Override the default noise application to disable noise.
        This method returns the input `loc` unchanged.
        """
        return loc
    
    def step(self, action: th.Tensor, deterministic: bool = False) -> tuple:
        """
        Perform one simulation step.
        
        - During the first 200ms (cue_delay), actions are suppressed (set to zero) so that the effector stays
          at the starting position.
        - After stepping the environment, update the elapsed time.
        - Track the time the effector's endpoint (fingertip) remains within the target.
        - Terminate early if the hold time exceeds the threshold (800ms).
        - Compute the reward according to:
             Rₗ = - y_pos * L1_norm(xₜ - xₜ′) - y_ctrl * ( (uₜ * f / ∥f∥₂²)² ),
          where:
             y_pos = 0 if ∥xₜ - xₜ′∥₂ < r, else 1
             y_ctrl = 1 if ∥xₜ - xₜ′∥₂ < r, else 0.03.
        """
        # Freeze actions during the cue delay.
        if self.elapsed < self.cue_delay:
            # Handle both numpy arrays and torch tensors
            if isinstance(action, np.ndarray):
                action = np.zeros_like(action)
            elif isinstance(action, th.Tensor):
                action = th.zeros_like(action)
        
        # Step the simulation using the parent method.
        obs, _, terminated, truncated, info = super().step(action, deterministic=deterministic)
        action = th.tensor(action, device=self.device)
        self.elapsed += self.dt
        
        # Compute the distance between the fingertip and the goal.
        # Assume the effector state info includes "states" with key "fingertip".
        fingertip = info["states"]["fingertip"]
        dist = th.norm(fingertip - self.goal[:, :2], dim=-1)  # L2 norm per batch
        # Update hold time: increment where distance < threshold, reset otherwise
        in_target = dist < self.distance_criteria
        self.hold_time = th.where(in_target, self.hold_time + self.dt, th.zeros_like(self.hold_time))
        
        #Termination conditions
        successful_terminations = self.hold_time >= self.hold_threshold
        unsuccessful_terminations = th.full_like(
            successful_terminations,
            self.elapsed >= self.max_ep_duration,
            dtype=th.bool,
            device=successful_terminations.device
)
        # Combine conditions (batched OR)
        terminated = successful_terminations | unsuccessful_terminations
        # Truncate only if time limit is reached (Gym convention)
        truncated = unsuccessful_terminations #th.full_like(terminated, terminated_due_to_time, dtype=th.bool)
        
        # Compute reward:
        # Compute reward components
        y_pos = th.where(in_target, th.tensor(0.0, device=action.device), th.tensor(1.0, device=action.device))
        y_ctrl = th.where(in_target, th.tensor(1.0, device=action.device), th.tensor(0.03, device=action.device))
        pos_error = th.sum(th.abs(fingertip - self.goal[:, :2]), dim=-1)
        
        # For the control term, assume f (maximum isometric contraction vector) is provided by the effector,
        # or use ones as a default. Adjust normalization as needed.
        f = th.tensor(self.effector.tobuild__muscle['max_isometric_force'], dtype=th.float32)
        norm_f_squared = th.norm(f.clone().detach(), p=2) ** 2
        f_expanded = f.expand_as(action)
        # Compute the control term: square of (uₜ * f / norm_f_squared)
        ctrl_term = th.sum((action * ((f_expanded) / norm_f_squared) ** 2), dim=-1)
        
        # Compute reward as described:
        reward = - y_pos * pos_error - y_ctrl * ctrl_term
        reward = reward.unsqueeze(-1)  # ensure shape is (batch_size, 1)
        #example: reward components: tensor([0.6731, 0.7756, 0.3219, 0.4017, 0.8585, 0.6772, 0.0000, 0.8300]) and ctrl_term: tensor([-0.0016,  0.0012, -0.0026,  0.0002, -0.0034, -0.0012, -0.0427, -0.0003])
        # Optionally, add reward components to info for debugging:
        info["reward_components"] = {"pos_error": pos_error, "ctrl_term": ctrl_term}
        
        return obs, reward, terminated, truncated, info


In [None]:
env = CustomTargetReach(
    effector=arm.effector, 
    obs_noise=0.0,
    proprioception_noise=0.0,
    vision_noise=0.0,
    action_noise=0.0,
    )

model = PPO(
    StableNetwork, env, 
    learning_rate=3e-4,
    n_steps=total_timesteps,  # Adjust based on your environment's requirements
    batch_size=batch_size,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    verbose=1,
    policy_kwargs={
        'device': DEVICE
    }
)
loss_callback = LossTrackingCallback(verbose=1)
progress_bar_callback = ProgressBarCallback()
callbacks = [
    loss_callback,
    progress_bar_callback,
]
# Train the model
model.learn(total_timesteps=100, callback=callbacks)



Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


KeyboardInterrupt: 

In [None]:
def plot_rollouts(model, env, n_trials=5):
    """Plot trajectories from multiple rollouts of the trained policy."""
    plt.figure(figsize=(10, 6))
    
    for trial in range(n_trials):
        obs = env.reset()
        done = False
        positions = []
        goals = []
        
        while not done:
            action, _states = model.predict(obs, deterministic=True)
            obs, reward, done, info = env.step(action)
            
            # Extract relevant position information - adjust based on your environment's obs structure
            positions.append(obs['proprioception'][:2])  # First 2 elements as position
            goals.append(obs['goal'][:2])  # First 2 elements as goal
            
        positions = np.array(positions)
        goals = np.array(goals)
        
        # Plot trajectory
        plt.plot(positions[:, 0], positions[:, 1], alpha=0.5, label=f'Trial {trial+1}')
        plt.scatter(goals[-1, 0], goals[-1, 1], marker='*', s=100, c='red')

    plt.title(f'Policy Trajectories ({n_trials} Trials)')
    plt.xlabel('X Position')
    plt.ylabel('Y Position')
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_losses(loss_history):
    """Plot training losses from the callback."""
    plt.figure(figsize=(10, 4))
    
    # Unpack loss components - adjust based on your callback's storage format
    policy_loss = [x['policy_loss'] for x in loss_history]
    value_loss = [x['value_loss'] for x in loss_history]
    entropy_loss = [x['entropy_loss'] for x in loss_history]
    
    x = np.arange(len(loss_history))
    
    plt.plot(x, policy_loss, label='Policy Loss')
    plt.plot(x, value_loss, label='Value Loss')
    plt.plot(x, entropy_loss, label='Entropy Loss')
    
    plt.title('Training Loss Curves')
    plt.xlabel('Update Step')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()



In [None]:
plot_rollouts(model, env, n_trials=5)
plot_losses(callback.losses)