In [76]:
import os
import torch as th
import numpy as np
import json 
from datetime import datetime
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import motornet as mn
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.cross_decomposition import CCA
import math
from torch.distributions import Normal
from gymnasium import Wrapper
from stable_baselines3.common.callbacks import BaseCallback


In [10]:
from dynamics_analysis import *
from environment import *
from networks import *
from neural_activity import *
from utils import *
from training import *

In [77]:
class L1LossCallback(BaseCallback):
    def __init__(self, verbose=0, log_path="./logs/"):
        super().__init__(verbose)
        self.log_path = log_path
        self.losses = []

    def _on_step(self) -> bool:
        # Info dict for each parallel env
        infos = self.locals.get("infos", [])

        for info in infos:
            if 'goal' in info and 'fingertip' in info:
                goal = np.array(info['goal'])
                fingertip = np.array(info['fingertip'])

                l1_loss = np.abs(goal - fingertip).sum()
                self.losses.append(l1_loss)

                if self.verbose > 0:
                    print(f"L1 loss: {l1_loss:.4f}")

        return True

    def _on_training_end(self):
        # Optional: Save the loss values for later
        os.makedirs(self.log_path, exist_ok=True)
        np.save(os.path.join(self.log_path, "l1_losses.npy"), np.array(self.losses))

In [None]:
class MotorNetBatchLoadWrapper(Wrapper):
    """
    Ensures actions have shape (batch, n_muscles) and supplies
    default zero loads for endpoint and joints on every step().
    """

    def __init__(self, env: mn.environment.RandomTargetReach):
        super().__init__(env)
        # Probe once to figure out load shapes:
        obs, info = env.reset(options={"batch_size": 1})
        geom_state = info["states"]["geometry"]  # (1, n_points, 2)
        joint_state = info["states"]["joint"]    # (1, n_joints)

        # For endpoint_load we need one value per coordinate axis (x,y)
        self.endpoint_dim = geom_state.shape[2]  # = 2
        # For joint_load we need one per joint
        self.n_joints = info["states"]["joint"].shape[1] // 2
        print(f"Detected: endpoint_dim={self.endpoint_dim}, n_joints={self.n_joints}, n_muscles={self.n_muscles}")
        # For actions we need one per muscle
        self.n_muscles   = env.n_muscles


    def step(self, action):
        action = np.array(action, dtype=np.float32)
        if action.ndim == 1:
            action = action[None, :]  # shape (1, act_dim)

        batch = action.shape[0]

        endpoint_load = np.zeros((batch, self.endpoint_dim), dtype=np.float32)
        joint_load = np.zeros((batch, self.n_joints), dtype=np.float32)

        obs, _, terminated, truncated, info = self.env.step(
            action=action,
            endpoint_load=endpoint_load,
            joint_load=joint_load
        )

        target = info["goal"]
        fingertip = info["states"]["fingertip"]
        reward = -np.linalg.norm(fingertip - target, axis=1)  # shape (batch,)

        # Use the first sample
        obs = obs[0]
        reward = float(reward[0])
        terminated = bool(terminated)
        truncated = bool(truncated)
        info = {k: (v[0] if isinstance(v, np.ndarray) and v.shape[0] == batch else v) for k, v in info.items()}

        return obs, reward, terminated, truncated, info
     

In [None]:
class CustomACPolicy(ActorCriticPolicy):
    def __init__(self, observation_space, action_space, lr_schedule, **kwargs):
        device = kwargs.get('device', 'cpu')
        kwargs.pop('device')
        super().__init__(observation_space, action_space, lr_schedule, **kwargs)
        self._build_network(device)
        self._hidden_states = None

    def _build_network(self, device):
        input_dim = self.observation_space.shape[0]
        self.action_dim = self.action_space.shape[0]
        self.hidden_dim = 64  # Match your GRU hidden size

        # Actor (GRU-based) and Critic
        self.actor = Policy(input_dim, self.hidden_dim, self.action_dim, device)
        self.critic = Critic(input_dim, device)

    def reset_hidden(self, batch_size=1):
        self._hidden_states = self.actor.init_hidden(batch_size)

    def forward(self, obs, deterministic=False):
        if self._hidden_states is None or self._hidden_states.size(1) != obs.shape[0]:
            self.reset_hidden(obs.shape[0])
        
        mean, new_hidden = self.actor(obs, self._hidden_states)
        self._hidden_states = new_hidden.detach()

        # Define action distribution
        std = th.ones_like(mean) * 0.1  # or learnable
        dist = th.distributions.Normal(mean, std)

        # Sample action
        actions = dist.mean if deterministic else dist.sample()
        log_probs = dist.log_prob(actions).sum(dim=-1, keepdim=True)

        values = self.critic(obs)

        return actions, values, log_probs

    def evaluate_actions(self, obs, actions):
        if self._hidden_states is None or self._hidden_states.size(1) != obs.shape[0]:
            self.reset_hidden(obs.shape[0])

        mean, new_hidden = self.actor(obs, self._hidden_states)
        self._hidden_states = new_hidden.detach()

        std = th.ones_like(mean) * 0.1
        dist = th.distributions.Normal(mean, std)

        log_probs = dist.log_prob(actions).sum(dim=-1, keepdim=True)
        entropy = dist.entropy().sum(dim=-1, keepdim=True)
        values = self.critic(obs)

        if values.ndim == 1:
            values = values.unsqueeze(1)

        return values, log_probs, entropy
    
    
def PPO_train():
    # Create the MotorNet environment
    effector = mn.effector.RigidTendonArm26(muscle=mn.muscle.MujocoHillMuscle())
    env = mn.environment.RandomTargetReach(effector=effector, max_ep_duration=5.)

    # 2. Wrap it
    wrapped_env = MotorNetBatchLoadWrapper(env)

    #sanity check
    print("Wrapped action_space:", wrapped_env.action_space)
    print("Wrapped observation_space:", wrapped_env.observation_space)
    
    # Instantiate the custom PPO model
    model = PPO(
        CustomACPolicy,
        wrapped_env,
        learning_rate=3e-4,
        n_steps=10000,  # Adjust based on your environment's requirements
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        verbose=1,
        policy_kwargs={
            'device': 'cuda:0' if th.cuda.is_available() else 'cpu'
        }
    )
    callback = L1LossCallback(verbose=1)
    # Train the model
    model.learn(total_timesteps=1000, callback=callback)
    
    return model

In [79]:
trained_model = PPO_train()


Detected: endpoint_dim=6, n_joints=2, n_muscles=6
Wrapped action_space: Box(0.0, 1.0, (6,), float32)
Wrapped observation_space: Box(-inf, inf, (16,), float32)
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 501      |
|    ep_rew_mean     | -180     |
| time/              |          |
|    fps             | 160      |
|    iterations      | 1        |
|    time_elapsed    | 62       |
|    total_timesteps | 10000    |
---------------------------------


In [None]:
def evaluate_pretrained(policy, env, batch_size):
    """Evaluation function with hidden state management"""
    # Reset hidden states
    policy.hidden_states = policy.actor.init_hidden(batch_size)
    
    # Initialize environment
    obs, info = env.reset(options={"batch_size": batch_size})
    terminated = np.zeros(batch_size, dtype=bool)
    xy = [info["states"]["fingertip"][:, None, :]]
    tg = [info["goal"][:, None, :]]

    # Run evaluation episode
    while not terminated.all():
        obs = obs.to(policy.actor.device)  # Move observation to the same device as the model
        action, _, _, _ = policy(obs, deterministic=True)
        xy.append(info["states"]["fingertip"][:, None, :])
        tg.append(info["goal"][:, None, :])

    # Plot results
    xy = th.cat(xy, axis=1).detach().numpy()
    tg = th.cat(tg, axis=1).detach().numpy()
    plot_simulations(xy=xy, target_xy=tg)

In [85]:
# Create the MotorNet environment
effector = mn.effector.RigidTendonArm26(muscle=mn.muscle.MujocoHillMuscle())
env = mn.environment.RandomTargetReach(effector=effector, max_ep_duration=5.)

# 2. Wrap it
wrapped_env = MotorNetBatchLoadWrapper(env)
evaluate_pretrained(trained_model.policy, wrapped_env, batch_size=1)


  logger.warn(


Detected: endpoint_dim=6, n_joints=2, n_muscles=6


RuntimeError: Input and hidden tensors are not at the same device, found input tensor at cpu and hidden tensor at cuda:0