In [1]:
import numpy as np
import main_functions as mf
import sys
np.set_printoptions(threshold=sys.maxsize)
import plotly.express as px
import plotly.graph_objects as go
from scipy import linalg
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from importlib import reload

# Synthetic Data (Single A and Single Control Input)

We will generate dynamics with a single A (latent state) and a single control input.

In [2]:
D_control = 1
D_obs = 4
K = 1
T = 1000
sigma = 0.01

dyns = mf.create_slds(K, D_obs, D_control, fix_point_change=False)
X = dyns.generate(T=T, fix_point_change=False, sigma=sigma).squeeze()
X1 = X[:,:-1]
X2 = X[:,1:]

In [3]:
px.line(X.T, title='True Data').update_layout(xaxis_title='time', yaxis_title='input', legend_title='control signal', showlegend=False).show()

# Coefficients and Reconstruction from dLDS Model 
With coefficients being updated with Lasso

## Alpha = 0.1, Smoothness = 0, Iterations = 100

In [4]:
alpha = 0.1
smooth = 0
n_iterations = 100
dlds_model_refitted0 = mf.train_model_include_D(data=X, max_iter=n_iterations, params={'update_c_type':'lasso','reg_term':alpha, 'smooth_term':smooth}, to_print=False) 


Mean of empty slice.


invalid value encountered in scalar divide



mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
Arrived to max iter
[0.08189005117945339, 0.033918614721506715, 0.03391861472150672, 0.03391861472150673, 0.03391861472150673, 0.033918614721506715, 0.03393247087688931, 0.03418155335324336, 0.034219004318628586, 0.03421900431862859, 0.034219004318628586, 0.034219004318628586, 0.03270952251096894, 0.032713193465942474, 0.032713193465942474, 0.032713193465942474, 0.032713193465942474, 0.031939240056310575, 0.03176499914241953, 0.03176499914241953, 0.03176499914241953, 0.03176499914241953, 0.03176499914241952, 0.030477136950894632, 0.030518310859539287, 0.030518310859539287, 0.030518310859539287, 0.030518310859539287, 0.030518310859539287, 0.03256013246981331, 0.03260052879052212, 0.03260052879052212, 0.03260052879052211, 0.03260052879052212, 0.03260052879052212, 0.03213765851562424, 0.03233040221857183, 0.03233040221857182, 0.03233040221857182, 

In [5]:
px.line(dlds_model_refitted0[0].T, title=f'coefficients with lasso [alpha={alpha}, smooth_term={smooth}] with {n_iterations} iterations').update_layout(xaxis_title='time', yaxis_title='magnitude', legend_title='mode').show()

In [9]:
# single time step reconstruction
recon = mf.create_reco(X, dlds_model_refitted0[0], dlds_model_refitted0[1])
px.line(recon.T, title='Single Time Step Reconstruction').update_layout(xaxis_title='time', yaxis_title='state', showlegend=False).show()

In [10]:
# multi time step reconstrunction
recon2 = mf.create_reco2(X, dlds_model_refitted0[0], dlds_model_refitted0[1])
px.line(recon2[0].T, title='Multi Time Step Reconstruction').update_layout(xaxis_title='time', yaxis_title='state', showlegend=False).show()

## Alpha = 0.1, Smoothness = 10, Iterations = 100

In [11]:
alpha = 0.1
smooth = 10
n_iterations = 100
dlds_model_refitted1 = mf.train_model_include_D(data=X, max_iter=n_iterations, params={'update_c_type':'lasso','reg_term':alpha, 'smooth_term':smooth}, to_print=False) 


Mean of empty slice.


invalid value encountered in scalar divide



mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
Arrived to max iter
[0.08189005117945339, 0.027126347957560457, 0.027126347957560464, 0.027126347957560464, 0.02712634795756046, 0.02712634795756046, 0.02774871355118531, 0.027559669291295642, 0.027559669291295646, 0.027559669291295642, 0.02755966929129565, 0.027559669291295646, 0.0287179029050808, 0.028666564054183805, 0.028666564054183802, 0.028666564054183802, 0.028666564054183802, 0.028666564054183802, 0.025537612380803758, 0.02571430239847998, 0.02571430239847998, 0.02571430239847998, 0.025714302398479973, 0.02571430239847998, 0.02502009274525557, 0.024634948930078195, 0.024634948930078202, 0.024634948930078195, 0.024634948930078195, 0.024634948930078195, 0.026116884266712256, 0.026307273566886488, 0.02630727356688649, 0.026307273566886488, 0.02630727356688649, 0.02630727356688649, 0.02794967393837207, 0.02787406543212166, 0.02787406543212166, 0.027874065

In [12]:
px.line(dlds_model_refitted1[0].T, title=f'coefficients with lasso [alpha={alpha}, smooth_term={smooth}] with {n_iterations} iterations').update_layout(xaxis_title='time', yaxis_title='magnitude', legend_title='mode').show()

In [13]:
# single time step reconstruction
recon = mf.create_reco(X, dlds_model_refitted1[0], dlds_model_refitted1[1])
px.line(recon.T, title=f'Single Time Step Reconstruction [alpha={alpha}, smooth_term={smooth}] with {n_iterations} iterations').update_layout(xaxis_title='time', yaxis_title='state', showlegend=False).show()

In [14]:
# multi time step reconstrunction
recon2 = mf.create_reco2(X, dlds_model_refitted1[0], dlds_model_refitted1[1])
px.line(recon2[0].T, title='Multi Time Step Reconstruction').update_layout(xaxis_title='time', yaxis_title='state', showlegend=False).show()

## Trying a bunch of other values

In [16]:
# looping over a bunch of alpha and smooth terms

config = {'staticPlot': True}

alphas = [0.005, 0.01, 0.1, 0.5, 1]
smooths = [0, 0.1, 0.5, 1, 10]
n_iterations = 100

all_coefficients = []
reconstructions_singlestep = []
reconstructions_multistep = []
for alpha in alphas:
    for smooth in smooths:
        dlds_model_refitted = mf.train_model_include_D(data=X, max_iter=n_iterations, params={'update_c_type':'lasso','reg_term':alpha, 'smooth_term':smooth}, to_print=False) 
        all_coefficients.append(dlds_model_refitted[0])
        recon = mf.create_reco(X, dlds_model_refitted[0], dlds_model_refitted[1])
        reconstructions_singlestep.append(recon)
        recon2 = mf.create_reco2(X, dlds_model_refitted[0], dlds_model_refitted[1])
        reconstructions_multistep.append(recon2[0])
        
        # plot the coefficients
        px.line(dlds_model_refitted[0].T, title=f'coefficients with lasso [alpha={alpha}, smooth_term={smooth}] with {n_iterations} iterations').update_layout(xaxis_title='time', yaxis_title='magnitude', legend_title='mode').show(config=config)
        
        # plot the single time step reconstruction
        px.line(recon.T, title=f'Single Time Step Reconstruction [alpha={alpha}, smooth_term={smooth}] with {n_iterations} iterations').update_layout(xaxis_title='time', yaxis_title='state', showlegend=False).show(config=config)
        
        # plot the multi time step reconstruction
        px.line(recon2[0].T, title=f'Multi Time Step Reconstruction [alpha={alpha}, smooth_term={smooth}] with {n_iterations} iterations').update_layout(xaxis_title='time', yaxis_title='state', showlegend=False).show(config=config)


Mean of empty slice.


invalid value encountered in scalar divide



mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
Arrived to max iter
[0.08189005117945339, 0.04409952981704032, 0.04409952981704035, 0.044099529817040364, 0.04409952981704038, 0.04409952981704036, 0.03933751085061304, 0.0386998177996783, 0.03869981779967833, 0.03869981779967833, 0.038699817799678314, 0.03869981779967832, 0.03708616095179143, 0.036939514627140506, 0.03693951462714051, 0.03693951462714049, 0.0369395146271405, 0.0369395146271405, 0.025750562043393883, 0.02656013638906802, 0.026560136389068012, 0.026560136389068026, 0.02656013638906801, 0.026560136389068033, 0.026370084255484982, 0.023502834068058524, 0.02350283406805854, 0.023502834068058528, 0.023502834068058528, 0.023502834068058535, 0.027388318310955007, 0.02903571190137321, 0.02903571190137322, 0.029035711901373225, 0.02903571190137323, 0.029035711901373225, 0.02838551700989181, 0.02826826079891101, 0.028268260798911017, 0.02826826079891102


Mean of empty slice.


invalid value encountered in scalar divide



mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
Arrived to max iter
[0.08189005117945339, 0.016875673972277617, 0.016875673972277627, 0.016875673972277624, 0.016875673972277624, 0.016875673972277624, 0.017448733927070463, 0.017114099525709497, 0.0171140995257095, 0.0171140995257095, 0.017114099525709497, 0.017114099525709497, 0.01574295049573065, 0.015795630431441755, 0.015795630431441752, 0.01579563043144175, 0.01579563043144175, 0.015795630431441752, 0.014883763411237398, 0.014527768442407251, 0.01452776844240725, 0.014527768442407251, 0.01452776844240725, 0.014527768442407251, 0.01629712520746612, 0.015140718840402863, 0.015140718840402863, 0.015140718840402861, 0.015140718840402861, 0.015140718840402861, 0.018502420495447747, 0.01897412788487486, 0.01897412788487487, 0.018974127884874868, 0.018974127884874875, 0.018974127884874868, 0.02134384462246428, 0.021191655884550253, 0.021191655884550256, 0.02119


Mean of empty slice.


invalid value encountered in scalar divide



mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
Arrived to max iter
[0.08189005117945339, 0.019439856416887415, 0.019439856416887433, 0.01943985641688743, 0.01943985641688744, 0.019439856416887433, 0.022112195454066674, 0.021850505127113763, 0.021850505127113777, 0.02185050512711377, 0.021850505127113763, 0.02185050512711377, 0.021585400118861377, 0.021583643551725068, 0.021583643551725068, 0.021583643551725054, 0.021583643551725058, 0.016967324785250724, 0.016767736889478916, 0.016767736889478926, 0.016767736889478926, 0.016767736889478923, 0.016767736889478923, 0.017410980318400902, 0.01693643618651749, 0.01693643618651748, 0.01693643618651749, 0.01693643618651749, 0.01693643618651749, 0.019487398607564, 0.021031057673332347, 0.021031057673332337, 0.021031057673332333, 0.02103105767333234, 0.02103105767333234, 0.02460652804692488, 0.023814739003944256, 0.023814739003944232, 0.02381473900394424, 0.


Mean of empty slice.


invalid value encountered in scalar divide



mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
Arrived to max iter
[0.08189005117945339, 0.02261298683148975, 0.02261298683148977, 0.02261298683148977, 0.022612986831489785, 0.02261298683148977, 0.024940374626427786, 0.024946054148033844, 0.02494605414803386, 0.024946054148033865, 0.02494605414803385, 0.027877360265982785, 0.02322467433434121, 0.023224674334341215, 0.023224674334341236, 0.023224674334341215, 0.02322467433434123, 0.022734760395013764, 0.02397070313948167, 0.023970703139481665, 0.023970703139481662, 0.02397070313948167, 0.02397070313948167, 0.019427523546064952, 0.020693574491123492, 0.020693574491123485, 0.0206935744911235, 0.020693574491123478, 0.020693574491123495, 0.019120812244147852, 0.01754863508726626, 0.017548635087266253, 0.01754863508726625, 0.017548635087266257, 0.017548635087266257, 0.024589375789637144, 0.0213813109699274, 0.021381310969927427, 0.021381310969927434, 0.021381310


Mean of empty slice.


invalid value encountered in scalar divide



mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
Arrived to max iter
[0.08189005117945339, 0.028610969512894803, 0.02861096951289482, 0.028610969512894827, 0.02861096951289484, 0.02861096951289482, 0.028395676770721374, 0.02865208220414766, 0.028652082204147677, 0.028652082204147687, 0.028652082204147673, 0.02865208220414768, 0.03209970047768274, 0.03257817322512635, 0.03257817322512636, 0.03257817322512632, 0.032578173225126336, 0.032578173225126336, 0.02601032281948827, 0.025326255460254122, 0.02532625546025412, 0.025326255460254132, 0.02532625546025411, 0.025326255460254136, 0.023570098896823828, 0.02116569322472485, 0.021165693224724856, 0.02116569322472485, 0.02116569322472485, 0.021165693224724853, 0.02400086804224511, 0.026203739578008266, 0.026203739578008273, 0.026203739578008266, 0.026203739578008277, 0.026203739578008277, 0.029198350605595343, 0.029386826055030575, 0.02938682605503058, 0.029386826


Mean of empty slice.


invalid value encountered in scalar divide



mixed F
mixed F
mixed F
mixed F
mixed F
mixed F
mixed F


KeyboardInterrupt: 