# ST-Stochastic descent

- using SGD on $C$, $\{X_m\}_{m=0}^{k+l}$

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os

os.chdir('../core')
import SSID_Hankel_loss 
from utility import get_subpop_stats, draw_sys, gen_data, gen_pars
from SSID_Hankel_loss import run_bad, plot_outputs_l2_gradient_test, l2_bad_sis_setup
import ssm_scripts
import ssm_fit
os.chdir('../dev')

import psutil
import time

#np.random.seed(0)

p,n,nr = 1000, 10, 2
k,l = 3,3
T = 500


mmap = False
save_file = 'test'
data_path = '../fits/'
chunksize = np.min((p,2000))

verbose=True

# create subpopulations
sub_pops = (np.arange(0,p), np.arange(0,p))

obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=sub_pops, p=p, verbose=False)

# draw system matrices    
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.8, 0.99, 0.8, 0.99
ev_r = np.linspace(eig_m_r, eig_M_r, nr)
ev_c = np.exp(2 * 1j * np.pi * np.random.uniform(size= (n - nr)//2))
ev_c = np.linspace(eig_m_c, eig_M_c, (n - nr)//2) * ev_c

pars_true, Qs, _ = draw_sys(p=p,n=n,k=k,l=l, nr=nr, ev_r=ev_r,ev_c=ev_c,calc_stats=T==np.inf,
                                 return_masked=False, mmap=mmap, chunksize=chunksize, data_path=data_path)
pars_true['d'], pars_true['mu0'], pars_true['V0'] = np.zeros(p), np.zeros(n), pars_true['Pi'].copy()

pa, pb = np.min((p,1000)), np.min((p,1000))
idx_a = np.sort(np.random.choice(p, pa, replace=False))
idx_b = np.sort(np.random.choice(p, pb, replace=False))

if T == np.inf:
    x,y = np.zeros((n,0)), np.zeros((p,0))
else:
    print('computing empirical covariances')
    x,y = gen_data(pars = pars_true, T = T, mmap = mmap, chunksize = chunksize, data_path = data_path ) 
    for m in range(k+l):
        print('computing time-lagged covariance for lag ', str(m))
        if mmap:
            Q = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, mode='w+', shape=(pa,pb))
        else:
            Q = np.empty((pa,pb))
        Q[:] = np.cov(y[m:m-(k+l),idx_a].T, y[:-(k+l),idx_b].T)[:pa,pb:]     
        if mmap:
            del Q
            Qs[m] = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, mode='r', shape=(pa,pb))
        else:
            Qs[m] = Q
pars_init='default'        
np.savez(data_path + save_file, 
         pars_init=pars_init,
         pars_true=pars_true, 
         pars_est=None,
         sub_pops=sub_pops,
         p=p,n=n,T=T,k=k,l=l,
         idx_a=idx_a, idx_b=idx_b,
         x=x)          
        
print('(p,n,k,l,T)', (p,n,k,l,T))
    
print('\n')
print(psutil.virtual_memory())
print(psutil.swap_memory())


# Load test dynamic texture

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os

os.chdir('../core')
import SSID_Hankel_loss 
from utility import get_subpop_stats, draw_sys, gen_data, gen_pars
from SSID_Hankel_loss import run_bad, plot_outputs_l2_gradient_test, l2_bad_sis_setup
os.chdir('../dev')

import psutil
import time

from scipy.io import loadmat
data = loadmat('/home/mackelab/Desktop/Projects/Stitching/data/dynamic_textures/fire.mat')['data']
T = data.shape[0]
p = data.shape[1] * data.shape[2] // 64
#  reshape data
y = np.zeros((T, p))
for t in range(T):
    y[t,:] = np.ravel(np.mean(data[t,:,:,:],axis=2)[np.ix_(np.arange(0,data.shape[1],8),np.arange(0,data.shape[2],8))])


# set fitting parameters
k,l = 2,2
n = 27
mmap = False
verbose = True
pars_init = 'default'

# create subpopulations
sub_pops = (np.arange(0,p), np.arange(0,p))
obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=sub_pops, p=p, verbose=False)

# generate (subsampled) covariance matrices
pa, pb = np.min((p,1000)), np.min((p,1000))
idx_a = np.sort(np.random.choice(p, pa, replace=False))
idx_b = np.sort(np.random.choice(p, pb, replace=False))    
Qs = []
for m in range(k+l):
    Qs.append(None)
    print('computing time-lagged covariance for lag ', str(m))
    if mmap:
        Q = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, mode='w+', shape=(pa,pb))
    else:
        Q = np.empty((pa,pb))
    Q[:] = np.cov(y[m:m-(k+l),idx_a].T, y[:-(k+l),idx_b].T)[:pa,pb:]     
    if mmap:
        del Q
        Qs[m] = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, mode='r', shape=(pa,pb))
    else:
        Qs[m] = Q           
        
    

# Start fitting

In [None]:
batch_size = 1
max_zip_size = 100
max_iter  = 20

a, b1, b2, e = 0.1, 0.9, 0.99, 1e-8

t = time.time()

linearity = 'False'
stable = False
sym_psd = False

pars_init, pars_est, traces = run_bad(k=k,l=l,n=n,y=y, Qs=Qs,Om=Om,idx_a=idx_a, idx_b=idx_b,
                                      sub_pops=sub_pops,idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      linearity=linearity,stable=stable,init=pars_init,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,batch_size=batch_size,
                                      verbose=verbose, sym_psd=sym_psd, max_zip_size=max_zip_size)

print('fitting time was ', time.time() - t)

print('\n')
print(psutil.virtual_memory())
print(psutil.swap_memory())

pars = pars_est          
plt.figure(figsize=(20,50))
for m in range(0,k+l): 
    Qrec = pars['C'][idx_a,:].dot(pars['X'][m*n:(m+1)*n, :]).dot(pars['C'][idx_b,:].T) 
    Qrec = Qrec + np.diag(pars['R'])[np.ix_(idx_a,idx_b)] if m==0 else Qrec
    
    plt.subplot(np.ceil( (k+l)/2 ), 2, m+1, adjustable='box-forced')

    if mmap:
        Q = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, mode='r', shape=(pa,pb))
    else:
        Q = Qs[m]
        
    plt.plot(Q.reshape(-1), Qrec.reshape(-1), '.')
    plt.title( ('m = ' + str(m) + ', corr = ' + 
    str(np.corrcoef( Qrec.reshape(-1), (Qs[m]).reshape(-1) )[0,1])))
    
    if mmap:
        del Q
        
    plt.xlabel('true covs')
    plt.ylabel('est. covs')
plt.show()
    
plt.figure(figsize=(20,10))
plt.plot(traces[0])
plt.xlabel('iteration count')
plt.ylabel('target loss')
plt.title('loss function vs. iterations')
plt.show()

if p <= 1000:
    plot_outputs_l2_gradient_test(pars_true=pars_true, pars_init=pars_init, 
                                  pars_est=pars_est, k=k, l=l, Qs=Qs, 
                                       Om=Om, traces = traces, idx_a=idx_a, idx_b=idx_b,
                                       linearity=linearity, idx_grp = idx_grp, co_obs = co_obs, 
                                       if_flip = True, m = 0)

In [None]:
m = 0
pars = pars_est
plt.subplot(1,2,1)
plt.imshow(Qs[m], interpolation='none')
plt.subplot(1,2,2)
plt.imshow(pars['C'].dot(pars['X'][m*n:(m+1)*n,:]).dot(pars['C'].T), interpolation='none')
plt.show()


In [None]:
batch_size = T-(k+l)
max_zip_size = 1000
max_iter  = 100

a, b1, b2, e = 0.1, 0.9, 0.99, 1e-8

t = time.time()

linearity = 'False'
stable = False
sym_psd = False

_, pars_est, traces = run_bad(k=k,l=l,n=n,y=y, Qs=Qs,Om=Om,idx_a=idx_a, idx_b=idx_b,
                                      sub_pops=sub_pops,idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      linearity=linearity,stable=stable,init=pars_est,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,batch_size=batch_size,
                                      verbose=verbose, sym_psd=sym_psd, max_zip_size=max_zip_size)

print('fitting time was ', time.time() - t)

print('\n')
print(psutil.virtual_memory())
print(psutil.swap_memory())

pars = pars_est          
plt.figure(figsize=(20,20))
for m in range(0,k+l): 
    Qrec = pars['C'][idx_a,:].dot(pars['X'][m*n:(m+1)*n, :]).dot(pars['C'][idx_b,:].T) 
    Qrec = Qrec + np.diag(pars['R'])[np.ix_(idx_a,idx_b)] if m==0 else Qrec
    
    plt.subplot(np.ceil( (k+l)/2 ), 2, m+1, adjustable='box-forced')

    if mmap:
        Q = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, mode='r', shape=(pa,pb))
    else:
        Q = Qs[m]
        
    plt.plot(Q.reshape(-1), Qrec.reshape(-1), '.')
    plt.title( ('m = ' + str(m) + ', corr = ' + 
    str(np.corrcoef( Qrec.reshape(-1), (Qs[m]).reshape(-1) )[0,1])))
    
    if mmap:
        del Q
        
    plt.xlabel('true covs')
    plt.ylabel('est. covs')
plt.show()
    
plt.figure(figsize=(20,10))
plt.plot(traces[0])
plt.xlabel('iteration count')
plt.ylabel('target loss')
plt.title('loss function vs. iterations')
plt.show()

if p <= 1000:
    plot_outputs_l2_gradient_test(pars_true=pars_true, pars_init=pars_init, 
                                  pars_est=pars_est, k=k, l=l, Qs=Qs, 
                                       Om=Om, traces = traces, idx_a=idx_a, idx_b=idx_b,
                                       linearity=linearity, idx_grp = idx_grp, co_obs = co_obs, 
                                       if_flip = True, m = 0)

In [None]:
pars = pars_true   
plt.figure(figsize=(20,50))
for m in range(0,k+l): 
    if 'X' in pars.keys():
        Qrec = pars['C'][idx_a,:].dot(pars['X'][m*n:(m+1)*n, :]).dot(pars['C'][idx_b,:].T) 
    else:
        Qrec = pars['C'][idx_a,:].dot(np.linalg.matrix_power(pars['A'],m).dot(pars['Pi'])).dot(pars['C'][idx_b,:].T) 
    Qrec = Qrec + np.diag(pars['R'])[np.ix_(idx_a,idx_b)] if m==0 else Qrec
    
    plt.subplot(np.ceil( (k+l)/2 ), 2, m+1, adjustable='box-forced')

    if mmap:
        Q = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, mode='r', shape=(pa,pb))
    else:
        Q = Qs[m]
        
    plt.plot(Q.reshape(-1), Qrec.reshape(-1), '.')
    plt.title( ('m = ' + str(m) + ', corr = ' + 
    str(np.corrcoef( Qrec.reshape(-1), (Qs[m]).reshape(-1) )[0,1])))
    
    if mmap:
        del Q
        
    plt.xlabel('true covs')
    plt.ylabel('est. covs')
plt.show()

In [None]:
pars = pars_est
m = 0

if not 'X' in pars.keys(): 
    #X_s = SSID_Hankel_loss.s_X_l2_Hankel_fully_obs(pars['C'], pars['R'], Qs, k, l, idx_grp, co_obs)
    #pars['X'] = X_s
    pars['X'] = np.vstack([np.linalg.matrix_power(pars['A'], m_).dot(pars['Pi']) for m_ in range(k+l)])
print(SSID_Hankel_loss.f_l2_Hankel_nl(pars['C'],pars['X'],None,pars['R'],
                                k,l,Qs,idx_grp,co_obs,idx_a,idx_b))
Q_est = pars['C'][idx_a,:].dot(pars['X'][m*n:(m+1)*n, :]).dot(pars['C'][idx_b,:].T)
if m == 0:
    min_ab = np.min((len(idx_a), len(idx_b)))
    Q_est[range(min_ab), range(min_ab)] += pars['R'][np.intersect1d(idx_a,idx_b)]
plt.imshow(Q_est, interpolation='none')
plt.show()
plt.imshow(Qs[m], interpolation='none')
plt.show()

print( np.sum( (Q_est - Qs[m])**2 ) )

plt.subplot(1,2,1)
plt.plot(Q_est.reshape(-1), Qs[m].reshape(-1), '.')
plt.hold(True)
plt.plot(np.diag(Q_est), np.diag(Qs[m]), 'r.')
plt.subplot(1,2,2)
plt.plot(pars_true['R'], pars_est['R'], '.')
plt.show()

# Loading data (+intermediate results)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy.optimize import fmin_bfgs, check_grad
import glob, os

os.chdir('../core')
import SSID_Hankel_loss 
from utility import get_subpop_stats, draw_sys, gen_data
from SSID_Hankel_loss import run_bad, plot_outputs_l2_gradient_test, l2_bad_sis_setup
os.chdir('../dev')

import psutil
import time

# load ancient code for drawing from LDS ...
os.chdir('../../../../pyRRHDLDS/core')
import ssm_scripts
import ssm_fit
os.chdir('../../code_le_stitch/iterSSID/python/dev')


from scipy.io import savemat # store results for comparison with Matlab code   

os.chdir('../fits/')

mmap = True
data_path = '../fits/'

save_file = np.load('test.npz')
p,n,T,k,l = save_file['p'], save_file['n'], save_file['T'], save_file['k'], save_file['l']
pars_true = save_file['pars_true'].tolist()
pars_est, pars_init = save_file['pars_est'].tolist(), save_file['pars_init'].tolist()
idx_a, idx_b = save_file['idx_a'], save_file['idx_b']
pa, pb = len(idx_a), len(idx_b)
#pa, pb = p,p

Qs = []
for m in range(k+l):
    print('loading time-lagged covariance for lag ', str(m))
    Qs.append(np.memmap('Qs_'+str(m), dtype=np.float, mode='r', shape=(pa,pb)))

y = np.memmap('y', dtype=np.float, mode='r', shape=(T,p))


sub_pops = (np.arange(p), np.arange(p))
obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=sub_pops, p=p, verbose=False)

chunksize = 5000
max_zip_size = 5000

verbose=True
            
print('\n')
print(psutil.virtual_memory())
print(psutil.swap_memory())


In [None]:
y[:100, :]

In [None]:
pars = pars_est          
idx_a = np.arange(p)
idx_b = np.arange(p)
plt.figure(figsize=(20,50))
for m in range(0,k+l): 
    Qrec = pars['C'][idx_a,:].dot(pars['X'][m*n:(m+1)*n, :]).dot(pars['C'][idx_b,:].T) 
    Qrec = Qrec + np.diag(pars['R'])[np.ix_(idx_a,idx_b)] if m==0 else Qrec
    
    plt.subplot(np.ceil( (k+l)/2 ), 2, m+1, adjustable='box-forced')

    if mmap:
        Q = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, mode='r', shape=(pa,pb))
    else:
        Q = Qs[m]
        
    plt.plot(Q.reshape(-1), Qrec.reshape(-1), '.')
    #if m == 0:
    #    plt.hold(True)
    #    idx_ab = np.intersect1d(idx_a,idx_b)
    #    plt.plot(Q[idx12,idx12], np.diag(pars['C'][idx12,:].dot(pars['X'][m*n:(m+1)*n, :]).dot(pars['C'][idx12,:].T)), 'r.')
    plt.title( ('m = ' + str(m) + ', corr = ' + 
    str(np.corrcoef( Qrec.reshape(-1), (Qs[m][np.ix_(idx_a,idx_b)]).reshape(-1) )[0,1])) ) 
    
    if mmap:
        del Q
        
    plt.xlabel('true covs')
    plt.ylabel('est. covs')
plt.show()


# Additional turns

In [None]:
batch_size = T-k-l # batch_size = 1 (size-1 mini-batches), p (column mini-batches), None (full gradients)
max_zip_size = (T-k-l)//batch_size
max_iter  = 100

a, b1, b2, e = 0.1, 0.9, 0.99, 1e-8
lag_range  = 1

t = time.time()

linearity = 'False'
stable = False
sym_psd = False
_, pars_est, traces = run_bad(k=k,l=l,n=n,y=y, Qs=Qs,Om=Om,Qs_full=None,idx_a=idx_a, idx_b=idx_b,
                                      sub_pops=sub_pops,idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      linearity=linearity,stable=stable,init=pars_est,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,batch_size=batch_size,
                                      verbose=verbose, sym_psd=sym_psd, lag_range = lag_range, max_zip_size=max_zip_size)

print('fitting time was ', time.time() - t)

print('\n')
print(psutil.virtual_memory())
print(psutil.swap_memory())

pars = pars_est          
plt.figure(figsize=(20,50))
for m in range(0,k+l): 
    Qrec = pars['C'][idx_a,:].dot(pars['X'][m*n:(m+1)*n, :]).dot(pars['C'][idx_b,:].T) 
    Qrec = Qrec + np.diag(pars['R'])[np.ix_(idx_a,idx_b)] if m==0 else Qrec
    
    plt.subplot(np.ceil( (k+l)/2 ), 2, m+1, adjustable='box-forced')

    if mmap:
        Q = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, mode='r', shape=(pa,pb))
    else:
        Q = Qs[m]
        
    plt.plot(Q.reshape(-1), Qrec.reshape(-1), '.')
    #if m == 0:
    #    plt.hold(True)
    #    idx_ab = np.intersect1d(idx_a,idx_b)
    #    plt.plot(Q[idx12,idx12], np.diag(pars['C'][idx12,:].dot(pars['X'][m*n:(m+1)*n, :]).dot(pars['C'][idx12,:].T)), 'r.')
    plt.title( ('m = ' + str(m) + ', corr = ' + 
    str(np.corrcoef( Qrec.reshape(-1), (Qs[m]).reshape(-1) )[0,1])))
    
    if mmap:
        del Q
        
    plt.xlabel('true covs')
    plt.ylabel('est. covs')
plt.show()
    
plt.figure(figsize=(20,10))
plt.plot(traces[0])
plt.xlabel('iteration count')
plt.ylabel('target loss')
plt.title('loss function vs. iterations')
plt.show()

if p <= 1000:
    plot_outputs_l2_gradient_test(pars_true=pars_true, pars_init=pars_init, 
                                  pars_est=pars_est, k=k, l=l, Qs=Qs, 
                                       Qs_full=Qs, Om=Om, traces = traces, idx_a=idx_a, idx_b=idx_b,
                                       linearity=linearity, idx_grp = idx_grp, co_obs = co_obs, 
                                       if_flip = True, m = 0)

# just one more turn...

In [None]:
from scipy.io import savemat # store results for comparison with Matlab code   

os.chdir('../fits/')

save_file = 'test_time_Tinf_p10000n100r1_2'


np.savez(save_file, 
         pars_init=pars_init,
         pars_true=pars_true, 
         pars_est=pars_est,
         #y=y,
         p=p,
         n=n,
         T=T,
         k=k,
         l=l,
         traces=traces,
         batch_size=batch_size,
         linearity=linearity)  

# Oops... (recovering broken run)