In [1]:
%matplotlib inline
import numpy as np
import lib.io.stan
import matplotlib.pyplot as plt
import os

In [2]:
data_dir = 'datasets/id001_ac'
results_dir = 'results/exp10'
os.makedirs(results_dir,exist_ok=True)
os.makedirs(f'{results_dir}/logs',exist_ok=True)
os.makedirs(f'{results_dir}/figures',exist_ok=True)

network = np.load(f'{data_dir}/AC_network.npz')
SC = network['SC']
SC = SC / np.max(SC)
SC[np.diag_indices(SC.shape[0])] = 0
gain_mat = network['gain_mat']

syn_data = np.load(f'{data_dir}/AC_fit_trgt.npz')

In [8]:
nn = SC.shape[0]
ns = gain_mat.shape[0]
nt = syn_data['fit_trgt'].shape[0]
I1 = 3.1

stan_fname = 'vep-snsrfit'
lib.io.stan.create_process(['bash','/home/anirudhnihalani/scripts/stancompile.sh', stan_fname],block=True)

for sigma in np.arange(0.1,2.1,0.1):
    data = {'nn':nn, 'ns':ns, 'nt':nt, 'I1':I1, 'SC':SC, 'gain': gain_mat,
            'sigma':sigma, 'seeg':syn_data['fit_trgt']}
    
    input_Rfile = f'fit_data_sigma{sigma:0.2}.R'
    os.makedirs(f'{data_dir}/Rfiles',exist_ok=True)
    lib.io.stan.rdump(f'{data_dir}/Rfiles/{input_Rfile}',data)

    nchains = 8
    with open('vep-snsrfit.sh','r') as fd:
        slurm_script = fd.read().format(f'{data_dir}/Rfiles', results_dir, input_Rfile, nchains, sigma)
    with open(f'tmp/vep-snsrfit-sigma{sigma:0.1f}.sh','w') as fd:
        fd.write(slurm_script)
    lib.io.stan.create_process(['sbatch',f'tmp/vep-snsrfit-sigma{sigma:0.1f}.sh'],block=False)


/home/anirudhnihalani/vep.stan
make: `/home/anirudhnihalani/vep.stan/vep-snsrfit' is up to date.


In [None]:
import lib.io.stan
# import importlib
# importlib.reload(lib.io.stan)

csv_fname = 'results/exp9/samples_sigma0.1_chain4.csv'
nwarmup = 500
nsampling = 500
ignore_warmup = True
variables_of_interest = ['lp__','accept_stat__','stepsize__','treedepth__','n_leapfrog__','divergent__','energy__','x0']
pstr_samples_1 = lib.io.stan.read_samples(csv_fname,nwarmup,nsampling,ignore_warmup,variables_of_interest) # read sampler diagnostics and x0 for all sampling iterations

# csv_fname = 'results/exp9/samples_sigma0.1_chain4.csv'
# nwarmup = 500
# nsampling = 10
# ignore_warmup = True
# variables_of_interest = ['x','z']
# pstr_samples_2 = lib.io.stan.read_samples(csv_fname,nwarmup,nsampling,ignore_warmup,variables_of_interest) # read 10 samples of hidden state variables x and z

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

plt.figure(figsize=(20,10))
plt.subplot(211)
plt.violinplot(pstr_samples_1['x0'][:,:]);
xtick_labels = []
for i in range(84):
    if(i%2 == 0):
        xtick_labels.append(str(i+1))
    else:
        xtick_labels.append('')
plt.xticks(np.r_[1:85],xtick_labels);
plt.xlabel('Region#',fontsize=15);
plt.ylabel('$x_0$',fontsize=15);

# Plot the HMC convergence diagnostics
plt.figure(figsize=(20,10))
plt.subplot(4,2,1)
plt.plot(pstr_samples_1['lp__'])
plt.xlabel('Iteration')
plt.ylabel('log prob.')

plt.subplot(4,2,2)
plt.plot(pstr_samples_1['energy__'])
plt.xlabel('Iteration')
plt.ylabel('energy')

plt.subplot(4,2,3)
plt.plot(pstr_samples_1['accept_stat__'])
plt.xlabel('Iteration')
plt.ylabel('accept stat.')

plt.subplot(4,2,4)
plt.plot(pstr_samples_1['stepsize__'])
plt.xlabel('Iteration')
plt.ylabel('step size')

plt.subplot(4,2,5)
plt.plot(pstr_samples_1['treedepth__'])
plt.xlabel('Iteration')
plt.ylabel('tree depth')

plt.subplot(4,2,6)
plt.plot(pstr_samples_1['n_leapfrog__'])
plt.xlabel('Iteration')
plt.ylabel('n_leapfrog')

plt.subplot(4,2,7)
plt.plot(pstr_samples_1['divergent__'])
plt.xlabel('Iteration')
plt.ylabel('divergent')

plt.tight_layout();

# Mean and 2*std of source activity(x) estimated from posterior samples
plt.figure(figsize=(15,20))
x_mean = np.mean(pstr_samples_2['x'], axis = 0)
x_std = np.std(pstr_samples_2['x'], axis = 0)
nt = x_mean.shape[0]
nn = x_mean.shape[1]
for i in range(nn):
    plt.plot(x_mean[:,i]+4*i)
    plt.fill_between(np.r_[0:nt], x_mean[:,i] - 2*x_std[:,i] + 4*i, x_mean[:,i] + 2*x_std[:,i] + 4*i,alpha=0.1)
plt.title('source activity(x)',fontsize=15);
plt.xlabel('time',fontsize=15);
plt.ylabel('Region#',fontsize=15);
plt.yticks(np.mean(x_mean,axis=0) + 4*np.r_[0:nn], np.r_[1:nn+1]);