#### Neccessary module imports

NCAP model GPU environment imports

In [None]:
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

print('Installing dm_control...')
!pip install -q dm_control>=1.0.16

# Configure dm_control to use the EGL rendering backend (requires GPU)
%env MUJOCO_GL=egl

!echo Installed dm_control $(pip show dm_control | grep -Po "(?<=Version: ).+")
!pip install -q dm-acme[envs]
!mkdir output_videos

In [None]:
#@title Download and install tonic library for training agents

import contextlib
import io


!git clone https://github.com/neuromatch/tonic
%cd tonic

In [None]:
import numpy as np
import collections
import argparse
import os
import yaml
import typing as T
import imageio
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import pandas as pd
import seaborn as sns
import csv
import torch
import torch.nn
import dm_control as dm
import dm_control.suite.swimmer as swimmer
from dm_control.rl import control
import logging
from IPython.display import HTML
from dm_control.utils import rewards
from dm_control import suite
from dm_control.suite.wrappers import pixels
from matplotlib.animation import FuncAnimation
from acme import wrappers
import plotly.graph_objs as go
from torch import nn
from plotly.colors import DEFAULT_PLOTLY_COLORS
from plotly.subplots import make_subplots
import plotly.express as px

In [None]:
#@title Utility code for displaying videos
def write_video(
  filepath: os.PathLike,
  frames: T.Iterable[np.ndarray],
  fps: int = 60,
  macro_block_size: T.Optional[int] = None,
  quality: int = 10,
  verbose: bool = False,
  **kwargs,
):
  """
  Saves a sequence of frames as a video file.

  Parameters:
  - filepath (os.PathLike): Path to save the video file.
  - frames (Iterable[np.ndarray]): An iterable of frames, where each frame is a numpy array.
  - fps (int, optional): Frames per second, defaults to 60.
  - macro_block_size (Optional[int], optional): Macro block size for video encoding, can affect compression efficiency.
  - quality (int, optional): Quality of the output video, higher values indicate better quality.
  - verbose (bool, optional): If True, prints the file path where the video is saved.
  - **kwargs: Additional keyword arguments passed to the imageio.get_writer function.

  Returns:
  None. The video is written to the specified filepath.
  """

  with imageio.get_writer(filepath,
                        fps=fps,
                        macro_block_size=macro_block_size,
                        quality=quality,
                        **kwargs) as video:
    if verbose: print('Saving video to:', filepath)
    for frame in frames:
      video.append_data(frame)


def display_video(
  frames: T.Iterable[np.ndarray],
  filename='output_videos/temp.mp4',
  fps=60,
  **kwargs,
):
  """
  Displays a video within a Jupyter Notebook from an iterable of frames.

  Parameters:
  - frames (Iterable[np.ndarray]): An iterable of frames, where each frame is a numpy array.
  - filename (str, optional): Temporary filename to save the video before display, defaults to 'output_videos/temp.mp4'.
  - fps (int, optional): Frames per second for the video display, defaults to 60.
  - **kwargs: Additional keyword arguments passed to the write_video function.

  Returns:
  HTML object: An HTML video element that can be displayed in a Jupyter Notebook.
  """

  # Write video to a temporary file.
  filepath = os.path.abspath(filename)
  write_video(filepath, frames, fps=fps, verbose=False, **kwargs)

  height, width, _ = frames[0].shape
  dpi = 70
  orig_backend = matplotlib.get_backend()
  matplotlib.use('Agg')  # Switch to headless 'Agg' to inhibit figure rendering.
  fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)
  matplotlib.use(orig_backend)  # Switch back to the original backend.
  ax.set_axis_off()
  ax.set_aspect('equal')
  ax.set_position([0, 0, 1, 1])
  im = ax.imshow(frames[0])
  def update(frame):
    im.set_data(frame)
    return [im]
  interval = 1000/fps
  anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,
                                  interval=interval, blit=True, repeat=False)
  return HTML(anim.to_html5_video())

In [None]:
def play_model(path, checkpoint='last',environment='default',seed=None, header=None):

  """
    Plays a model within an environment and renders the gameplay to a video.

    Parameters:
    - path (str): Path to the directory containing the model and checkpoints.
    - checkpoint (str): Specifies which checkpoint to use ('last', 'first', or a specific ID). 'none' indicates no checkpoint.
    - environment (str): The environment to use. 'default' uses the environment specified in the configuration file.
    - seed (int): Optional seed for reproducibility.
    - header (str): Optional Python code to execute before initializing the model, such as importing libraries.
    """

  if checkpoint == 'none':
    # Use no checkpoint, the agent is freshly created.
    checkpoint_path = None
    tonic.logger.log('Not loading any weights')
  else:
    checkpoint_path = os.path.join(path, 'checkpoints')
    if not os.path.isdir(checkpoint_path):
      tonic.logger.error(f'{checkpoint_path} is not a directory')
      checkpoint_path = None

    # List all the checkpoints.
    checkpoint_ids = []
    for file in os.listdir(checkpoint_path):
      if file[:5] == 'step_':
        checkpoint_id = file.split('.')[0]
        checkpoint_ids.append(int(checkpoint_id[5:]))

    if checkpoint_ids:
      if checkpoint == 'last':
        # Use the last checkpoint.
        checkpoint_id = max(checkpoint_ids)
        checkpoint_path = os.path.join(checkpoint_path, f'step_{checkpoint_id}')
      elif checkpoint == 'first':
        # Use the first checkpoint.
        checkpoint_id = min(checkpoint_ids)
        checkpoint_path = os.path.join(checkpoint_path, f'step_{checkpoint_id}')
      else:
        # Use the specified checkpoint.
        checkpoint_id = int(checkpoint)
        if checkpoint_id in checkpoint_ids:
          checkpoint_path = os.path.join(checkpoint_path, f'step_{checkpoint_id}')
        else:
          tonic.logger.error(f'Checkpoint {checkpoint_id} not found in {checkpoint_path}')
          checkpoint_path = None
    else:
      tonic.logger.error(f'No checkpoint found in {checkpoint_path}')
      checkpoint_path = None

  # Load the experiment configuration.
  arguments_path = os.path.join(path, 'config.yaml')
  with open(arguments_path, 'r') as config_file:
    config = yaml.load(config_file, Loader=yaml.FullLoader)
  config = argparse.Namespace(**config)

  # Run the header first, e.g. to load an ML framework.
  try:
    if config.header:
      exec(config.header)
    if header:
      exec(header)
  except:
    pass

  # Build the agent.
  agent = eval(config.agent)

  # Build the environment.
  if environment == 'default':
    environment  = tonic.environments.distribute(lambda: eval(config.environment))
  else:
    environment  = tonic.environments.distribute(lambda: eval(environment))
  if seed is not None:
    environment.seed(seed)

  # Initialize the agent.
  agent.initialize(
    observation_space=environment.observation_space,
    action_space=environment.action_space,
    seed=seed,
  )

  # Load the weights of the agent form a checkpoint.
  if checkpoint_path:
    agent.load(checkpoint_path)

  steps = 0
  test_observations = environment.start()
  frames = [environment.render('rgb_array',camera_id=0, width=640, height=480)[0]]
  score, length = 0, 0

  neural_activity_values_list = []

  while True:
      # Select an action.
      actions = agent.test_step(test_observations, steps)
      assert not np.isnan(actions.sum())
      neural_activity_values = agent.model.actor.get_connections_log_values()
      for neuron_activity in neural_activity_values:
          timestep, activity_type, neuron, param_value = neuron_activity
          neural_activity_values_list.append({
              'timestep': timestep,
              'activity': 1 if activity_type == 'exc' else -1,
              'neuron': neuron,
              'param_value': param_value,
          })
      # Take a step in the environment.
      test_observations, infos = environment.step(actions)
      frames.append(environment.render('rgb_array',camera_id=0, width=640, height=480)[0])
      agent.test_update(**infos, steps=steps)

      score += infos['rewards'][0]
      length += 1

      if infos['resets'][0]:
          break
  video_path = os.path.join(path, 'video.mp4')
  print('Reward for the run: ', score)
  return display_video(frames,video_path), neural_activity_values_list

In [None]:
_SWIM_SPEED = 0.1

@swimmer.SUITE.add()
def swim(
  n_links=6,
  desired_speed=_SWIM_SPEED,
  time_limit=swimmer._DEFAULT_TIME_LIMIT,
  random=None,
  environment_kwargs={},
):
  """Returns the Swim task for a n-link swimmer."""
  model_string, assets = swimmer.get_model_and_assets(n_links)
  physics = swimmer.Physics.from_xml_string(model_string, assets=assets)
  task = Swim(desired_speed=desired_speed, random=random)
  return control.Environment(
    physics,
    task,
    time_limit=time_limit,
    control_timestep=swimmer._CONTROL_TIMESTEP,
    **environment_kwargs,
  )


class Swim(swimmer.Swimmer):
  """Task to swim forwards at the desired speed."""
  def __init__(self, desired_speed=_SWIM_SPEED, **kwargs):
    super().__init__(**kwargs)
    self._desired_speed = desired_speed

  def initialize_episode(self, physics):
    super().initialize_episode(physics)
    # Hide target by setting alpha to 0.
    physics.named.model.mat_rgba['target', 'a'] = 0
    physics.named.model.mat_rgba['target_default', 'a'] = 0
    physics.named.model.mat_rgba['target_highlight', 'a'] = 0

  def get_observation(self, physics):
    """Returns an observation of joint angles and body velocities."""
    obs = collections.OrderedDict()
    obs['joints'] = physics.joints()
    obs['body_velocities'] = physics.body_velocities()
    return obs

  def get_reward(self, physics):
    """Returns a smooth reward that is 0 when stopped or moving backwards, and rises linearly to 1
    when moving forwards at the desired speed."""
    forward_velocity = -physics.named.data.sensordata['head_vel'][1]
    return rewards.tolerance(
      forward_velocity,
      bounds=(self._desired_speed, float('inf')),
      margin=self._desired_speed,
      value_at_margin=0.,
      sigmoid='linear',
    )

#### Train function for the NCAP model

In [None]:
import tonic
import tonic.torch
from google.colab import drive


drive.mount('/content/drive')

def train(
  header,
  agent,
  environment,
  name = 'test',
  trainer = 'tonic.Trainer()',
  before_training = None,
  after_training = None,
  parallel = 1,
  sequential = 1,
  seed = 0
):
  """
  Some additional parameters:

  - before_training: Python code to execute immediately before the training loop commences, suitable for setup actions needed after initialization but prior to training.
  - after_training: Python code to run once the training loop concludes, ideal for teardown or analytical purposes.
  - parallel: The count of environments to execute in parallel. Limited to 1 in a Colab notebook, but if additional resources are available, this number can be increased to expedite training.
  - sequential: The number of sequential steps the environment runs before sending observations back to the agent. This setting is useful for temporal batching. It can be disregarded for this tutorial's purposes.
  - seed: The experiment's random seed, guaranteeing the reproducibility of the training process.

  """
  # Capture the arguments to save them, e.g. to play with the trained agent.
  args = dict(locals())

  # Run the header first, e.g. to load an ML framework.
  if header:
    exec(header)

  # Build the train and test environments.
  _environment = environment
  environment = tonic.environments.distribute(lambda: eval(_environment), parallel, sequential)
  test_environment = tonic.environments.distribute(lambda: eval(_environment))


  # Build the agent.
  agent = eval(agent)
  agent.initialize(
    observation_space=test_environment.observation_space,
    action_space=test_environment.action_space, seed=seed)

  # Choose a name for the experiment.
  if hasattr(test_environment, 'name'):
    environment_name = test_environment.name
  else:
    environment_name = test_environment.__class__.__name__
  if not name:
    if hasattr(agent, 'name'):
      name = agent.name
    else:
      name = agent.__class__.__name__
    if parallel != 1 or sequential != 1:
      name += f'-{parallel}x{sequential}'

  # add you drive path '/content/drive/My Drive/your_directory_name'
  path = os.path.join('/content/drive/My Drive/', 'data', 'experiments', 'tonic', environment_name, name)
  print (f"path for saving logs/model --- {path}")
  tonic.logger.initialize(path, script_path=None, config=args)

  # Build the trainer.
  trainer = eval(trainer)
  trainer.initialize(
    agent=agent,
    environment=environment,
    test_environment=test_environment,
  )
  # Run some code before training.
  if before_training:
    exec(before_training)

  # Train.
  trainer.run()

  # Run some code after training.
  if after_training:
    exec(after_training)

  # save the connection logs
  neural_activity_list = agent.model.actor.get_connections_log_values()

  if neural_activity_list:
    neural_activity_file_path = os.path.join(path, 'neural_activity.csv')
    neural_activity_df = pd.DataFrame(neural_activity_list, columns=['timestep', 'activity_type', 'neuron', 'param_value'])
    neural_activity_df.to_csv(neural_activity_file_path, index=False)

### NCAP parameter initialization

#### Weight initialization distribution

In [None]:
weight_initalizing_type = "he"
clamped = True

In [None]:
# Weight constraints
def excitatory(w, upper=None):
    return w.clamp(min=0, max=upper)
def inhibitory(w, lower=None):
    return w.clamp(min=lower, max=0)
def unsigned(w, lower=None, upper=None):
    return w if lower is None and upper is None else w.clamp(min=lower, max=upper)

# Activation constraints
def graded(x):
    return x.clamp(min=0, max=1)

def initialize_weights(shape=(1,), init_type='uniform', lower=-1., upper=1., clamped=True, type=None):
    if init_type == 'uniform':
        param = nn.init.uniform_(nn.Parameter(torch.empty(shape)), a=lower, b=upper)
    elif init_type == 'xavier':
        if len(shape) < 2:
            # add a dimension if missing
            shape = (1, shape[0])
        param = nn.init.xavier_uniform_(nn.Parameter(torch.empty(shape)))
    elif init_type == 'he':
        if len(shape) < 2:
            # add a dimension if missing
            shape = (1, shape[0])
        param = nn.init.kaiming_uniform_(nn.Parameter(torch.empty(shape)), nonlinearity='relu')
    else:
        raise ValueError(f"Unknown initialization type: {init_type}")

    # clamp the values if needed
    if clamped:
        param.data.clamp_(min=lower, max=upper)

    # apply the excitatory/inhibitory logic
    if type == 'exc':
        param.data = excitatory(param.data)
    elif type == 'inh':
        param.data = inhibitory(param.data)
    return param

# Weight initialization
def excitatory_weights(clamped, shape=(1,), init_type='uniform', lower=0., upper=1.):
    assert lower >= 0
    return initialize_weights(shape, init_type, lower, upper, clamped, type='exc')

def inhibitory_weights(clamped, shape=(1,), init_type='uniform', lower=-1., upper=0.):
    assert upper <= 0
    return initialize_weights(shape, init_type, lower, upper, clamped, type='inh')

def unsigned_weights(clamped, shape=(1,), init_type='uniform', lower=-1., upper=1.):
    return initialize_weights(shape, init_type, lower, upper, clamped)

def excitatory_constant(shape=(1,), value=1.):
    return nn.Parameter(torch.full(shape, value))

def inhibitory_constant(shape=(1,), value=-1.):
    return nn.Parameter(torch.full(shape, value))

def unsigned_constant(shape=(1,), lower=-1., upper=1., p=0.5):
    with torch.no_grad():
        weight = torch.empty(shape).uniform_(0, 1)
        mask = weight < p
        weight[mask] = upper
        weight[~mask] = lower
        return nn.Parameter(weight)

In [None]:
def plot_weights_init_distribution(weight_init_type: str = "uniform"):
  # checking the distribution by taking 1000 samples
  shape = (1000, 1)
  # generate weights for each distribution
  weights_excitatory = excitatory_weights(clamped=clamped, shape=shape, init_type=weight_init_type)
  weights_inhibitory = inhibitory_weights(clamped=clamped, shape=shape, init_type=weight_init_type)
  weights_unsigned = unsigned_weights(clamped=clamped, shape=shape, init_type=weight_init_type)

  # convert torch tensors to numpy arrays for plotting
  weights_excitatory_np = weights_excitatory.detach().numpy()
  weights_inhibitory_np = weights_inhibitory.detach().numpy()
  weights_unsigned_np = weights_unsigned.detach().numpy()

  # Plotting the histograms
  plt.figure(figsize=(18, 6))

  # Excitatory uniform weights
  plt.subplot(1, 3, 1)
  plt.hist(weights_excitatory_np, bins=50, density=True, alpha=0.7, color='blue', edgecolor='black')
  plt.title(f'Excitatory {weight_init_type} clamped {clamped} Weights')
  plt.xlabel('Value')
  plt.ylabel('Density')
  plt.grid(True)

  # Inhibitory uniform weights
  plt.subplot(1, 3, 2)
  plt.hist(weights_inhibitory_np, bins=50, density=True, alpha=0.7, color='green', edgecolor='black')
  plt.title(f'Inhibitory {weight_init_type} clamped {clamped} Weights')
  plt.xlabel('Value')
  plt.ylabel('Density')
  plt.grid(True)

  # Unsigned uniform weights
  plt.subplot(1, 3, 3)
  plt.hist(weights_unsigned_np, bins=50, density=True, alpha=0.7, color='orange', edgecolor='black')
  plt.title(f'Unsigned {weight_init_type} clamped {clamped} Weights')
  plt.xlabel('Value')
  plt.ylabel('Density')
  plt.grid(True)

  plt.tight_layout()
  plt.show()

plot_weights_init_distribution(weight_init_type=weight_initalizing_type)

In [None]:
include_contra_forward_feedback_control = False
inclue_self_feedback_inhibition = False
# I expect the model to not do movement properly
include_head_oscillators = True
# set as False if testing with different wt initialization
use_weight_constant_init = False

Define Swimmer Module

In [None]:
class SwimmerModule(nn.Module):
    """C.-elegans-inspired neural circuit architectural prior."""

    def __init__(
            self,
            n_joints: int,
            log_dir: str,
            log_file: str = None,
            n_turn_joints: int = 1,
            oscillator_period: int = 60,
            use_weight_sharing: bool = True,
            use_weight_constraints: bool = True,
            use_weight_constant_init: bool = use_weight_constant_init,
            include_proprioception: bool = True,
            include_head_oscillators: bool = include_head_oscillators,
            include_speed_control: bool = False,
            include_turn_control: bool = False,
            include_contra_feedback_control: bool = include_contra_forward_feedback_control,
            include_self_feedback_inh: bool = inclue_self_feedback_inhibition,
    ):
        super().__init__()
        self.n_joints = n_joints
        self.n_turn_joints = n_turn_joints
        self.oscillator_period = oscillator_period
        self.include_proprioception = include_proprioception
        self.include_head_oscillators = include_head_oscillators
        self.include_speed_control = include_speed_control
        self.include_turn_control = include_turn_control
        self.include_contra_feedback_control = include_contra_forward_feedback_control
        self.include_self_feedback_inh = include_self_feedback_inh
        self.log_file = log_file
        self.log_dir = log_dir

        # Log activity
        self.connections_log = []

        # Timestep counter (for oscillations).
        self.timestep = 0

        # Weight sharing switch function.
        self.ws = lambda nonshared, shared: shared if use_weight_sharing else nonshared

        # Weight constraint and init functions.
        if use_weight_constraints:
            self.exc = excitatory
            self.inh = inhibitory
            if use_weight_constant_init:
                exc_param = excitatory_constant
                inh_param = inhibitory_constant
            else:
                print(f'using wts from distribution {weight_initalizing_type} with range clamp {clamped}')
                exc_param = excitatory_weights(init_type=weight_initalizing_type, clamped=clamped)
                inh_param = inhibitory_weights(init_type=weight_initalizing_type, clamped=clamped)
        else:
            self.exc = unsigned
            self.inh = unsigned
            if use_weight_constant_init:
                exc_param = inh_param = unsigned_constant
            else:
                exc_param = inh_param = unsigned_weights(init_type=weight_initalizing_type)

        # Learnable parameters.
        self.params = nn.ParameterDict()
        if use_weight_sharing:
            if self.include_proprioception:
                self.params['bneuron_prop'] = exc_param() if use_weight_constant_init else nn.Parameter(exc_param.data.clone())
            if self.include_speed_control:
                self.params['bneuron_speed'] = inh_param() if use_weight_constant_init else nn.Parameter(inh_param.data.clone())
            if self.include_turn_control:
                self.params['bneuron_turn'] = exc_param() if use_weight_constant_init else nn.Parameter(exc_param.data.clone())
            if self.include_head_oscillators:
                self.params['bneuron_osc'] = exc_param() if use_weight_constant_init else nn.Parameter(exc_param.data.clone())
            if self.include_self_feedback_inh:
                self.params['D_sfeedback'] = inh_param() if use_weight_constant_init else nn.Parameter(inh_param.data.clone())
            self.params['muscle_ipsi'] = exc_param() if use_weight_constant_init else nn.Parameter(exc_param.data.clone())
            self.params['muscle_contra'] = inh_param() if use_weight_constant_init else nn.Parameter(inh_param.data.clone())
        else:
            for i in range(self.n_joints):
                if self.include_proprioception and i > 0:
                    self.params[f'bneuron_d_prop_{i}'] = exc_param()
                    self.params[f'bneuron_v_prop_{i}'] = exc_param()

                if self.include_self_feedback_inh and i > 0:
                    self.params[f'D_d_sfeedback_{i}'] = inh_param()
                    self.params[f'D_v_sfeedback_{i}'] = inh_param()

                if self.include_speed_control:
                    self.params[f'bneuron_d_speed_{i}'] = inh_param()
                    self.params[f'bneuron_v_speed_{i}'] = inh_param()

                if self.include_turn_control and i < self.n_turn_joints:
                    self.params[f'bneuron_d_turn_{i}'] = exc_param()
                    self.params[f'bneuron_v_turn_{i}'] = exc_param()

                if self.include_head_oscillators and i == 0:
                    self.params[f'bneuron_d_osc_{i}'] = exc_param()
                    self.params[f'bneuron_v_osc_{i}'] = exc_param()

                self.params[f'muscle_d_d_{i}'] = exc_param()
                self.params[f'muscle_d_v_{i}'] = inh_param()
                self.params[f'muscle_v_v_{i}'] = exc_param()
                self.params[f'muscle_v_d_{i}'] = inh_param()

    def reset(self):
        self.timestep = 0

    def log_activity(self, activity_type, neuron, param_value):
      """Logs an active connection between neurons."""
      self.connections_log.append((self.timestep, activity_type, neuron, param_value))

    def forward(
            self,
            joint_pos,
            right_control=None,
            left_control=None,
            speed_control=None,
            timesteps=None,
            log_activity=True,
    ):
        """Forward pass.

    Args:
      joint_pos (torch.Tensor): Joint positions in [-1, 1], shape (..., n_joints).
      right_control (torch.Tensor): Right turn control in [0, 1], shape (..., 1).
      left_control (torch.Tensor): Left turn control in [0, 1], shape (..., 1).
      speed_control (torch.Tensor): Speed control in [0, 1], 0 stopped, 1 fastest, shape (..., 1).
      timesteps (torch.Tensor): Timesteps in [0, max_env_steps], shape (..., 1).

    Returns:
      (torch.Tensor): Joint torques in [-1, 1], shape (..., n_joints).
    """

        exc = self.exc
        inh = self.inh
        ws = self.ws


        # Separate into dorsal and ventral sensor values in [0, 1], shape (..., n_joints).
        joint_pos_d = joint_pos.clamp(min=0, max=1)
        joint_pos_v = joint_pos.clamp(min=-1, max=0).neg()

        # Convert speed signal from acceleration into brake.
        if self.include_speed_control:
            assert speed_control is not None
            speed_control = 1 - speed_control.clamp(min=0, max=1)

        joint_torques = []  # [shape (..., 1)]
        for i in range(self.n_joints):
            bneuron_d = bneuron_v = torch.zeros_like(joint_pos[..., 0, None])  # shape (..., 1)
            D_d = D_v = torch.zeros_like(joint_pos[..., 0, None])  # shape (..., 1)

            # B-neurons recieve proprioceptive input from previous joint to propagate waves down the body.
            if self.include_proprioception and i > 0:
                bneuron_d = bneuron_d + joint_pos_d[
                    ..., i - 1, None] * exc(self.params[ws(f'bneuron_d_prop_{i}', 'bneuron_prop')])
                bneuron_v = bneuron_v + joint_pos_v[
                    ..., i - 1, None] * exc(self.params[ws(f'bneuron_v_prop_{i}', 'bneuron_prop')])

                self.log_activity('exc', f'bneuron_d_prop_{i}', self.params['bneuron_prop'].item())
                self.log_activity('exc', f'bneuron_v_prop_{i}', self.params['bneuron_prop'].item())

                # adding in the new variables for contra muscles for dorsal and ventral
                if self.include_contra_feedback_control:
                    contra_feedback = bneuron_d + bneuron_v

            # B-neurons receive self inhibi
            if self.include_self_feedback_inh:
                D_d = bneuron_d + D_d + joint_pos_d[..., i - 1, None] * inh(self.params[ws(f'D_d_sfeedback_{i}', 'D_sfeedback')])
                D_v = bneuron_v + D_v + joint_pos_v[..., i - 1, None] * inh(self.params[ws(f'D_v_sfeedback_{i}', 'D_sfeedback')])

                self.log_activity('inh', f'D_d_sfeedback_{i}', self.params['D_sfeedback'].item())
                self.log_activity('inh', f'D_v_sfeedback_{i}', self.params['D_sfeedback'].item())

            # Speed control unit modulates all B-neurons.
            if self.include_speed_control:
                bneuron_d = bneuron_d + speed_control * inh(
                    self.params[ws(f'bneuron_d_speed_{i}', 'bneuron_speed')]
                )
                bneuron_v = bneuron_v + speed_control * inh(
                    self.params[ws(f'bneuron_v_speed_{i}', 'bneuron_speed')]
                )

                self.log_activity('inh', f'bneuron_d_speed_{i}', self.params['bneuron_speed'].item())
                self.log_activity('inh', f'bneuron_v_speed_{i}', self.params['bneuron_speed'].item())

            # Turn control units modulate head B-neurons.
            if self.include_turn_control and i < self.n_turn_joints:
                assert right_control is not None
                assert left_control is not None
                turn_control_d = right_control.clamp(min=0, max=1)  # shape (..., 1)
                turn_control_v = left_control.clamp(min=0, max=1)
                bneuron_d = bneuron_d + turn_control_d * exc(
                    self.params[ws(f'bneuron_d_turn_{i}', 'bneuron_turn')]
                )
                bneuron_v = bneuron_v + turn_control_v * exc(
                    self.params[ws(f'bneuron_v_turn_{i}', 'bneuron_turn')]
                )

                self.log_activity('exc', f'bneuron_d_turn_{i}', self.params['bneuron_turn'].item())
                self.log_activity('exc', f'bneuron_v_turn_{i}', self.params['bneuron_turn'].item())

            # Oscillator units modulate first B-neurons.
            if self.include_head_oscillators and i == 0:
                if timesteps is not None:
                    phase = timesteps.round().remainder(self.oscillator_period)
                    mask = phase < self.oscillator_period // 2
                    oscillator_d = torch.zeros_like(timesteps)  # shape (..., 1)
                    oscillator_v = torch.zeros_like(timesteps)  # shape (..., 1)
                    oscillator_d[mask] = 1.
                    oscillator_v[~mask] = 1.
                else:
                    phase = self.timestep % self.oscillator_period  # in [0, oscillator_period)
                    if phase < self.oscillator_period // 2:
                        oscillator_d, oscillator_v = 1.0, 0.0
                    else:
                        oscillator_d, oscillator_v = 0.0, 1.0
                bneuron_d = bneuron_d + oscillator_d * exc(
                    self.params[ws(f'bneuron_d_osc_{i}', 'bneuron_osc')]
                )
                bneuron_v = bneuron_v + oscillator_v * exc(
                    self.params[ws(f'bneuron_v_osc_{i}', 'bneuron_osc')]
                )
                self.log_activity('exc', f'bneuron_d_osc_{i}', self.params['bneuron_osc'].item())
                self.log_activity('exc', f'bneuron_v_osc_{i}', self.params['bneuron_osc'].item())

            # B-neuron activation.
            bneuron_d = graded(bneuron_d)
            bneuron_v = graded(bneuron_v)

            if self.include_contra_feedback_control:
                contra_feedback = graded(contra_feedback)

            if self.include_self_feedback_inh:
                # D_d and D_v activation
                D_d = graded(D_d)
                D_v = graded(D_v)

            # Muscles receive excitatory ipsilateral and inhibitory contralateral input.
            if self.include_contra_feedback_control and not self.include_self_feedback_inh:
                muscle_d = graded(bneuron_d * exc(self.params[ws(f'muscle_d_d_{i}', 'muscle_ipsi')]) + contra_feedback * inh(self.params[ws(f'muscle_d_v_{i}', 'muscle_contra')]))
                muscle_v = graded(bneuron_v * exc(self.params[ws(f'muscle_v_v_{i}', 'muscle_ipsi')]) + contra_feedback * inh(self.params[ws(f'muscle_v_d_{i}', 'muscle_contra')]))
            elif self.include_contra_feedback_control and self.include_self_feedback_inh:
                muscle_d = graded(D_d * exc(self.params[ws(f'muscle_d_d_{i}', 'muscle_ipsi')]) + contra_feedback * inh(self.params[ws(f'muscle_d_v_{i}', 'muscle_contra')]))
                muscle_v = graded(D_v * exc(self.params[ws(f'muscle_v_v_{i}', 'muscle_ipsi')]) + contra_feedback * inh(self.params[ws(f'muscle_v_d_{i}', 'muscle_contra')]))
            else:
                muscle_d = graded(
                    bneuron_d * exc(self.params[ws(f'muscle_d_d_{i}', 'muscle_ipsi')]) +
                    bneuron_v * inh(self.params[ws(f'muscle_d_v_{i}', 'muscle_contra')])
                )
                muscle_v = graded(
                    bneuron_v * exc(self.params[ws(f'muscle_v_v_{i}', 'muscle_ipsi')]) +
                    bneuron_d * inh(self.params[ws(f'muscle_v_d_{i}', 'muscle_contra')])
                )
            self.log_activity('exc', f'muscle_d_d_{i}', self.params['muscle_ipsi'].item())
            self.log_activity('exc', f'muscle_v_v_{i}', self.params['muscle_ipsi'].item())
            self.log_activity('inh', f'muscle_d_v_{i}', self.params['muscle_contra'].item())
            self.log_activity('inh', f'muscle_v_d_{i}', self.params['muscle_contra'].item())

            # Joint torque from antagonistic contraction of dorsal and ventral muscles.
            joint_torque = muscle_d - muscle_v
            joint_torques.append(joint_torque)

        self.timestep += 1

        out = torch.cat(joint_torques, -1)  # shape (..., n_joints)
        return out

Swimmer actor Wrapper

In [None]:
# trying simple to control speed and turn, instead of some model
def simple_controller(observations):
    # Implement your logic to generate control signals
    right = observations[:, 0] * 0.1
    left = observations[:, 1] * 0.1
    speed = observations[:, 2] * 0.1
    return right, left, speed

In [None]:
class SwimmerActor(nn.Module):
    def __init__(
            self,
            swimmer,
            controller=None,
            distribution=None,
            timestep_transform=(-1, 1, 0, 1000),
    ):
        super().__init__()
        self.swimmer = swimmer
        self.controller = controller
        self.distribution = distribution
        self.timestep_transform = timestep_transform

    def initialize(
            self,
            observation_space,
            action_space,
            observation_normalizer=None,
    ):
        self.action_size = action_space.shape[0]

    def get_connections_log_values(self):
      return self.swimmer.connections_log

    def forward(self, observations):
        joint_pos = observations[..., :self.action_size]
        timesteps = observations[..., -1, None]

        # Normalize joint positions by max joint angle (in radians).
        joint_limit = 2 * np.pi / (self.action_size + 1)  # In dm_control, calculated with n_bodies.
        joint_pos = torch.clamp(joint_pos / joint_limit, min=-1, max=1)

        # Convert normalized time signal into timestep.
        if self.timestep_transform:
            low_in, high_in, low_out, high_out = self.timestep_transform
            timesteps = (timesteps - low_in) / (high_in - low_in) * (high_out - low_out) + low_out

        # Generate high-level control signals.
        if self.controller:
            right, left, speed = simple_controller(observations)
        else:
            right, left, speed = None, None, None

        # Generate low-level action signals.
        actions = self.swimmer(
            joint_pos,
            timesteps=timesteps,
            right_control=right,
            left_control=left,
            speed_control=speed,
        )
        # Pass through distribution for stochastic policy.
        if self.distribution:
            actions = self.distribution(actions)

        return actions

Train function for PPO and DDPG

In [None]:
from tonic.torch import models, normalizers

def ppo_swimmer_model(
        activity_dir_file_path: str,
        n_joints=5,
        action_noise=0.1,
        critic_sizes=(64, 64),
        critic_activation=nn.Tanh,
        **swimmer_kwargs,
):
    return models.ActorCritic(
        actor=SwimmerActor(
            swimmer=SwimmerModule(n_joints=n_joints, log_dir=activity_dir_file_path, **swimmer_kwargs),
            distribution=lambda x: torch.distributions.normal.Normal(x, action_noise),
        ),
        critic=models.Critic(
            encoder=models.ObservationEncoder(),
            torso=models.MLP(critic_sizes, critic_activation),
            head=models.ValueHead(),
        ),
        observation_normalizer=normalizers.MeanStd(),
    )


def d4pg_swimmer_model(
  activity_dir_file_path: str,
  n_joints=5,
  critic_sizes=(256, 256),
  critic_activation=nn.ReLU,
  **swimmer_kwargs,
):
  # NOTE: swimmer kwargs empty -- what modifications/params can be used.
  return models.ActorCriticWithTargets(
    actor=SwimmerActor(swimmer=SwimmerModule(n_joints=n_joints, log_dir=activity_dir_file_path, **swimmer_kwargs),),
    critic=models.Critic(
      encoder=models.ObservationActionEncoder(),
      torso=models.MLP(critic_sizes, critic_activation),
      # These values are for the control suite with 0.99 discount.
      head=models.DistributionalValueHead(-150., 150., 51),
    ),
    observation_normalizer=normalizers.MeanStd(),
  )

In [None]:
steps = '2e5'
# set ppo or ddpg
alogrithm = 'ddpg'
experiment_name = f'ncap_{alogrithm}_{weight_initalizing_type}_weight_init_steps_{steps}'
print(f'experiment name {experiment_name}')
activity_dir_file_path = os.path.join('/content/drive/My Drive/', 'data', 'experiments', 'tonic', "swimmer-swim", experiment_name)

In [None]:
train('import tonic.torch',
      #'tonic.torch.agents.PPO(model=ppo_swimmer_model(n_joints=5, activity_dir_file_path=activity_dir_file_path, critic_sizes=(256,256)))',
      'tonic.torch.agents.D4PG(model=d4pg_swimmer_model(n_joints=5, activity_dir_file_path=activity_dir_file_path, critic_sizes=(128,128)))',
      'tonic.environments.ControlSuite("swimmer-swim",time_feature=True)',
      name = experiment_name,
      trainer = 'tonic.Trainer(steps=int(2e5),save_steps=int(5e4))')

In [None]:
model_results = play_model('/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_ddpg_xavier_weight_init_steps_2e5_noclamp/')

In [None]:
# video
model_results[0]

In [None]:
path = '/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_ddpg_he_weight_init_steps_2e5/'
model_name = os.path.basename(path.rstrip('/'))
df = pd.read_csv(os.path.join(path, 'neural_activity.csv'))
df.columns = ["timestep", "neuron_type", "neuron", "value"]

In [None]:
interval = 1000
sampled_df = df[df['timestep'] % interval == 0]
min_value = sampled_df['value'].min()
max_value = sampled_df['value'].max()

In [None]:
len(sampled_df)

In [None]:
# Create the animation
fig = px.scatter(sampled_df, x="neuron", y="value", animation_frame="timestep", color="neuron_type", title="Neural Activity Over Time")

# Customize the layout for better readability
fig.update_layout(
    xaxis_title="Neuron",
    yaxis_title="Value",
    legend_title="Neuron Type",
    xaxis={'categoryorder':'total descending'},
    yaxis=dict(range=[min_value, max_value])
)

# Save the figure as an HTML file
html_file_path = os.path.join(path, 'neural_activity_animation.html')
fig.write_html(html_file_path)

# Show the figure
fig.show()

In [None]:
# from collections import defaultdict
# # neural activities
# # Create an animation of the neural activity bar plot
# neural_activity_values_list = model_results[1]

# # Extract unique neuron names
# neuron_names = list({entry['neuron'] for entry in neural_activity_values_list})
# neuron_indices = {name: idx for idx, name in enumerate(neuron_names)}
# # Group entries by timestep
# timesteps = defaultdict(list)
# for entry in neural_activity_values_list:
#     timesteps[entry['timestep']].append(entry)

# # Sort timesteps for animation
# sorted_timesteps = sorted(timesteps.keys())

# # Initialize the plot
# fig, ax = plt.subplots()
# scatter = ax.scatter([], [], s=100)

# # Set plot limits
# ax.set_xlim(-0.5, len(neuron_names) - 0.5)
# ax.set_ylim(-1.5, 1.5)
# ax.set_xticks(range(len(neuron_names)))
# ax.set_xticklabels(neuron_names, rotation=90)
# ax.set_yticks([-1, 0, 1])
# ax.set_ylabel('Activity')

# # Update function for animation
# def update(frame):
#     timestep = sorted_timesteps[frame]
#     x = [neuron_indices[entry['neuron']] for entry in timesteps[timestep]]
#     y = [entry['activity'] for entry in timesteps[timestep]]

#     scatter.set_offsets(list(zip(x, y)))
#     ax.set_title(f'Timestep: {timestep}')

# # Create animation
# ani = FuncAnimation(fig, update, frames=len(sorted_timesteps), repeat=False)

# # Display animation in Jupyter notebook
# from IPython.display import HTML
# HTML(ani.to_jshtml())

Plotting Functions

In [None]:
def plot_performance(paths, title='Model Performance'):
    """
    Plots the performance of multiple models on the same axes using Plotly for interactive visualization.

    Reads CSV log files from specified paths and plots the mean episode scores
    achieved during testing against the cumulative time steps for each model.
    The plot uses a logarithmic scale for the x-axis to better display the progression
    over a wide range of steps. Each line's legend is set to the name of the last folder
    in the path, representing the model's name.

    Parameters:
    - paths (list of str): Paths to the experiment directories.
    """
    fig = go.Figure()

    for index, path in enumerate(paths):
        # Extract the model name from the path
        model_name = os.path.basename(path.rstrip('/'))

        # Load data
        df = pd.read_csv(os.path.join(path, 'log.csv'))
        scores = df['test/episode_score/mean']
        lengths = df['test/episode_length/mean']
        scores_min = df['test/episode_score/min']
        scores_max = df['test/episode_score/max']
        steps = np.cumsum(lengths)

        # Add line plot for mean scores
        fig.add_trace(go.Scatter(
            x=steps, y=scores,
            mode='lines',
            name=f"{model_name} mean",
            line=dict(color=f'rgba({index*50 % 255},{index*100 % 255},{index*150 % 255},1)')
        ))

        # Add shaded area for min and max scores
        fig.add_trace(go.Scatter(
            x=np.concatenate([steps, steps[::-1]]),
            y=np.concatenate([scores_max, scores_min[::-1]]),
            fill='toself',
            fillcolor=f'rgba({index*50 % 255},{index*100 % 255},{index*150 % 255},0.2)',
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False
        ))

    fig.update_layout(
        title=title,
        xaxis=dict(title='Cumulative Time Steps', type='log', tickvals=[1e4, 1e5, 1e6], ticktext=['10^4', '10^5', '10^6']),
        yaxis=dict(title='Episode Avg Score'),
        legend_title='Models',
        template='plotly_white'
    )

    fig.show()

def plot_actor_critic_loss(paths, title='Actor/Critic Loss'):
    fig = go.Figure()
    for index, path in enumerate(paths):
        # Extract the model name from the path
        model_name = os.path.basename(path.rstrip('/'))
        # Load data
        df = pd.read_csv(os.path.join(path, 'log.csv'))
        actor_loss = df['actor/loss']
        critic_loss = df['critic/loss']
        lengths = df['test/episode_length/mean']
        steps = np.cumsum(lengths)

        fig.add_trace(go.Scatter(
            x=steps, y=actor_loss,
            mode='lines',
            name=f"{model_name} actor loss",
            line=dict(color=f'rgba({index*50 % 255},{index*100 % 255},{index*150 % 255},1)')
        ))

        # Add line plot for critic loss
        fig.add_trace(go.Scatter(
            x=steps, y=critic_loss,
            mode='lines',
            name=f"{model_name} critic loss",
            line=dict(dash='dash', color=f'rgba({(index+1)*50 % 255},{(index+1)*100 % 255},{(index+1)*150 % 255},0.5)')
        ))
    fig.update_layout(
        title=title,
        xaxis=dict(title='Cumulative Time Steps', type='log', tickvals=[1e4, 1e5, 1e6], ticktext=['10^4', '10^5', '10^6']),
        yaxis=dict(title='Actor/Critic Loss'),
        legend_title='Models',
        template='plotly_white'
    )
    fig.show()

def plot_entropy_kl_divergence_actor_performance_PPO(model_path: str, title='Actor Entropy and KL Divergence'):
    model_name = os.path.basename(model_path.rstrip('/'))
    # Load data
    df = pd.read_csv(os.path.join(model_path, 'log.csv'))
    actor_entropy = df['actor/entropy']
    actor_kl = df['actor/kl']
    lengths = df['test/episode_length/mean']
    steps = np.cumsum(lengths)

    # Create traces
    fig = go.Figure()

    fig.add_trace(go.Scatter(
        x=steps, y=actor_entropy,
        mode='lines',
        name=f"{model_name} actor entropy",
        line=dict(color='blue')
    ))

    fig.add_trace(go.Scatter(
        x=steps, y=actor_kl,
        mode='lines',
        name=f"{model_name} actor kl divergence",
        line=dict(color='red')
    ))

    # Update layout
    fig.update_layout(
        title=title,
        xaxis=dict(title='Cumulative Time Steps', type='log', tickvals=[1e4, 1e5, 1e6], ticktext=['10^4', '10^5', '10^6']),
        yaxis=dict(title='Value'),
        legend_title='Metrics',
        template='plotly_white'
    )

    fig.show()

In [None]:
def plot_action_variance_DDPG(paths, title='Action Variance'):
    fig = go.Figure()
    colors = DEFAULT_PLOTLY_COLORS
    for index, path in enumerate(paths):
        model_name = os.path.basename(path.rstrip('/'))
        df = pd.read_csv(os.path.join(path, 'log.csv'))
        # Confirm it's for variance
        action_variance = df['test/action/std']
        lengths = df['test/episode_length/mean']
        steps = np.cumsum(lengths)

        fig.add_trace(go.Scatter(
            x=steps, y=action_variance,
            mode='lines',
            name=model_name,
            line=dict(color=colors[index % len(colors)])
        ))
    fig.update_layout(
        title=title,
        xaxis=dict(title='Cumulative Time Steps', type='log', tickvals=[1e4, 1e5, 1e6], ticktext=['10^4', '10^5', '10^6']),
        yaxis=dict(title='Action Variance'),
        legend_title='Models',
        template='plotly_white'
    )
    fig.show()

Comparions Models with respect to episode avg score

In [None]:
basline_paths = [
    '/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_ddpg_constant_weight_init_steps_2e5_osc_off/',
    '/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_ddpg_constant_weight_init_steps_2e5_prop_off/',
]
plot_performance(basline_paths, title='Baseline NCAP models comparison with PPO and DDPG')
plt.tight_layout()
plt.show()

In [None]:
wt_init_paths = [
    '/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_ddpg_constant_weight_init_steps_2e5_osc_off/',
    '/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_ddpg_constant_weight_init_steps_2e5_prop_off/',
]
plot_performance(wt_init_paths, title='NCAP models comparison with movement parameters')
plt.tight_layout()
plt.show()

Plot critic loss

In [None]:
loss_paths = [
    '/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_baseline_ddpg_constant_weight_init_steps_2e5/',
    '/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_ddpg_uniform_weight_init_steps_2e5/',
    '/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_ddpg_he_weight_init_steps_2e5/',
    '/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_ddpg_xavier_weight_init_steps_2e5/',
    '/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_ddpg_he_weight_init_steps_2e5_noclamp/',
]
plot_actor_critic_loss(loss_paths, title="Actor/Critic Loss : Different Weight Initialization")

Entropy and KL Divergence

In [None]:
#plot_entropy_kl_divergence_actor_performance_PPO('/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_ppo_baseline_steps_6e5/', title='NCAP PPO')

In [None]:
action_vr_path = [
    '/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_ddpg_constant_weight_init_steps_2e5_contra_feedback/',
    '/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_ddpg_constant_weight_init_steps_2e5_contra_feedback_self_inf/',
]
plot_action_variance_DDPG(paths=action_vr_path, title="Feedback : Action Variance Comparisons")

### Testing the visualization network [not working at the moment]

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np


def draw_network(mode='NCAP', N=2, include_speed_control=False, include_turn_control=False, node_colors=None):
    """
    Draws a network graph for a swimmer model based on either NCAP or MLP architecture.

    Parameters:
    - mode (str): Determines the architecture type ('NCAP' or 'MLP'). Defaults to 'NCAP'.
    - N (int): Number of joints in the swimmer model. Defaults to 2.
    - include_speed_control (bool): If True, includes nodes for speed control in the graph.
    - include_turn_control (bool): If True, includes nodes for turn control in the graph.
    """
    G = nx.DiGraph()

    n=2+N*4

    nodes =dict()

    if include_speed_control:
      nodes['1-s'] = n+7
    if include_turn_control:
      nodes['r'] = n+5
      nodes['l'] = n+3

    nodes['o'] = n-1
    nodes['$o^d$'] = n-0
    nodes['$o^v$']= n-2

    custom_node_positions = {}
    custom_node_positions['o'] = (1, nodes['o'])
    custom_node_positions['$o^d$'] = (1.5, nodes['$o^d$'])
    custom_node_positions['$o^v$'] = (1.5, nodes['$o^v$'])


    if include_speed_control:
      custom_node_positions['1-s'] = (1.5, nodes['1-s'])
    if include_turn_control:
      custom_node_positions['r'] = (1.5, nodes['r'])
      custom_node_positions['l'] = (1.5, nodes['l'])

    for i in range(1,N+1):
      nodes[f'$q_{i}$'] = 4*(N-i) + 1
      nodes[f'$q^d_{i}$'] = 4*(N-i) + 2
      nodes[f'$q^v_{i}$'] = 4*(N-i)
      nodes[f'$b^d_{i}$'] = 4*(N-i) + 2
      nodes[f'$b^v_{i}$'] = 4*(N-i)
      nodes[f'$m^d_{i}$'] = 4*(N-i) + 2
      nodes[f'$m^v_{i}$'] = 4*(N-i)
      nodes['$\overset{..}{q}$' + f'$_{i}$'] = 4*(N-i) + 1

      custom_node_positions[f'$q_{i}$'] = (1, nodes[f'$q_{i}$'])
      custom_node_positions[f'$q^d_{i}$'] = (1.5, nodes[f'$q^d_{i}$'])
      custom_node_positions[f'$q^v_{i}$'] = (1.5, nodes[f'$q^v_{i}$'])
      custom_node_positions[f'$b^d_{i}$'] = (2, nodes[f'$b^d_{i}$'])
      custom_node_positions[f'$b^v_{i}$'] = (2, nodes[f'$b^v_{i}$'])
      custom_node_positions[f'$m^d_{i}$'] = (2.5, nodes[f'$m^d_{i}$'])
      custom_node_positions[f'$m^v_{i}$'] = (2.5, nodes[f'$m^v_{i}$'])
      custom_node_positions['$\overset{..}{q}$' + f'$_{i}$'] = (3, nodes['$\overset{..}{q}$' + f'$_{i}$'])

    for node, layer in nodes.items():
        G.add_node(node, layer=layer)

    if mode=='NCAP':
        # Add edges between nodes
        edges_colors = ['green', 'orange', 'green', 'green']
        edge_labels = {
            ('o', '$o^d$'):'+1',
            ('o', '$o^v$'):'-1',
            ('$o^d$', '$b^d_1$'):'o',
            ('$o^v$', '$b^v_1$'):'o'
            }

        if include_speed_control:
          edges_colors += ['orange']
          edge_labels[('1-s', '$b^d_1$')] = 's, to all b'
        if include_turn_control:
          edges_colors += ['green', 'green']
          edge_labels[('r', '$b^d_1$')] = 't'
          edge_labels[('l', '$b^v_1$')] = 't'


        for i in range(1,N+1):
          if i < N:
            edges_colors += ['green', 'orange', 'green', 'green']

            edge_labels[((f'$q_{i}$', f'$q^d_{i}$'))] = '+1'
            edge_labels[((f'$q_{i}$', f'$q^v_{i}$'))] = '-1'
            edge_labels[((f'$q^d_{i}$', f'$b^d_{i+1}$'))] = 'p'
            edge_labels[((f'$q^v_{i}$', f'$b^v_{i+1}$'))] = 'p'

          edges_colors += ['green', 'orange', 'green', 'orange',
                          'orange', 'green']

          edge_labels[((f'$b^d_{i}$', f'$m^d_{i}$'))] = 'i'
          edge_labels[((f'$b^d_{i}$', f'$m^v_{i}$'))] = 'c'
          edge_labels[((f'$b^v_{i}$', f'$m^v_{i}$'))] = 'i'
          edge_labels[((f'$b^v_{i}$', f'$m^d_{i}$'))] = 'c'
          edge_labels[((f'$m^v_{i}$', '$\overset{..}{q}$' + f'$_{i}$'))] = '-1'
          edge_labels[((f'$m^d_{i}$', '$\overset{..}{q}$' + f'$_{i}$'))] = '+1'

        edges = edge_labels.keys()
    # G.add_edges_from(edges)

    # # Draw the graph using the custom node positions
    # options = {"edge_color": edges_colors, "edgecolors": "tab:gray", "node_size": 500, 'node_color':'white'}
    # nx.draw(G, pos=custom_node_positions, with_labels=True, arrowstyle="-", arrowsize=20, **options)
    # nx.draw_networkx_edge_labels(G, pos=custom_node_positions, edge_labels=edge_labels)
    # Get colors for nodes
    color_map = []
    for node in G.nodes():
        if node_colors and node in node_colors:
            color_map.append(node_colors[node])
        else:
            color_map.append('grey')  # default color

    # Draw the graph
    pos = custom_node_positions  # using the custom positions
    plt.figure(figsize=(20, 16))
    nx.draw(G, pos, with_labels=True, node_color=color_map, node_size=7000, font_size=10, font_color='black', font_weight='bold', arrows=True)
    plt.show()

In [None]:
# Map your neural activity data to the node names used in the graph
neural_activity_logs_path = '/content/drive/My Drive/data/experiments/tonic/swimmer-swim/ncap_test_ppo_uniform_weight_init_steps_1e5_connection_logs/neural_activity.csv'
df = pd.read_csv(neural_activity_logs_path)
neural_activity_list = df.to_dict(orient='records')

# node_colors = {}
# for log in neural_activity_list:
#   timestamp, activity_type, neuron = log
#   if activity_type == 'exc':
#       color = 'red'
#   else:
#       color = 'blue'
#   if 'd_prop' in neuron:
#       node_number = neuron.split('_')[-1]
#       node_colors[f'$b^d_{node_number}$'] = color
#   if 'v_prop' in neuron:
#       node_number = neuron.split('_')[-1]
#       node_colors[f'$b^v_{node_number}$'] = color

breakpoint()
#draw_network('NCAP', N=6, include_speed_control=True, include_turn_control=True, node_colors=node_colors)