In [1]:
import numpy as np
import main_functions as mf
import sys
np.set_printoptions(threshold=sys.maxsize)
import plotly.express as px
import plotly.graph_objects as go
from controlsignalpath import ControlSignalPath
from models import ControlModel
import controlsignalpath

## Synthetic Data

We will generate dynamics with a single A (latent state) and a single control input.

In [2]:
dyns = mf.create_slds(K=1, D_obs=4, D_control=1, fix_point_change=False)
X = dyns.generate(T=1000, fix_point_change=False, add_noise=True).squeeze()
X1 = X[:,:-1]
X2 = X[:,1:]

In [3]:
px.line(X.T, title='True Data').update_layout(xaxis_title='time', yaxis_title='input', legend_title='control signal', showlegend=False).show()

In [4]:
controls = dyns.u_.T
px.line(controls, title='True Control Signal').update_layout(xaxis_title='time', yaxis_title='input', legend_title='control signal', showlegend=False)

## Initialization

In [5]:
num_iter = 100 # number of iterations of the algorithm
D_obs, T = X1.shape # dimension of the observation space x time points
D_control = 1 # dimension of the control space
control_density = 0.1 # density of control inputs

## Model Fitting

In [6]:
CM = ControlModel(control_density)
CM.fit(X1, X2, num_iter, D_control=1)

Iteration 1/100
Iteration 2/100
Iteration 3/100
Iteration 4/100
Iteration 5/100
Iteration 6/100
Iteration 7/100
Iteration 8/100
Iteration 9/100
Iteration 10/100
Iteration 11/100
Iteration 12/100
Iteration 13/100
Iteration 14/100
Iteration 15/100
Iteration 16/100
Iteration 17/100
Iteration 18/100
Iteration 19/100
Iteration 20/100
Iteration 21/100
Iteration 22/100
Iteration 23/100
Iteration 24/100
Iteration 25/100
Iteration 26/100
Iteration 27/100
Iteration 28/100
Iteration 29/100
Iteration 30/100
Iteration 31/100
Iteration 32/100
Iteration 33/100
Iteration 34/100
Iteration 35/100
Iteration 36/100
Iteration 37/100
Iteration 38/100
Iteration 39/100
Iteration 40/100
Iteration 41/100
Iteration 42/100
Iteration 43/100
Iteration 44/100
No values above threshold 2.8405012759717514 in control input 0


In [7]:
px.line(CM.all_U[0].T, title='Control Signal Initialization').update_layout(xaxis_title='time', yaxis_title='input', legend_title='control signal', showlegend=False).show()

In [8]:
idx=44
px.line(CM.all_U[:idx+1].squeeze().T, title=f'all Control Signals up to {idx}').update_layout(xaxis_title='time', yaxis_title='input', legend_title='U').show()

# Calculate the best control signal

In [9]:
control_signal = controlsignalpath.ControlSignalPath(X, CM.all_A, CM.all_B, CM.all_U, num_iter)

In [10]:
control_signal.calc_best_control_signal(window=2)
best_U = control_signal.U

In [17]:
px.line(control_signal.errors, color_discrete_map= {0: 'red', 1: 'blue'}).update_layout(xaxis_title='iteration', yaxis_title=f'RMSE (window={control_signal.window})', title='Errors').show()

In [18]:
px.line(control_signal.scores, title='AIC Scores for all iterations').update_layout(xaxis_title='iteration', yaxis_title=f'AIC (window={control_signal.window})', showlegend=False).show()

In [19]:
fig = go.Figure(data = [
    go.Scatter(x=np.arange(T), y=best_U.squeeze(), mode='lines', name='Best U'),
    go.Scatter(x=np.arange(T), y=controls.squeeze(), mode='lines', name='Ground Truth')
],
    layout = {"xaxis": {"title": "time"}, "yaxis": {"title": "input"}, "title": "Best Control Signal vs Ground Truth"})
fig.show()

In [20]:
idx = 31
px.line(CM.all_U[idx].squeeze().T, title=f'Control Signal after {idx} iterations').update_layout(xaxis_title='time', yaxis_title='input', legend_title='U').show()

## Visualize

In [22]:
X3 = np.zeros((T, D_obs, 1))
initial_conditions = np.random.randn(CM.all_A[0].shape[0])
X3[0] = np.array(initial_conditions).reshape(-1, 1)
for i in range(1, T-1):
            X3[i] = CM.all_A[np.argmax(control_signal.scores)] @ X3[i-1] + \
                CM.all_B[np.argmax(control_signal.scores)] @ best_U[:, i-1].reshape(-1, 1)

In [23]:
dynamics = np.vstack([X3[:, k, 0].flatten() for k in range(D_obs)])
fig = px.line(dynamics.T, title='Reconstructed SLDS with controls Dynamics', labels={'index': 'time', 'value': 'state (e.g. F)'})
fig.show()