In [None]:
%matplotlib inline
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
import bz2

import glob, os, psutil, time
os.chdir('../core')
from text import progprint_xrange
from utility import get_subpop_stats, gen_data
from SSID_Hankel_loss import run_bad, plot_slim, print_slim, plot_outputs_l2_gradient_test
os.chdir('../dev')

data_path = '../fits/lsfm/grid_quick/'
Ts = np.arange(0, 1200, 1)
nx, ny, nz = 41, 1024, 2048
dx, dy, dz = 4, 32, 64
x_sel, y_sel, z_sel = np.arange(0,nx,dx), np.arange(0,ny,dy), np.arange(0,nz,dz)

n = 20
T = len(Ts)
p = len(x_sel) * len(y_sel) * len(z_sel) 
k,l = 3, 3
print('(p,n,k+l,T) = ', (p,n,k+l,T), '\n')

mmap = True


In [None]:
if False:
    if mmap:
        y = np.memmap(data_path+'y', dtype=np.float, mode='w+', shape=(T,p))
    else:
        y = np.empty(shape=(T,p))

    for i in progprint_xrange(len(Ts), perline=10):
        t = Ts[i]
        #print('t= ' + str(t) + ', #' + str(i+1) + '/' + str(T))

        filename = '/home/mackelab/data/dOMR0_20150414_112406/TM' + ("%05d" % t) + '_CM0_CHN00.stack.bz2'
        dfile = bz2.BZ2File(filename,compresslevel=1)
        stack = np.frombuffer(dfile.read(), dtype=np.float16).reshape(nx,ny,nz)[np.ix_(x_sel,y_sel,z_sel)]
        y[i,:] = stack.reshape(-1)    
        if mmap:
            del y # releases RAM, forces flush to disk
            y = np.memmap(data_path+'y', dtype=np.float, mode='r+', shape=(T,p))

    if mmap:
        del y # releases RAM, forces flush to disk
        y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))

    #dfile = bz2.BZ2File(filename)
    #stack = np.fromfile(dfile).reshape(nx,ny,nz)

    del stack

    if mmap:
        del y # releases RAM, forces flush to disk
        y_raw = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
        y = np.memmap(data_path+'y_zscore', dtype=np.float, mode='w+', shape=(T,p))
    else:
        y_raw = y.copy()
        y = np.zeros((T,p))

    for i in range(p//1000):
        y[:,i*1000:(i+1)*1000] = stats.zscore(y_raw[:,i*1000:(i+1)*1000])
    y[:, (p//1000)*1000:] = stats.zscore(y_raw[:, (p//1000)*1000:])    

In [None]:
# create subpopulations

tmp = np.arange(p)
tmp =tmp.reshape(len(x_sel),len(y_sel),len(z_sel))


sub_pops = (tmp[:,:len(y_sel)//2+1,:].reshape(-1),
            tmp[:,len(y_sel)//2-1:,:].reshape(-1))
obs_pops = np.array([0,1])
obs_time = np.array([T//2, T])

In [None]:
from sklearn.decomposition import PCA

p1 = len(sub_pops[0])
pa = np.min((p1,1000))
pb = np.min((p1,1000))
idx_a = np.sort(np.random.choice(p1, pa, replace=False))
idx_b = np.sort(np.random.choice(p1, pb, replace=False))


Qs = []
if mmap:
    y = np.memmap(data_path+'y_zscore', dtype=np.float, mode='r', shape=(T,p))
    y = y[:,sub_pops[0]]
for m in range(k+l):
    print('computing time-lagged covariance for lag ', str(m))
    Qs.append(None)
    if mmap:
        Q = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, 
                      mode='w+', shape=(pa,pb))
    else:
        Q = np.empty((pa,pb))
    Q[:] = np.cov(y[m:m-(k+l),idx_a].T, y[:-(k+l),idx_b].T)[:pa,pa:]     
    if mmap:
        del Q
        Qs[m] = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, 
                          mode='r', shape=(pa,pb))
    else:
        Qs[m] = Q
if mmap:
    del y # releases RAM, forces flush to disk
    y = np.memmap(data_path+'y_zscore', dtype=np.float, mode='r', shape=(T,p)) 
    y = y[:,sub_pops[0]]


obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=(np.arange(len(sub_pops[0])),), p=len(sub_pops[0]), verbose=False)


pca = PCA()
pca.fit(y)
explained_variance_ratio_ = np.hstack((pca.explained_variance_ratio_, np.zeros(np.max((p1-T,0)))))
plt.plot(range(1,p1+1), np.cumsum(explained_variance_ratio_)/np.sum(explained_variance_ratio_))
plt.hold(True)
plt.plot(np.linspace(0,p+1,np.min((20,p))), 
         np.cumsum(pca.explained_variance_ratio_[:np.min((20,p))])/np.sum(pca.explained_variance_ratio_), 
         'r--')
plt.legend(('cum. var. expl.', 'first 20, x-axis rescaled'))

pars_pca = {}
pars_pca['C'] = pca.components_[:n].T
pars_pca['Pi'] = np.diag(pca.explained_variance_[:n])    
#np.savez(data_path+'pars_pca', 
#     pars=pars_pca)


t = time.time()
linearity, stable, sym_psd = 'False', False, False
verbose = True

pars_init, pars_est1, traces = run_bad(k=k,l=l,n=n,y=y, Om = None, Qs=Qs, idx_a=idx_a,idx_b=idx_b,
                                      sub_pops=(np.arange(p1),),
                                      idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      obs_pops=np.array([0]),obs_time=np.array([T]),
                                      linearity=linearity,stable=stable,init='default',alpha=0,
                                      max_iter=1,batch_size=1,max_zip_size=1)



# settings for fitting algorithm
batch_size, max_zip_size, max_iter = 1, 100, 50
a, b1, b2, e = 0.001, 0.9, 0.99, 1e-8
a_R = a
    
pars_est1['C'] = pars_pca['C'].copy()
pars_est1['X'][:n,:] = pars_pca['Pi'].copy()
_, pars_est1, traces = run_bad(k=k,l=l,n=n,y=y, Om = None, Qs=Qs, idx_a=idx_a,idx_b=idx_b,
                                      sub_pops=(np.arange(p1),),
                                      idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      obs_pops=np.array([0]),obs_time=np.array([T]),
                                      linearity=linearity,stable=stable,init=pars_est1,
                                      alpha=a,alpha_R=a_R,b1=b1,b2=b2,e=e,max_iter=max_iter,batch_size=batch_size,
                                      verbose=verbose, sym_psd=sym_psd, max_zip_size=max_zip_size)

batch_size, max_zip_size, max_iter = 5, 20, 50
a, b1, b2, e = 0.001, 0.9, 0.99, 1e-8
a_R = a
    
_, pars_est1, traces = run_bad(k=k,l=l,n=n,y=y, Om = None, Qs=Qs, idx_a=idx_a,idx_b=idx_b,
                                      sub_pops=(np.arange(p1),),
                                      idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      obs_pops=np.array([0]),obs_time=np.array([T]),
                                      linearity=linearity,stable=stable,init=pars_est1,
                                      alpha=a,alpha_R=a_R,b1=b1,b2=b2,e=e,max_iter=max_iter,batch_size=batch_size,
                                      verbose=verbose, sym_psd=sym_psd, max_zip_size=max_zip_size)


plot_slim(Qs,k,l,pars_est1,idx_a,idx_b,traces,mmap,data_path)


print(y.shape)

In [None]:

obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=sub_pops, p=p, verbose=False)

    
pa = np.min((p,1000))
pb = np.min((p,1000))
idx_a = np.sort(np.random.choice(p, pa, replace=False))
idx_b = np.sort(np.random.choice(p, pb, replace=False))


Qs = []
if mmap:
    y = np.memmap(data_path+'y_zscore', dtype=np.float, mode='r', shape=(T,p))
for m in range(k+l):
    print('computing time-lagged covariance for lag ', str(m))
    Qs.append(None)
    if mmap:
        Q = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, 
                      mode='w+', shape=(pa,pb))
    else:
        Q = np.empty((pa,pb))
    Q[:] = np.cov(y[m:m-(k+l),idx_a].T, y[:-(k+l),idx_b].T)[:pa,pa:]     
    if mmap:
        del Q
        Qs[m] = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, 
                          mode='r', shape=(pa,pb))
    else:
        Qs[m] = Q
if mmap:
    del y # releases RAM, forces flush to disk
    y = np.memmap(data_path+'y_zscore', dtype=np.float, mode='r', shape=(T,p)) 
    
    
batch_size, max_zip_size, max_iter = 1, 100, 50
a, b1, b2, e = 0.001, 0.9, 0.99, 1e-8
a_R = a

t = time.time()
pars_init, pars_0, traces = run_bad(k=k,l=l,n=n,y=y, Om = None, Qs=Qs, idx_a=idx_a,idx_b=idx_b,
                                      sub_pops=sub_pops,
                                      idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      obs_pops=obs_pops,obs_time=obs_time,
                                      linearity=linearity,stable=stable,init='default',alpha=0,
                                      max_iter=1,batch_size=1,max_zip_size=1)
pars_0['C'] *= 0
pars_0['C'][sub_pops[0], :] = pars_est1['C'].copy()
pars_0['R'][sub_pops[0]] = pars_est1['R'].copy()
pars_0['X'] = pars_est1['X'].copy()
pars_est = pars_0.copy()
del pars_0


_, pars_est, traces = run_bad(k=k,l=l,n=n,y=y, Om = None, Qs=Qs, idx_a=idx_a,idx_b=idx_b,
                                      sub_pops=sub_pops,
                                      idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      obs_pops=obs_pops,obs_time=obs_time,
                                      linearity=linearity,stable=stable,init=pars_est,
                                      alpha=a,alpha_R=a_R,b1=b1,b2=b2,e=e,max_iter=max_iter,batch_size=batch_size,
                                      verbose=verbose, sym_psd=sym_psd, max_zip_size=max_zip_size)
plot_slim(Qs,k,l,pars_est,idx_a,idx_b,traces,mmap,data_path)

batch_size, max_zip_size, max_iter = 10, 10, 50
a, b1, b2, e = 0.001, 0.9, 0.99, 1e-8
a_R = a


_, pars_est, traces = run_bad(k=k,l=l,n=n,y=y, Om = None, Qs=Qs, idx_a=idx_a,idx_b=idx_b,
                                      sub_pops=sub_pops,
                                      idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      obs_pops=obs_pops,obs_time=obs_time,
                                      linearity=linearity,stable=stable,init=pars_est,
                                      alpha=a,alpha_R=a_R,b1=b1,b2=b2,e=e,max_iter=max_iter,batch_size=batch_size,
                                      verbose=verbose, sym_psd=sym_psd, max_zip_size=max_zip_size)
plot_slim(Qs,k,l,pars_est,idx_a,idx_b,traces,mmap,data_path)



print('fitting time was ', time.time() - t, 's')
print('\n')
print(psutil.virtual_memory())
print(psutil.swap_memory())



In [None]:


batch_size, max_zip_size, max_iter = 10, 10, 50
a, b1, b2, e = 0.0002, 0.9, 0.99, 1e-8
a_R = a


_, pars_est, traces = run_bad(k=k,l=l,n=n,y=y, Om = None, Qs=Qs, idx_a=idx_a,idx_b=idx_b,
                                      sub_pops=sub_pops,
                                      idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      obs_pops=obs_pops,obs_time=obs_time,
                                      linearity=linearity,stable=stable,init=pars_est,
                                      alpha=a,alpha_R=a_R,b1=b1,b2=b2,e=e,max_iter=max_iter,batch_size=batch_size,
                                      verbose=verbose, sym_psd=sym_psd, max_zip_size=max_zip_size)
plot_slim(Qs,k,l,pars_est,idx_a,idx_b,traces,mmap,data_path)



print('fitting time was ', time.time() - t, 's')
print('\n')
print(psutil.virtual_memory())
print(psutil.swap_memory())


In [None]:


batch_size, max_zip_size, max_iter = 20, 10, 50
a, b1, b2, e = 0.0001, 0.99, 0.99, 1e-8
a_R = a


_, pars_est, traces = run_bad(k=k,l=l,n=n,y=y, Om = None, Qs=Qs, idx_a=idx_a,idx_b=idx_b,
                                      sub_pops=sub_pops,
                                      idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      obs_pops=obs_pops,obs_time=obs_time,
                                      linearity=linearity,stable=stable,init=pars_est,
                                      alpha=a,alpha_R=a_R,b1=b1,b2=b2,e=e,max_iter=max_iter,batch_size=batch_size,
                                      verbose=verbose, sym_psd=sym_psd, max_zip_size=max_zip_size)
plot_slim(Qs,k,l,pars_est,idx_a,idx_b,traces,mmap,data_path)



print('fitting time was ', time.time() - t, 's')
print('\n')
print(psutil.virtual_memory())
print(psutil.swap_memory())


In [None]:

y_masked = np.memmap(data_path+'y_zscore', dtype=np.float, mode='c', shape=(T,p))

y_masked[:obs_time[0], np.setdiff1d(np.arange(p), sub_pops[0])] = 0
y_masked[obs_time[0]:, np.setdiff1d(np.arange(p), sub_pops[1])] = 0
    
pca = PCA()
pca.fit(y_masked)
del y_masked

explained_variance_ratio_ = np.hstack((pca.explained_variance_ratio_, np.zeros(np.max((p-T,0)))))
plt.plot(range(1,p+1), np.cumsum(explained_variance_ratio_)/np.sum(explained_variance_ratio_))
plt.hold(True)
plt.plot(np.linspace(0,p+1,np.min((20,p))), 
         np.cumsum(pca.explained_variance_ratio_[:np.min((20,p))])/np.sum(pca.explained_variance_ratio_), 
         'r--')
plt.legend(('cum. var. expl.', 'first 20, x-axis rescaled'))

pars_pca = {}
pars_pca['C'] = pca.components_[:n].T
pars_pca['Pi'] = np.diag(pca.explained_variance_[:n])    


idx_a = np.setdiff1d(np.arange(p), sub_pops[1]) # those only in subpop #1
idx_b = np.setdiff1d(np.arange(p), sub_pops[0]) # those only in subpop #2 
pa, pb = 1000, 1000
idx_a = idx_a[np.sort(np.random.choice(idx_a.size, pa, replace=False))] # subsample for \
idx_b = idx_b[np.sort(np.random.choice(idx_b.size, pb, replace=False))] # memory reasons


Qs = []
if mmap:
    y = np.memmap(data_path+'y_zscore', dtype=np.float, mode='r', shape=(T,p))
for m in range(k+l):
    print('computing time-lagged covariance for lag ', str(m))
    Qs.append(None)
    if mmap:
        Q = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, 
                      mode='w+', shape=(pa,pb))
    else:
        Q = np.empty((pa,pb))
    Q[:] = np.cov(y[m:m-(k+l),idx_a].T, y[:-(k+l),idx_b].T)[:pa,pa:]     
    if mmap:
        del Q
        Qs[m] = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, 
                          mode='r', shape=(pa,pb))
    else:
        Qs[m] = Q
if mmap:
    del y # releases RAM, forces flush to disk
    y = np.memmap(data_path+'y_zscore', dtype=np.float, mode='r', shape=(T,p)) 



plt.figure(figsize=(10, 10))
for m in range(1):
    print('computing time-lagged covariance for lag ', str(m))
        
    Q = (pars_est['C'][idx_a,:].dot(pars_est['X'][:n,:]).dot(pars_est['C'][idx_b,:].T))
    #Q += np.diag(pars_est['R'])[np.ix_(idx_a,idx_b)]
    plt.plot(Qs[m][:], Q[:], 'r.')
    plt.hold(True)
    plt.title('corr. est. us:' + str(np.corrcoef(Qs[m][:].reshape(-1), Q.reshape(-1))[0,1]))
    
    Q = np.empty((pa,pb))
    Q[:] = (pars_pca['C'][idx_a,:].dot(pars_pca['Pi']).dot(pars_pca['C'][idx_b,:].T))
    plt.plot(Qs[m][:], Q[:], 'b.')
    plt.plot([-1,1], [-1,1], 'k--')
    
    plt.xlabel('corr. est. PCA:' + str(np.corrcoef(Qs[m][:].reshape(-1), Q.reshape(-1))[0,1]))
    plt.savefig(fig_path + 'sticthed_covs.png', bbox_inches="tight"); 
    plt.show()
    
