In [None]:
import numpy as np
import scipy as sp
import matplotlib
import matplotlib.pyplot as plt

import ssm_timeSeries as ts  # my self-written time series overhead
import ssm_fit               # my self-written library for state-space model fitting
import ssm_scripts

import random
from datetime import datetime     # generate random seed for 
random.seed(datetime.now())       # np.random. Once this is fixed, all 
rngSeed = random.randint(0, 1000) # other 'randomness' is also fixed

%matplotlib inline


""" Set problem size """

xDim = 8
yDim = 800
uDim = 0
T    = 100000

idxy = np.random.choice(800, yDim, replace=False) # subsamples randomly without replacement
                                                   # unordered, which is important to mix the modes! 

# happy times, we fix it:
#idxy = np.array([29, 35, 73, 66, 55, 72, 57, 97, 94, 71, 92,  4,  3, 34,  9, 77, 47,
#       70, 98, 21, 13, 88,  7,  1, 41, 78, 93, 39, 27, 61, 52,  2, 79, 86,
#       49, 14,  5, 26, 82, 46, 84, 51, 10, 33, 17, 85, 18, 43, 64, 32, 91,
#       50, 45, 48,  8, 36, 59, 24, 89, 81, 12, 96, 53, 63, 74, 37, 99, 25,
#       67, 75,  0, 20, 68, 31, 87, 80,  6, 40, 90, 11, 44, 22, 42, 65, 76,
#       54, 23, 16, 19, 58, 69, 28, 56, 38, 30, 62, 60, 15, 83, 95])


from scipy.io import loadmat
OUT = loadmat('/home/mackelab/Desktop/Projects/Stitching/data/clustered_networks/calcium_traces/calcTraces800.mat')
yReal = OUT['y']
yReal = yReal[np.ix_(idxy,range(T))]
Trial=1
yReal = yReal.reshape(yDim,T,1)
del OUT

""" Set observation protocol """
        
subpops = [[], list(range(0,yDim)), list(range(0,51)),list(range(50,yDim))]
obsTime = [1] # start with observation of
obsPops = [1] # full population at t = 0

temporalStitchingOrder = 2
for i in range(1,int(T/3)+1): # first third of data
    if np.mod(i,temporalStitchingOrder)==0:
        obsTime.append(i)
        obsPops.append(0)
        obsTime.append(i+1)
        obsPops.append(1)
while obsTime[-1] >= int(T/3)+1: # if we overshoot,
    obsTime.pop()                # just delete
    obsPops.pop()                #
obsTime.append(int(T/3)+1)  # fill gap with empty
obsPops.append(0)           # subpopulation
    
obsTime.append(int(2*T/3)) # second and last third of data
obsPops.append(2)          #
obsTime.append(int(T))     # observe the two subpops densely   
obsPops.append(3)          #      

obsScheme = {'subpops': subpops,
             'obsTime': obsTime,
             'obsPops': obsPops}

fitOptions = {'ifUseB' : False,  
              'maxIter': 100, 
              'ifPlotProgress' : True,
              'covConvEps' : 1e-50,
              'ifTraceParamHist': False,
              'ifFitA' : True,
              'ifInitCwithPCA' : True
             }

""" Fit the model, save results """

sf = '/home/mackelab/Desktop/Projects/Stitching/results/test_problems/LDS_save'
[yOut,xOut,u,learnedPars,initPars,truePars] = ssm_scripts.run(xDim, yDim, uDim, T, 
                                                              obsScheme, fitOptions,
                                                              y = yReal, x = [],
                                                              truePars=None,
                                                              initPars=None,
                                                              saveFile=sf)

(800, 800)

In [None]:
%matplotlib inline
corry = np.corrcoef(yOut[:,:,0])
corry.shape
plt.figure(figsize=(16,12))
plt.subplot(1,2,1)
plt.imshow(corry, interpolation='none')
plt.colorbar()
plt.subplot(1,2,2)
plt.plot(np.sort(corry.reshape(yDim*yDim,)))

In [None]:
[A,B,Q,mu0,V0,C,d,R] = learnedPars 
Pi    = np.array([sp.linalg.solve_discrete_lyapunov(A, Q)])[0,:,:]

covy_h= np.dot(np.dot(C, Pi), C.transpose()) + np.diag(R)

y_tl = np.zeros([2*yDim,T-1])
y_tl[range(yDim),:] = yOut[:,range(0,T-1),0]
y_tl[range(yDim,2*yDim),:] = yOut[:,range(1,T),0]
covy = np.cov(y_tl)

covy_e=   covy[np.ix_(range(yDim),range(yDim))]
covy_tl_e=covy[np.ix_(range(0,yDim),range(yDim,2*yDim))]

covy_tl_h= np.dot(np.dot(C, np.dot(A,Pi)), C.transpose())

plt.figure(1,figsize=(15,15))
plt.subplot(2,2,1)
plt.imshow(np.dot(np.dot(C, Pi), C.transpose()) + np.diag(R), interpolation='none')
plt.colorbar()
plt.title('inst. covariances est.')
plt.subplot(2,2,2)
plt.imshow(covy_e, interpolation='none')
plt.title('inst. covariances true')
plt.colorbar()
plt.subplot(2,2,3)
plt.imshow(np.dot(np.dot(np.dot(C, A), Pi), C.transpose()), interpolation='none' )
plt.colorbar()
plt.title('time-lagged. covariances est.')
plt.subplot(2,2,4)
plt.imshow(covy_tl_e, interpolation='none' )
plt.colorbar()
plt.title('time-lagged. covariances true')

plt.figure(2,figsize=(15,15))
plt.subplot(2,2,1)
plt.plot(d)
plt.title('d')
plt.subplot(2,2,2)
plt.plot(R)
plt.legend(['true', 'est'])
plt.title('R')
plt.subplot(2,2,3)
plt.plot(np.sort(np.linalg.eig(A)[0]))
plt.title('eig(A)')
plt.subplot(2,2,4)


In [None]:
%matplotlib inline
idxStitched = np.ones([yDim,yDim],dtype = bool)
for i in range(len(obsScheme['subpops'])):
    if len(obsScheme['subpops'][i])>0:
        idxStitched[np.ix_(obsScheme['subpops'][i],obsScheme['subpops'][i])] = False


plt.imshow(idxStitched,interpolation='none')

In [None]:

    
%matplotlib inline
plt.figure(1, figsize=(15,8))
plt.plot(covy_e[np.invert(idxStitched)], covy_h[np.invert(idxStitched)], '.')
plt.title('emp vs. stitched')

plt.figure(2, figsize=(15,8))
plt.plot(covy_tl_e[np.invert(idxStitched)], covy_tl_h[np.invert(idxStitched)], '.')
plt.title('emp vs. stitched')

plt.figure(3, figsize=(15,8))
plt.plot(covy_e[idxStitched], covy_h[idxStitched], '.')
plt.title('emp vs. stitched')

plt.figure(4, figsize=(15,8))
plt.plot(covy_tl_e[idxStitched], covy_tl_h[idxStitched], '.')
plt.title('emp vs. stitched')
