In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os
import psutil
import time
from scipy.io import loadmat
from scipy.io import savemat # store results for comparison with Matlab code   
from ssidid import ObservationScheme
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om


print('loading data')
spikes = loadmat('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/gb_net/spikes_20trials_10msBins')
spikes= spikes['spikes_out']

num_iter = 0
pi_method = 'heuristic'

print('concatenating trials')
p, T = 1000, 600

n = 10

print('number trials: ', spikes.size)
idx_n = np.sort(np.random.choice(1000, size=p, replace=False))
y = np.vstack([spikes[i][0].T[:,idx_n] for i in range(spikes.size)]).astype(np.float)
T *= spikes.size
del spikes

T,p = y.shape
idx_shuffle = np.arange(p)
np.random.seed(0)
np.random.shuffle(idx_shuffle)
y = y[:,idx_shuffle]




plt.figure(figsize=(20,20))
plt.imshow(y[1000:2000,:].T, aspect='auto', interpolation='None', cmap='gray')
plt.show()

lag_range = np.arange(2 * n)
idx_a, idx_b = np.arange(p), np.arange(p)
W = ObservationScheme(p=p, T=T).comp_coocurrence_weights(lag_range, sso=True, idx_a=idx_a, idx_b=idx_b)
Qs, _ = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme= ObservationScheme(p=p, T=T),
                      idx_a=idx_a,idx_b=idx_b,W=W,sso=True,
                      mmap=False,data_path=None,ts=None,ms=None)
del W

plt.figure(figsize=(20,20))
plt.imshow(Qs[0], interpolation='None')
plt.show()

In [None]:
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid.icml_scripts import run_default
from scipy import linalg as la

dtype=np.float
lag_range = np.arange(20)
overlap = 1000
sso=True
mmap, data_path = False, None
idx_a, idx_b = np.arange(p), np.arange(p)
sub_pops = (np.arange((p+overlap)//2),np.arange((p-overlap)//2,p))

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)

reps = 1
obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)
obs_scheme = ObservationScheme(p=p, T=T, 
                                sub_pops=sub_pops, 
                                obs_pops=obs_pops, 
                                obs_time=obs_time)

W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
if overlap < p:
    for m in range(len(lag_range)):
        W[m][0,1] = 0
        W[m][1,0] = 0

print('computing time-lagged covariances')
_, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                      idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                      mmap=mmap,data_path=data_path,ts=None,ms=None)    
if overlap < p:
    for m in range(len(lag_range)):    
        Om[m][np.ix_(obs_scheme.idx_grp[0], obs_scheme.idx_grp[1])] = False
        Om[m][np.ix_(obs_scheme.idx_grp[1], obs_scheme.idx_grp[0])] = False


In [None]:
rnd_seed = np.random.get_state()

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 1, 100, 100
a, b1, b2, e = 0.1, 0.9, 0.99, 1e-8
a_R = 1 * a
a_decay = 0.95

proj_errors = np.zeros((max_iter,n+1)) 
def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))

rnd_seed_fit = np.random.get_state()    
pars_est={
    'C' : np.asarray(la.orth(np.random.normal(size=(p,n))),dtype=dtype) * np.sqrt(p) / np.sqrt(n),
    'A' : np.asarray(np.diag(np.linspace(0.89, 0.91, n)), dtype=dtype),
    'Q' : np.asarray(np.eye(n), dtype=dtype),
    'R' : 2*np.ones(p, dtype=dtype)
}        
_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,pars_init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=None,save_every=np.inf,data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')

In [None]:
rnd_seed = np.random.get_state()

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 10, 5000, 100
a, b1, b2, e = 0.0001, 0.99, 0.99, 1e-8
a_R = 1 * a
a_decay = 0.98

_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,pars_init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=None,save_every=np.inf,data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')

In [None]:
rnd_seed = np.random.get_state()

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 20, 2000, 100
a, b1, b2, e = 0.00001, 0.99, 0.99, 1e-8
a_R = 1 * a
a_decay = 0.98

_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,pars_init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=None,save_every=np.inf,data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')

In [None]:
rnd_seed = np.random.get_state()

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 50, 1000, 100
a, b1, b2, e = 0.00001, 0.99, 0.99, 1e-8
a_R = 1 * a
a_decay = 0.98

_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,pars_init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=None,save_every=np.inf,data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')

In [None]:
rnd_seed = np.random.get_state()

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 100, 500, 100
a, b1, b2, e = 0.00001, 0.99, 0.99, 1e-8
a_R = 1 * a
a_decay = 0.98

_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,pars_init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=None,save_every=np.inf,data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')

In [None]:
data_path = '/home/marcel/Desktop/Projects/Stitching/results/icml_spikes/'
save_dict = {'p' : p,'n' : n,'T' : T, 'lag_range' : lag_range,
             'obs_scheme' : obs_scheme, 'y' : None,
             'pars_est' : pars_est, 'Qs' : Qs, 'Om' : Om, 'W' : W, 'idx_a' : idx_a, 'idx_b' : idx_b,
             'rnd_seed' : rnd_seed_fit, 'idx_shuffle' : idx_shuffle
            }
file_name = 'n' + str(n) + 'spikes_fullyObs'
np.savez(data_path + file_name, save_dict)              

# fit stitching SSID

In [None]:
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from ssidid.icml_scripts import run_default
from scipy import linalg as la


p,n = 1000, 10
dtype=np.float
lag_range = np.arange(20)
overlap = 100
sso=True
mmap, data_path = False, None
idx_a, idx_b = np.arange(p), np.arange(p)
sub_pops = (np.arange((p+overlap)//2),np.arange((p-overlap)//2,p))

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)

reps = 1
obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)
obs_scheme = ObservationScheme(p=p, T=T, 
                                sub_pops=sub_pops, 
                                obs_pops=obs_pops, 
                                obs_time=obs_time)

W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
if overlap < p:
    for m in range(len(lag_range)):
        W[m][0,1] = 0
        W[m][1,0] = 0

print('computing time-lagged covariances')
_, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                      idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                      mmap=mmap,data_path=data_path,ts=None,ms=None)    
if overlap < p:
    for m in range(len(lag_range)):    
        Om[m][np.ix_(obs_scheme.idx_grp[0], obs_scheme.idx_grp[1])] = False
        Om[m][np.ix_(obs_scheme.idx_grp[1], obs_scheme.idx_grp[0])] = False


In [None]:
rnd_seed = np.random.get_state()

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 1, 100, 100
a, b1, b2, e = 0.1, 0.9, 0.99, 1e-8
a_R = 1 * a
a_decay = 0.95

proj_errors = np.zeros((max_iter,n+1)) 
def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))

rnd_seed_fit = np.random.get_state()    
pars_est={
    'C' : np.asarray(la.orth(np.random.normal(size=(p,n))),dtype=dtype) * np.sqrt(p) / np.sqrt(n),
    'A' : np.asarray(np.diag(np.linspace(0.89, 0.91, n)), dtype=dtype),
    'Q' : np.asarray(np.eye(n), dtype=dtype),
    'R' : 2*np.ones(p, dtype=dtype)
}        
_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,pars_init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=None,save_every=np.inf,data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')

In [None]:

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 10, 5000, 100
a, b1, b2, e = 0.0001, 0.99, 0.99, 1e-8
a_R = 1 * a
a_decay = 0.98

_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,pars_init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=None,save_every=np.inf,data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')

In [None]:

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 20, 2000, 100
a, b1, b2, e = 0.00001, 0.99, 0.99, 1e-8
a_R = 1 * a
a_decay = 0.98

_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,pars_init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=None,save_every=np.inf,data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')

In [None]:

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 50, 1000, 100
a, b1, b2, e = 0.00001, 0.99, 0.99, 1e-8
a_R = 1 * a
a_decay = 0.98

_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,pars_init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=None,save_every=np.inf,data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')

In [None]:

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 100, 500, 200
a, b1, b2, e = 0.00001, 0.99, 0.99, 1e-8
a_R = 1 * a
a_decay = 0.98

_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,pars_init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=None,save_every=np.inf,data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')

In [None]:
data_path = '/home/marcel/Desktop/Projects/Stitching/results/icml_spikes/'
save_dict = {'p' : p,'n' : n,'T' : T, 'lag_range' : lag_range,
             'obs_scheme' : obs_scheme, 'y' : None,
             'pars_est' : pars_est, 'Qs' : Qs, 'Om' : Om, 'W' : W, 'idx_a' : idx_a, 'idx_b' : idx_b,
             'rnd_seed' : rnd_seed_fit, 'idx_shuffle' : idx_shuffle
            }
file_name = 'n' + str(n) + 'spikes_2sp_stitched'
np.savez(data_path + file_name, save_dict)              

# load data for comparison

In [None]:
data_path = '/home/marcel/Desktop/Projects/Stitching/results/icml_spikes/'

file_name = 'n' + str(n) + 'spikes_fullyObs'
load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
Qs = load_file['Qs']

file_name = 'n' + str(n) + 'spikes_2sp_stitched'
load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
pars_est = load_file['pars_est']
Om,W,idx_a,idx_b = load_file['Om'], load_file['W'], load_file['idx_a'], load_file['idx_b']
#Qs = load_file['Qs']
print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,None,False,None)

nOm = [np.invert(Om[m]) for m in range(len(lag_range))]
print_slim(Qs,nOm,lag_range,pars_est,idx_a,idx_b,None,False,None)

# fit EM

In [None]:
%matplotlib inline
import os
os.chdir("/home/marcel/Desktop/Projects/Stitching/code/pyRRHDLDS/core")
import ssm_scripts
import ssm_fit
from ssm_scripts import setup_fit_lds
from scipy import linalg as la

dtype=np.float
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


overlap = 1000
sub_pops = (np.arange((p+overlap)//2),np.arange((p-overlap)//2,p))

max_iter_EM = 100
eps_cov = 1e-20

reps = 3

for r in range(reps):
    rnd_seed_fit = np.random.get_state()    
    pars_init={
        'C' : np.asarray(la.orth(np.random.normal(size=(p,n))),dtype=dtype) * np.sqrt(p) / np.sqrt(n),
        'A' : np.asarray(np.diag(np.linspace(0.89, 0.91, n)), dtype=dtype),
        'Q' : np.asarray(np.eye(n), dtype=dtype),
        'R' : 2*np.ones(p, dtype=dtype)
    }     
    # EM    
    likes = np.zeros(max_iter_EM)
    res = np.zeros((max_iter_EM, n+1))
    obs_scheme = {'sub_pops': [sub_pops[0], sub_pops[1]],
                  'obs_pops': [0,1],
                  'obs_time': [T//2,T]}
    fit_lds = setup_fit_lds(y=y.T.reshape(p,T,1), 
                            u=None, 
                            max_iter=max_iter_EM,
                            epsilon=np.log(1.0001), 
                            eps_cov=eps_cov,
                            plot_flag=False, 
                            trace_pars_flag=True, 
                            trace_stats_flag=False, 
                            diag_R_flag=True,
                            use_A_flag=True, 
                            use_B_flag=False)

    # fit the model to data          
    print('fitting model to data')
    pars_init[ 'd' ] = np.zeros(p, dtype=dtype)
    pars_init['mu0'] = np.zeros(n ,dtype=dtype)
    pars_init[ 'V0'] = np.eye(  n ,dtype=dtype)

    pars_hat = pars_init    
    t = time.time()
    pars_hat['B'] = np.empty((n,0), dtype=dtype)
    pars_hat,ll = fit_lds(x_dim=n,
                          pars=pars_hat, 
                          obs_scheme=obs_scheme,
                          save_file=None)

    likes = np.array(ll)
    elapsed_time = time.time() - t
    print('elapsed time for fitting is')
    print(elapsed_time)
    pars_hat['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], 
                                                       pars_hat['Q'])        

    t = time.time() - t    
    print('fitting time: ', t)

    plt.figure()
    plt.plot(likes)
    plt.show()

    i = max_iter_EM-1
    pars_est = {
        'A' : pars_hat['As'][i],
        'Q' : pars_hat['Qs'][i],
        'C' : pars_hat['Cs'][i],
        'R' : pars_hat['Rs'][i]
    }
    pars_est['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_est['A'], 
                                                       pars_est['Q']) 
    pars_est['X'] = np.vstack([ np.linalg.matrix_power(pars_hat['A'],m).dot(pars_hat['Pi']) for m in lag_range])
    print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,None,False,data_path)

    nOm = [np.invert(Om[m]) for m in range(len(lag_range))]
    print_slim(Qs,nOm,lag_range,pars_est,idx_a,idx_b,None,False,data_path)


    data_path = '/home/marcel/Desktop/Projects/Stitching/results/icml_spikes/'
    save_dict = {'p' : p,'n' : n,'T' : T, 
                 'obs_scheme' : obs_scheme, 'y' : None,
                 'pars_est' : pars_hat, 'Qs' : Qs, 'Om' : Om, 'W' : W, 'idx_a' : idx_a, 'idx_b' : idx_b,
                 'rnd_seed' : rnd_seed_fit, 'eps_cov' : eps_cov, 'idx_shuffle' : idx_shuffle
                }
    file_name = 'n' + str(n) + 'spikes_EM_fullyObs' + '_init' + str(r)
    np.savez(data_path + file_name, save_dict)              

# fit stitching EM

In [None]:
%matplotlib inline
import os
os.chdir("/home/marcel/Desktop/Projects/Stitching/code/pyRRHDLDS/core")
import ssm_scripts
import ssm_fit
from ssm_scripts import setup_fit_lds
from scipy import linalg as la

dtype=np.float
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


overlap = 100
sub_pops = (np.arange((p+overlap)//2),np.arange((p-overlap)//2,p))

max_iter_EM = 100
eps_cov = 1e-20

reps = 3

for r in range(reps):
    rnd_seed_fit = np.random.get_state()    
    pars_init={
        'C' : np.asarray(la.orth(np.random.normal(size=(p,n))),dtype=dtype) * np.sqrt(p) / np.sqrt(n),
        'A' : np.asarray(np.diag(np.linspace(0.89, 0.91, n)), dtype=dtype),
        'Q' : np.asarray(np.eye(n), dtype=dtype),
        'R' : 2*np.ones(p, dtype=dtype)
    }     
    # EM    
    likes = np.zeros(max_iter_EM)
    res = np.zeros((max_iter_EM, n+1))
    obs_scheme = {'sub_pops': [sub_pops[0], sub_pops[1]],
                  'obs_pops': [0,1],
                  'obs_time': [T//2,T]}
    fit_lds = setup_fit_lds(y=y.T.reshape(p,T,1), 
                            u=None, 
                            max_iter=max_iter_EM,
                            epsilon=np.log(1.0001), 
                            eps_cov=eps_cov,
                            plot_flag=False, 
                            trace_pars_flag=True, 
                            trace_stats_flag=False, 
                            diag_R_flag=True,
                            use_A_flag=True, 
                            use_B_flag=False)

    # fit the model to data          
    print('fitting model to data')
    pars_init[ 'd' ] = np.zeros(p, dtype=dtype)
    pars_init['mu0'] = np.zeros(n ,dtype=dtype)
    pars_init[ 'V0'] = np.eye(  n ,dtype=dtype)

    pars_hat = pars_init    
    t = time.time()
    pars_hat['B'] = np.empty((n,0), dtype=dtype)
    pars_hat,ll = fit_lds(x_dim=n,
                          pars=pars_hat, 
                          obs_scheme=obs_scheme,
                          save_file=None)

    likes = np.array(ll)
    elapsed_time = time.time() - t
    print('elapsed time for fitting is')
    print(elapsed_time)
    pars_hat['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], 
                                                       pars_hat['Q'])        

    t = time.time() - t    
    print('fitting time: ', t)

    plt.figure()
    plt.plot(likes)
    plt.show()

    i = max_iter_EM-1
    pars_est = {
        'A' : pars_hat['As'][i],
        'Q' : pars_hat['Qs'][i],
        'C' : pars_hat['Cs'][i],
        'R' : pars_hat['Rs'][i]
    }
    pars_est['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_est['A'], 
                                                       pars_est['Q']) 
    pars_est['X'] = np.vstack([ np.linalg.matrix_power(pars_hat['A'],m).dot(pars_hat['Pi']) for m in lag_range])
    print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,None,False,data_path)

    nOm = [np.invert(Om[m]) for m in range(len(lag_range))]
    print_slim(Qs,nOm,lag_range,pars_est,idx_a,idx_b,None,False,data_path)


    data_path = '/home/marcel/Desktop/Projects/Stitching/results/icml_spikes/'
    save_dict = {'p' : p,'n' : n,'T' : T, 
                 'obs_scheme' : obs_scheme, 'y' : None,
                 'pars_est' : pars_hat, 'Qs' : Qs, 'Om' : Om, 'W' : W, 'idx_a' : idx_a, 'idx_b' : idx_b,
                 'rnd_seed' : rnd_seed_fit, 'eps_cov' : eps_cov, 'idx_shuffle' : idx_shuffle
                }
    file_name = 'n' + str(n) + 'spikes_EM_2sp_stitched' + '_init' + str(r)
    np.savez(data_path + file_name, save_dict)              

# fit SSID-initialised EM

In [None]:
%matplotlib inline
import os
os.chdir("/home/marcel/Desktop/Projects/Stitching/code/pyRRHDLDS/core")
import ssm_scripts
import ssm_fit
from ssm_scripts import setup_fit_lds
from scipy import linalg as la

dtype=np.float
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


overlap = 100
sub_pops = (np.arange((p+overlap)//2),np.arange((p-overlap)//2,p))

max_iter_EM = 20
eps_cov = 1e-20

reps = 3

for r in range(reps):
    rnd_seed_fit = np.random.get_state()    

    
    data_path = '/home/marcel/Desktop/Projects/Stitching/results/icml_spikes/'

    file_name = 'n' + str(n) + 'spikes_fullyObs'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    Qs = load_file['Qs']

    file_name = 'n' + str(n) + 'spikes_2sp_stitched'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    pars_hat = load_file['pars_est']
    if pars_hat['B'] is None:
        pars_hat['Pi'] = (pars_hat['Pi'] + pars_hat['Pi'].T) / 2
        l = np.min( (np.real(np.linalg.eigvals(pars_hat['Pi'])).min(), 0) )
        pars_hat['B'] = np.linalg.cholesky(pars_hat['Pi'] + (1e-10 - l) * np.eye(n))
    if pars_hat['A'] is None:
        pars_hat['A'] = np.linalg.lstsq(pars_hat['X'][:(len(lag_range)-1)*n,:], pars_hat['X'][n:len(lag_range)*n,:])[0]
    pars_hat['X'] = np.vstack([ np.linalg.matrix_power(pars_hat['A'],m).dot(pars_hat['Pi']) for m in lag_range])
    pars_hat['Q'] = pars_hat['Pi'] + pars_hat['A'].dot(pars_hat['Pi']).dot(pars_hat['Pi'])
    
    Om,W,idx_a,idx_b = load_file['Om'], load_file['W'], load_file['idx_a'], load_file['idx_b']
    print_slim(Qs,Om,lag_range,pars_hat,idx_a,idx_b,None,False,None)

    nOm = [np.invert(Om[m]) for m in range(len(lag_range))]
    print_slim(Qs,nOm,lag_range,pars_hat,idx_a,idx_b,None,False,None)    
    
    
    # EM    
    likes = np.zeros(max_iter_EM)
    res = np.zeros((max_iter_EM, n+1))
    obs_scheme = {'sub_pops': [sub_pops[0], sub_pops[1]],
                  'obs_pops': [0,1],
                  'obs_time': [T//2,T]}
    fit_lds = setup_fit_lds(y=y.T.reshape(p,T,1), 
                            u=None, 
                            max_iter=max_iter_EM,
                            epsilon=np.log(1.0001), 
                            eps_cov=eps_cov,
                            plot_flag=False, 
                            trace_pars_flag=True, 
                            trace_stats_flag=False, 
                            diag_R_flag=True,
                            use_A_flag=True, 
                            use_B_flag=False)

    # fit the model to data          
    print('fitting model to data')
    pars_hat[ 'd' ] = np.zeros(p, dtype=dtype)
    pars_hat['mu0'] = np.zeros(n ,dtype=dtype)
    pars_hat[ 'V0'] = pars_hat['X'][:n, :]
  
    t = time.time()
    pars_hat['B'] = np.empty((n,0), dtype=dtype)
    pars_hat,ll = fit_lds(x_dim=n,
                          pars=pars_hat, 
                          obs_scheme=obs_scheme,
                          save_file=None)

    likes = np.array(ll)
    elapsed_time = time.time() - t
    print('elapsed time for fitting is')
    print(elapsed_time)
    pars_hat['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], 
                                                       pars_hat['Q'])        

    t = time.time() - t    
    print('fitting time: ', t)

    plt.figure()
    plt.plot(likes)
    plt.show()

    i = max_iter_EM-1
    pars_est = {
        'A' : pars_hat['As'][i],
        'Q' : pars_hat['Qs'][i],
        'C' : pars_hat['Cs'][i],
        'R' : pars_hat['Rs'][i]
    }
    pars_est['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_est['A'], 
                                                       pars_est['Q']) 
    pars_est['X'] = np.vstack([ np.linalg.matrix_power(pars_est['A'],m).dot(pars_est['Pi']) for m in lag_range])
    print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,None,False,data_path)

    nOm = [np.invert(Om[m]) for m in range(len(lag_range))]
    print_slim(Qs,nOm,lag_range,pars_est,idx_a,idx_b,None,False,data_path)


    data_path = '/home/marcel/Desktop/Projects/Stitching/results/icml_spikes/'
    save_dict = {'p' : p,'n' : n,'T' : T, 
                 'obs_scheme' : obs_scheme, 'y' : None,
                 'pars_est' : pars_hat, 'Qs' : Qs, 'Om' : Om, 'W' : W, 'idx_a' : idx_a, 'idx_b' : idx_b,
                 'rnd_seed' : rnd_seed_fit, 'eps_cov' : eps_cov, 'idx_shuffle' : idx_shuffle
                }
    file_name = 'n' + str(n) + 'spikes_ssidEM_2sp_stitched' + '_init' + str(r)
    np.savez(data_path + file_name, save_dict)              