Skip to content

Commit

Permalink
Update plot_oc.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lenasal committed Sep 20, 2023
1 parent 602f33d commit 7ca319a
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions neurolib/control/optimal_control/oc_utils/plot_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

colors = ["red", "blue", "green", "orange"]


def plot_oc_singlenode(
duration,
dt,
Expand All @@ -15,29 +14,40 @@ def plot_oc_singlenode(
plot_state_vars=[0, 1],
plot_control_vars=[0, 1],
):
"""Plot target and controlled dynamics for a network with a single node.
"""Plot target and controlled dynamics for a single node.
:param duration: Duration of simulation (in ms).
:type duration: float
:param dt: Time discretization (in ms).
:type dt: float
:param state: The state of the system controlled with the found oc-input.
:type state: np.ndarray
:param target: The target state.
:type target: np.ndarray
:param control: The control signal found by the oc-algorithm.
:type control: np.ndarray
:param orig_input: The inputs that were used to generate target time series.
:type orig_input: np.ndarray
:param cost_array: Array of costs in optimization iterations.
:param color_x: Color used for plots of x-population variables.
:param color_y: Color used for plots of y-population variables.
:type cost_array: np.ndarray, optional
:param plot_state_vars: List of indices of state variables that should be plotted
:type plot_state_vars: List, optional
:param plot_control_vars: List of indices of control variables that should be plotted
:type plot_control_vars: List, optional
"""
fig, ax = plt.subplots(3, 1, figsize=(8, 6), constrained_layout=True)

# Plot the target (dashed line) and unperturbed activity
t_array = np.arange(0, duration + dt, dt)

# Plot the controlled state and the initial/ original state (dashed line)
for v in plot_state_vars:
ax[0].plot(t_array, state[0, v, :], label="state var " + str(v), color=colors[v])
ax[0].plot(t_array, target[0, v, :], linestyle="dashed", label="target var " + str(v), color=colors[v])
ax[0].legend(loc="upper right")
ax[0].set_title("Activity without stimulation and target activity")

# Plot the target control signal (dashed line) and "initial" zero control signal
# Plot the computed control signal and the initial/ original control signal (dashed line)
for v in plot_control_vars:
ax[1].plot(t_array, control[0, v, :], label="stimulation var " + str(v), color=colors[v])
ax[1].plot(t_array, orig_input[0, v, :], linestyle="dashed", label="input var " + str(v), color=colors[v])
Expand All @@ -63,37 +73,48 @@ def plot_oc_network(
plot_state_vars=[0, 1],
plot_control_vars=[0, 1],
):
"""Plot target and controlled dynamics for a network with a single node.
"""Plot target and controlled dynamics for a network of N nodes.
:param N: Number of nodes in the network.
:type N: int
:param duration: Duration of simulation (in ms).
:type duration: float
:param dt: Time discretization (in ms).
:type dt: float
:param state: The state of the system controlled with the found oc-input.
:type state: np.ndarray
:param target: The target state.
:type target: np.ndarray
:param control: The control signal found by the oc-algorithm.
:type control: np.ndarray
:param orig_input: The inputs that were used to generate target time series.
:type orig_input: np.ndarray
:param cost_array: Array of costs in optimization iterations.
:param step_array: Number of iterations in the step-size algorithm in each optimization iteration.
:param color_x: Color used for plots of x-population variables.
:param color_y: Color used for plots of y-population variables.
:type cost_array: np.ndarray, optional
:param step_array: Array of step sizes in optimization iterations.
:type step_array: np.ndarray, optional
:param plot_state_vars: List of indices of state variables that should be plotted
:type plot_state_vars: List, optional
:param plot_control_vars: List of indices of control variables that should be plotted
:type plot_control_vars: List, optional
"""

t_array = np.arange(0, duration + dt, dt)
fig, ax = plt.subplots(3, N, figsize=(12, 8), constrained_layout=True)

# Plot the controlled state and the initial/ original state (dashed line)
for n in range(N):
for v in plot_state_vars:
ax[0, n].plot(t_array, state[n, v, :], label="state var " + str(v), color=colors[v])
ax[0, n].plot(t_array, target[n, v, :], linestyle="dashed", label="target var " + str(v), color=colors[v])
# ax[0, n].legend(loc="upper right")
ax[0, n].set_title(f"Activity and target, node %s" % (n))

# Plot the target control signal (dashed line) and "initial" zero control signal
# Plot the computed control signal and the initial/ original control signal (dashed line)
for v in plot_control_vars:
ax[1, n].plot(t_array, control[n, v, :], label="stimulation var " + str(v), color=colors[v])
ax[1, n].plot(
t_array, orig_input[n, v, :], linestyle="dashed", label="input var " + str(v), color=colors[v]
)
# ax[1, n].legend(loc="upper right")
ax[1, n].set_title(f"Stimulation and input, node %s" % (n))

ax[2, 0].plot(cost_array)
Expand Down

0 comments on commit 7ca319a

Please sign in to comment.