In [1]:
import gymnasium 
import numpy as np
from stable_baselines3 import SAC

import PyFlyt.gym_envs
from PyFlyt.gym_envs import FlattenWaypointEnv #needed for waypoints
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Import Ring Attractor components
import sys
import os

from src.utils.policy_warp import create_sac_ring_attractor
from src.utils.attractors import RingAttractorConfig

In [2]:
# Create the env
env = gymnasium.make(
        "PyFlyt/QuadX-Waypoints-v4",
        sparse_reward=False,
        num_targets=4,
        goal_reach_distance=0.3,
        max_duration_seconds=15.0,
        flight_mode = 1, 
        angle_representation="quaternion",
        render_mode = "human",
    )
env = FlattenWaypointEnv(env, context_length=2)


In [3]:
 # Create baseline SAC model first
base_model = SAC(
    "MlpPolicy",
    env,
    learning_rate=3e-4,
    policy_kwargs=dict(net_arch=[256, 256]),
    verbose=1,
    tensorboard_log="./tensorboard_logs/ring_attractor_sac"
)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [4]:
# Configure Ring Attractor for quadcopter control
ring_config = {
    'layer_type': 'multi',
    'input_dim': 256,  # Should match the last layer of the MLP
    'control_axes': ['roll_rate', 'pitch_rate', 'yaw_rate',"thrust"],
    'ring_axes': ['roll_rate', 'pitch_rate', 'yaw_rate'],  # Spatial axes
    'config': RingAttractorConfig(
        num_excitatory=16,
        tau=8.0,
        beta=12.0,
        lambda_decay=0.8,
        trainable_structure=True,
        connectivity_strength=0.1,
        cross_coupling_factor=0.05
    )
}

In [5]:
ring_model = create_sac_ring_attractor(
            base_model=base_model,
            layer_config=ring_config,
            device="cpu"
        )

INFO:src.utils.control_layers:Initialized MultiAxisRingAttractorLayer: ring_axes=['roll_rate', 'pitch_rate', 'yaw_rate'], linear_axes=['thrust']
INFO:src.utils.policy_warp:Wrapped SAC policy with Ring Attractor
INFO:src.utils.policy_warp:Created Ring Attractor model for stable_baselines3/SAC on cpu


Before modification:
layer_config type: <class 'dict'>
layer_config: {'layer_type': 'multi', 'input_dim': 256, 'control_axes': ['roll_rate', 'pitch_rate', 'yaw_rate', 'thrust'], 'ring_axes': ['roll_rate', 'pitch_rate', 'yaw_rate'], 'config': <src.utils.attractors.RingAttractorConfig object at 0x00000171E4432440>}
output_shape: 256
After modification:
layer_config: {'layer_type': 'multi', 'input_dim': 256, 'control_axes': ['roll_rate', 'pitch_rate', 'yaw_rate', 'thrust'], 'ring_axes': ['roll_rate', 'pitch_rate', 'yaw_rate'], 'config': <src.utils.attractors.RingAttractorConfig object at 0x00000171E4432440>, 'output_dim': 256}
'output_dim' in layer_config: True


In [6]:
ring_model.learn(total_timesteps=1000, progress_bar=True)


[A                             [A
Logging to ./tensorboard_logs/ring_attractor_sac\SAC_7


---------------------------------
| rollout/           |          |
|    ep_len_mean     | 93.2     |
|    ep_rew_mean     | -101     |
| time/              |          |
|    episodes        | 4        |
|    fps             | 4        |
|    time_elapsed    | 83       |
|    total_timesteps | 373      |
| train/             |          |
|    actor_loss      | -6.16    |
|    critic_loss     | 21.7     |
|    ent_coef        | 0.923    |
|    ent_coef_loss   | -0.523   |
|    learning_rate   | 0.0003   |
|    n_updates       | 272      |
---------------------------------


KeyboardInterrupt: 

In [None]:
env.close()

In [None]:
dir(ring_model.policy.actor)

In [None]:
ring_model.policy.actor

In [None]:
ring_model.policy.critic


In [None]:
base = SAC("MlpPolicy", env,verbose=1)

In [None]:
layers = list (base.policy.actor.latent_pi.children())


In [None]:
output_shape = None
for layer in reversed(layers):
    if hasattr(layer, 'out_features'):  # Linear layer
        output_shape = layer.out_features
        break
    elif hasattr(layer, 'out_channels'):  # Conv layer
        output_shape = layer.out_channels
        break
    elif hasattr(layer, 'hidden_size'):  # RNN/LSTM
        output_shape = layer.hidden_size
        break


In [None]:
ring_model.policy

In [None]:
output_shape
