In [None]:
!git clone -b fix-jax-compatability https://github.com/google/deluca.git

In [None]:
%cd deluca

!pip install -e .

This cell creates an LDS environment and plots the signal with random activations.

In [None]:
#%load_ext autoreload
#%autoreload 2
import jax
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import sys
sys.path.append("../../")

from deluca.core import Env
from deluca.core import field
from deluca.envs import _lds as lds



env = lds.LDS() #LDS with random initiations and dimensions as specified
#env = lds.LDS(2,10,3) #LDS with random initiations and dimensions as specified
d_obs = 2
d_action = 3
d_hidden = 10
obs = env.init(d_in=d_action,d_hidden=d_hidden,d_out=d_obs)
env.show_me_the_signal(1000)


Now I'd like to plot the signal with different activations, but instead of a build in function, i want to use agents.

In [None]:
from deluca.core import Agent
from deluca.agents._grc import GRC
from deluca.agents._random import SimpleRandom
from deluca.agents._lqg import LQG
from deluca.agents._zero import Zero
from deluca.agents._sfc import SFC



def plot_lds_with_a_given_agent(env, agent, name_of_agent, traj_length = 1000, window_size=50, key = jax.random.PRNGKey(0)):
  obs = np.zeros(shape = (d_obs,1))
  action = agent(obs)
  losses = np.zeros(traj_length)
  window_losses = np.zeros(traj_length - window_size)

  for i in range(traj_length):
    obs = env(action, key)
    action = agent(obs)
    losses[i] = obs[0,0] # first coordinate of the observation
    agent.update(obs,action)
    if i >= window_size - 1:
      window_losses[i-window_size] = np.mean(losses[i - window_size + 1 : i + 1])

  plt.plot(losses)
  plt.xlabel('Time step (sliding window)')
  plt.ylabel('Loss')
  plt.title(f"Loss of {name_of_agent} on LDS")
  plt.show()



env = lds.LDS() #LDS with random initiations and dimensions as specified
d_obs = 2
d_action = 3
d_hidden = 10
obs = env.init(d_in=d_action,d_hidden=d_hidden,d_out=d_obs)

# fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# axes = axes.flatten()

obs = env.init(d_in=d_action,d_hidden=d_hidden,d_out=d_obs)
agent = SimpleRandom(d_action)
agent.init()
plot_lds_with_a_given_agent(env,agent, "Random agent")

obs = env.init(d_in=d_action,d_hidden=d_hidden,d_out=d_obs)
agent = GRC(env.A,env.B,env.C)
agent.init()
plot_lds_with_a_given_agent(env,agent, "GRC")

# obs = env.init(d_in=d_action,d_hidden=d_hidden,d_out=d_obs)
# agent = Zero(d_action)
# agent.init(d_action)
# plot_lds_with_a_given_agent(env,agent, "Zero agent")

obs = env.init(d_in=d_action,d_hidden=d_hidden,d_out=d_obs)
agent = SFC(env.A,env.B,env.C)
agent.init()
plot_lds_with_a_given_agent(env,agent, "SFC")



In [None]:
# prompt: I want the same cell as above, but plot the figures side by side
%load_ext autoreload
%autoreload 2
from deluca.core import Agent
from deluca.agents._grc import GRC
from deluca.agents._random import SimpleRandom
from deluca.agents._lqg import LQG
from deluca.agents._zero import Zero
from deluca.agents._sfc import SFC

import matplotlib.pyplot as plt
import numpy as np

def plot_lds_with_a_given_agent(env, agent, name_of_agent, traj_length=1000, window_size=50, ax=None):
    obs = np.zeros(shape=(d_obs, 1))
    action = agent(obs)
    losses = np.zeros(traj_length)
    window_losses = np.zeros(traj_length - window_size)

    for i in range(traj_length):
        obs = env(action)
        action = agent(obs)
        losses[i] = obs[0, 0]  # first coordinate of the observation
        agent.update(obs, action)
        if i >= window_size - 1:
            window_losses[i - window_size] = np.mean(losses[i - window_size + 1: i + 1])

    ax.plot(window_losses)
    ax.set_xlabel('Time step (sliding window)')
    ax.set_ylabel('Loss')
    ax.set_title(f"Loss of {name_of_agent} on LDS")


env = lds.LDS()  # LDS with random initiations and dimensions as specified
d_obs = 2
d_action = 3
d_hidden = 10
obs = env.init(d_action, d_hidden, d_obs)

fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes = axes.flatten()

agent = GRC(env.A, env.B, env.C)
agent.init()
plot_lds_with_a_given_agent(env, agent, "GRC", ax=axes[0])

agent = SimpleRandom(d_action)
agent.init()
plot_lds_with_a_given_agent(env, agent, "Random agent", ax=axes[1])

agent = Zero(d_action)
agent.init(d_action)
plot_lds_with_a_given_agent(env, agent, "Zero agent", ax=axes[2])

#agent = SFC(env.A, env.B, env.C)
#agent.init()
plot_lds_with_a_given_agent(env, agent, "SFC", ax=axes[3])

plt.tight_layout()
plt.show()