## Prepare Rfiles

In [None]:
import numpy as np
import lib.io.stan
import glob
import matplotlib.pyplot as plt
import os

In [None]:
npts = 150
data_root_dir = 'datasets/retro'
res_root_dir = 'results/exp10/exp10.65.5'
for patient_dir in glob.glob(os.path.join(data_root_dir, 'id*')):
    patient_id = os.path.basename(patient_dir)
    results_dir = os.path.join(res_root_dir, patient_id)
    os.makedirs(os.path.join(results_dir, 'Rfiles'), exist_ok=True)
    os.makedirs(os.path.join(results_dir, 'figures'), exist_ok=True)
    os.makedirs(os.path.join(results_dir, 'logs'), exist_ok=True)
    os.makedirs(os.path.join(results_dir, 'results'), exist_ok=True)
    for szr_path in glob.glob(os.path.join(data_root_dir, patient_id, 'stan', 'fit_target_lpf0.3', 'obs_data*.npz')):
        szr_name = os.path.splitext(os.path.basename(szr_path))[0].split('obs_data_')[1]
        print(szr_path)
        print(szr_name)
        szr_data_full = np.load(szr_path) # Preprocessed SEEG data (i.e. SLP) without any downsampling
        ds_freq = int(szr_data_full['slp'].shape[0] / npts)
        fit_data = dict()
        fit_data['slp'] = szr_data_full['slp'][::ds_freq, :]
        fit_data['snsr_pwr'] = (fit_data['slp']**2).mean(axis=0)
        fit_data['SC'] = szr_data_full['SC']
        fit_data['gain'] = szr_data_full['gain']
        fit_data['nt'] = fit_data['slp'].shape[0]
        fit_data['ns'], fit_data['nn'] = fit_data['gain'].shape
        szr_data_full.close()
        lib.io.stan.rdump(os.path.join(results_dir, 'Rfiles', f'obs_data_{szr_name}.R'), fit_data)
        params_init = dict()
        params_init['x0_star_star'] = np.zeros(fit_data['nn'])
        params_init['amplitude_star_star'] = 0.0
        params_init['offset_star_star'] = 0.0
        params_init['K_star_star'] = 0.0
        params_init['tau0_star_star'] = 0.0
        params_init['alpha'] = 1.0
        lib.io.stan.rdump(os.path.join(results_dir, 'Rfiles', 'params_init.R'), params_init)
        
        plt.figure(figsize=(25,13))
        plt.subplot(211)
        plt.plot(fit_data['slp'], color='black', alpha=0.3);
        plt.xlabel('Time', fontsize=12)
        plt.ylabel('SLP', fontsize=12)

        plt.subplot(212)
        plt.bar(np.r_[1:fit_data['ns']+1],fit_data['snsr_pwr'], color='black', alpha=0.3);
        plt.xlabel('Time', fontsize=12)
        plt.ylabel('Power', fontsize=12)
        plt.title('SEEG channel power', fontweight='bold')
        plt.savefig(os.path.join(results_dir, 'figures', f'observed_data{szr_name}.png'))
        plt.close()