# summarize results on small index set extracted from zebra-fish data

# 10 non-overlapping subpopulations (4 z-planes each, last with 5 z-planes)

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error



run = '_real_stitch'
ns = (2,5,10,20)

data_path = '../fits/lsfm/grid_quick/'
idx_str = 'small'
idx_fish = np.load(data_path + 'idx_' + idx_str + '.npy')

T, p = 1200, len(idx_fish)
lag_range = np.arange(0,10)
kl_ = np.max(lag_range)+1
snr = (0., 0.)
verbose=True

def get_corrs(Qs,Om,lag_range,pars,idx_a,idx_b,traces,mmap):
    kl = len(lag_range)
    p,n = pars['C'].shape
    pa, pb = idx_a.size, idx_b.size
    idx_ab = np.intersect1d(idx_a, idx_b)
    idx_a_ab = np.where(np.in1d(idx_a, idx_ab))[0]
    idx_b_ab = np.where(np.in1d(idx_b, idx_ab))[0]
    out = np.zeros(len(lag_range))
    for m in range(kl): 
        m_ = lag_range[m] 
        Qrec = pars['C'][idx_a,:].dot(pars['X'][m*n:(m+1)*n, :]).dot(pars['C'][idx_b,:].T) 
        if m_ == 0:
            Qrec[np.ix_(idx_a_ab, idx_b_ab)] += np.diag(pars['R'][idx_ab])
        if mmap:
            Q = np.memmap(data_path+'Qs_'+str(m_), dtype=np.float, mode='r', shape=(pa,pb))
        else:
            Q = Qs[m]
        out[m] = np.corrcoef( Qrec[Om[m]].reshape(-1), (Qs[m][Om[m]]).reshape(-1) )[0,1]
        if mmap:
            del Q
    return out

obs_corrs = np.zeros((len(ns), len(lag_range)))
stc_corrs = np.zeros((len(ns), len(lag_range)))
loss      = np.zeros( len(ns) )
for ni in range(len(ns)):
    
    n = ns[ni]

    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + 'snr' + str(np.int(np.mean(snr)//1)) + '_run' + str(run)

    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    mmap = load_file['mmap']
    y, x, snr, idx_a, idx_b = load_file['y'], load_file['x'], load_file['snr'], load_file['idx_a'], load_file['idx_b'] 
    pars_true, pars_est, obs_scheme = load_file['pars_true'], load_file['pars_est'],load_file['obs_scheme']
    W, Om= load_file['W'], load_file['Om']
    Qs = [np.load(data_path+'Qs_'+str(lag_range[m])+'.npy') for m in range(len(lag_range)) ]
    W = obs_scheme.comp_coocurrence_weights(lag_range, sso=True, idx_a=idx_a, idx_b=idx_b) if W is None else W

    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))

    print('(T,p,n)', (T,p,n))

    #print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,_,False,data_path)
    obs_corrs[ni,:] = get_corrs(Qs,Om,lag_range,pars_est,idx_a,idx_b,None,False)
    loss[ni] = f_l2_Hankel_nl(C=pars_est['C'],
                                   X=pars_est['X'],
                                   R=pars_est['R'],
                                   Qs=Qs,
                                   Om=Om,
                                   lag_range=lag_range,
                                   ms=range(len(lag_range)),
                                   idx_a=idx_a,
                                   idx_b=idx_b)
    

    for m in lag_range:    
        y_ = np.memmap(data_path+'y_' + idx_str + '_zscore', dtype=np.float, mode='r', shape=(T,p))
        ya = y_[m:-kl_+m, idx_a].copy()
        yb = y_[:-kl_, idx_b].copy()
        del y_

        Qgd = np.cov(ya.T, yb.T)[:len(idx_a), len(idx_b):]
        del ya
        del yb    
        Qest = pars_est['C'][idx_a,:].dot(pars_est['X'][m*n:(m+1)*n,:]).dot(pars_est['C'][idx_b,:].T)
        stc_corrs[ni,m] = np.corrcoef(Qgd[np.invert(Om[m])], Qest[np.invert(Om[m])])[0,1]
        del Qest
        del Qgd


In [None]:
import seaborn

plt.figure(figsize=(14,8))

plt.subplot(2,3,1)
plt.imshow(Qs[0], interpolation='None')
plt.colorbar()
plt.grid('off')
plt.title('observed inst. covariances')

plt.subplot(2,3,2)
y_ = np.memmap(data_path+'y_' + idx_str + '_zscore', dtype=np.float, mode='r', shape=(T,p))
ya = y_[:-kl_, idx_a].copy()
yb = y_[:-kl_, idx_b].copy()
del y_
Qgd = np.cov(ya.T, yb.T)[:len(idx_a), len(idx_b):]
del ya
del yb       
plt.imshow(Qgd, interpolation='None')
plt.colorbar()
del Qgd
plt.grid('off')
plt.title('full instantaneous covariances')

plt.subplot(2,3,3)
n, m = 10, 0
pars_est =  np.load(data_path + 'p160495n' + str(n) + 'T1200snr0_run_real_stitch.npz')['arr_0'].tolist()['pars_est']
Qest = pars_est['C'][idx_a,:].dot(pars_est['X'][m*n:(m+1)*n,:]).dot(pars_est['C'][idx_b,:].T)
del pars_est
plt.imshow(Qest, interpolation='None')
plt.colorbar()
del Qest
plt.grid('off')
plt.title('reconstructed inst. covariances')

plt.subplot(2,3,4)
plt.plot(ns, loss, 'o-')
plt.ylabel('(partial) training loss')
plt.xlabel('time lag')
plt.box('off')
plt.title('training loss vs. latent dim.')
plt.xticks(ns)

plt.subplot(2,3,5)
plt.plot(obs_corrs.T)
plt.xlabel('time lag')
plt.ylabel('corr. of covariances')
plt.legend(['n = ' + str(n) for n in ns], loc=3, frameon=False)
plt.box('off')
plt.axis([0,np.max(lag_range), -0.25, 1.])
plt.title('corr. of observed covariances ')

plt.subplot(2,3,6)
plt.plot(stc_corrs.T)
plt.xlabel('time lag')
plt.ylabel('corr. of covariances')
#plt.legend(['n = ' + str(n) for n in ns], loc=3, frameon=False)
plt.box('off')
plt.axis([0,np.max(lag_range), -0.25, 1.])
plt.title('corr. of stitched covariances ')
plt.savefig(data_path + 'res_summary_10subpops_zebrafish_small.pdf')
plt.show()

# 2 overlapping subpopulations (z-planes 1:21 & 21:41)

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error

run = '_overlap_stitch'
ns = (5, 10, 50)

data_path = '../fits/lsfm/grid_quick/'
idx_str = 'small'
idx_fish = np.load(data_path + 'idx_' + idx_str + '.npy')

T, p = 1200, len(idx_fish)
lag_range = np.arange(0,10)
kl_ = np.max(lag_range)+1
snr = (0., 0.)
verbose=True

def get_corrs(Qs,Om,lag_range,pars,idx_a,idx_b,traces,mmap):
    kl = len(lag_range)
    p,n = pars['C'].shape
    pa, pb = idx_a.size, idx_b.size
    idx_ab = np.intersect1d(idx_a, idx_b)
    idx_a_ab = np.where(np.in1d(idx_a, idx_ab))[0]
    idx_b_ab = np.where(np.in1d(idx_b, idx_ab))[0]
    out = np.zeros(len(lag_range))
    for m in range(kl): 
        m_ = lag_range[m] 
        Qrec = pars['C'][idx_a,:].dot(pars['X'][m*n:(m+1)*n, :]).dot(pars['C'][idx_b,:].T) 
        if m_ == 0:
            Qrec[np.ix_(idx_a_ab, idx_b_ab)] += np.diag(pars['R'][idx_ab])
        if mmap:
            Q = np.memmap(data_path+'Qs_'+str(m_), dtype=np.float, mode='r', shape=(pa,pb))
        else:
            Q = Qs[m]
        out[m] = np.corrcoef( Qrec[Om[m]].reshape(-1), (Qs[m][Om[m]]).reshape(-1) )[0,1]
        if mmap:
            del Q
    return out

obs_corrs = np.zeros((len(ns), len(lag_range)))
stc_corrs = np.zeros((len(ns), len(lag_range)))
loss      = np.zeros( len(ns) )
for ni in range(len(ns)):
    
    n = ns[ni]

    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + 'snr' + str(np.int(np.mean(snr)//1)) + '_run' + str(run)

    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    mmap = load_file['mmap']
    y, x, snr, idx_a, idx_b = load_file['y'], load_file['x'], load_file['snr'], load_file['idx_a'], load_file['idx_b'] 
    pars_true, pars_est, obs_scheme = load_file['pars_true'], load_file['pars_est'],load_file['obs_scheme']
    W, Om= load_file['W'], load_file['Om']
    Qs = [np.load(data_path+'Qs_'+str(lag_range[m])+'.npy') for m in range(len(lag_range)) ]
    W = obs_scheme.comp_coocurrence_weights(lag_range, sso=True, idx_a=idx_a, idx_b=idx_b) if W is None else W

    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))

    print('(T,p,n)', (T,p,n))

    #print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,_,False,data_path)
    obs_corrs[ni,:] = get_corrs(Qs,Om,lag_range,pars_est,idx_a,idx_b,None,False)
    loss[ni] = f_l2_Hankel_nl(C=pars_est['C'],
                                   X=pars_est['X'],
                                   R=pars_est['R'],
                                   Qs=Qs,
                                   Om=Om,
                                   lag_range=lag_range,
                                   ms=range(len(lag_range)),
                                   idx_a=idx_a,
                                   idx_b=idx_b)
    

    for m in lag_range:    
        y_ = np.memmap(data_path+'y_' + idx_str + '_zscore', dtype=np.float, mode='r', shape=(T,p))
        ya = y_[m:-kl_+m, idx_a].copy()
        yb = y_[:-kl_, idx_b].copy()
        del y_

        Qgd = np.cov(ya.T, yb.T)[:len(idx_a), len(idx_b):]
        del ya
        del yb    
        Qest = pars_est['C'][idx_a,:].dot(pars_est['X'][m*n:(m+1)*n,:]).dot(pars_est['C'][idx_b,:].T)

        # notice that we take Om[0] for all m below: Om[1] etc. is 'fully observed' because it's only two subpops
        # and we keep switching a couple of times. Actually however, this is something one would not take
        # into account when fitting systems individually, so we reduce to inst. co-observations, i.e. Om[0]
        stc_corrs[ni,m] = np.corrcoef(Qgd[np.invert(Om[0])], Qest[np.invert(Om[0])])[0,1]
        del Qest
        del Qgd


In [None]:
import seaborn

plt.figure(figsize=(14,8))


plt.subplot(2,3,1)
plt.imshow(Qs[0], interpolation='None')
plt.colorbar()
plt.grid('off')
plt.title('observed inst. covariances')

plt.subplot(2,3,2)
y_ = np.memmap(data_path+'y_' + idx_str + '_zscore', dtype=np.float, mode='r', shape=(T,p))
ya = y_[:-kl_, idx_a].copy()
yb = y_[:-kl_, idx_b].copy()
del y_
Qgd = np.cov(ya.T, yb.T)[:len(idx_a), len(idx_b):]
del ya
del yb       
plt.imshow(Qgd, interpolation='None')
plt.colorbar()
del Qgd
plt.grid('off')
plt.title('full instantaneous covariances')

plt.subplot(2,3,3)
n, m = 10, 0
pars_est =  np.load(data_path + 'p160495n' + str(n) + 'T1200snr0_run_real_stitch.npz')['arr_0'].tolist()['pars_est']
Qest = pars_est['C'][idx_a,:].dot(pars_est['X'][m*n:(m+1)*n,:]).dot(pars_est['C'][idx_b,:].T)
del pars_est
plt.imshow(Qest, interpolation='None')
plt.colorbar()
del Qest
plt.grid('off')
plt.title('reconstructed inst. covariances')

plt.subplot(2,3,4)
plt.plot(ns, loss, 'o-')
plt.ylabel('(partial) training loss')
plt.xlabel('time lag')
plt.box('off')
plt.title('training loss vs. latent dim.')
plt.xticks(ns)

plt.subplot(2,3,5)
plt.plot(obs_corrs.T)
plt.xlabel('time lag')
plt.ylabel('corr. of covariances')
plt.legend(['n = ' + str(n) for n in ns], loc=3, frameon=False)
plt.box('off')
plt.axis([0,np.max(lag_range), 0.75, 1.])
plt.title('corr. of observed covariances ')

plt.subplot(2,3,6)
plt.plot(stc_corrs.T)
plt.xlabel('time lag')
plt.ylabel('corr. of covariances')
#plt.legend(['n = ' + str(n) for n in ns], loc=3, frameon=False)
plt.box('off')
plt.axis([0,np.max(lag_range), 0.75, 1.])
plt.title('corr. of stitched covariances ')
plt.savefig(data_path + 'res_summary_2subpops_zebrafish_small.pdf')
plt.show()