# Running with 

In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import yaml
import pickle
import argparse
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
print('GPU devices available:', tf.config.list_physical_devices('GPU'))

from scipy import io
import braincoder
from braincoder.utils.visualize import *
import prfpy_csenf

from os.path import join as opj
from braincoder.models import ContrastSensitivity, ContrastSensitivityWithHRF
from braincoder.hrf import SPMHRFModel, CustomHRFModel, HRFModel
from braincoder.stimuli import ContrastSensitivityStimulus
from braincoder.bprf_mcmc import BPRF
from braincoder.optimize import ParameterFitter
%matplotlib inline

In [4]:
with open("config.yml", 'r') as f:
    config = yaml.safe_load(f)

In [5]:
config

In [None]:
sim_file = './output'
with open(opj(sim_file,'ground_truth.pkl'), 'rb') as file:
    gt = pickle.load(file)
    bounds = gt['bounds']
    data = gt['data']
    parameters = gt['parameters']
with open(opj(sim_file,'cfitter.pkl'), 'rb') as file:
    cf = pickle.load(file)
    cfitter = cf['cfitter']
    cfit_time = cf['cfit_time']
refined_pars = cfitter.estimated_parameters
model = cfitter.model    
with open(opj(sim_file,'bfitter.pkl'), 'rb') as file:
    bf = pickle.load(file)
    bfitter = bf['bfitter']
    bfit_time = bf['bfit_time']
print(cfit_time)
print(bfit_time)

In [None]:
data.plot(legend=False)
sns.despine()
data.shape

In [None]:
plt.plot(bfitter.mcmc_stats['log_prob'])

In [None]:
fig, axs = plt.subplots(2, 4, figsize=(20, 10))
axs = axs.flatten()
param_names = parameters.columns
rsq = cfitter.get_rsq(parameters=refined_pars)
for i, param in enumerate(param_names):
    ax = axs[i]
    ax.scatter(
        parameters[param], 
        refined_pars[param], 
        c=rsq, 
        alpha=0.6,
        # cmap='viridis',
        vmin=0, vmax=1,
        )
    corr = np.corrcoef(parameters[param], refined_pars[param])[0, 1]
    ax.set_title(f'{param} (corr={corr:.2f})')
    ax.set_xlabel('Initial Parameters')
    ax.set_ylabel('Refined Parameters')
    ax.plot([bounds[param][0], bounds[param][1]], 
            [bounds[param][0], bounds[param][1]], 'k--')
    ax.set_xlim(bounds[param])
    ax.set_ylim(bounds[param])
    # ax.set_aspect('square')
    sns.despine()

plt.tight_layout()
plt.show()

In [None]:
print(bfitter.mcmc_stats.keys())
plt.plot(bfitter.mcmc_stats['step_size'].numpy())



In [None]:
idx = 20
fig, axs = plt.subplots(3,3, figsize=(12,4))
fig.suptitle(f'idx = {idx}')
axs = axs.flatten()
for i,p in enumerate(bfitter.model_labels):    
    axs[i].plot(
        bfitter.mcmc_sampler[idx][p]
    )
    axs[i].axhline(parameters[p][idx], linestyle=':', color='k', label='ground truth')
    axs[i].axhline(refined_pars[p][idx] , linestyle=':', color='r', label='classical fit')
    axs[i].set_ylabel(p)
    axs[i].set_ylim(bounds[p])
    sns.despine()

axs[i].legend()
fig.set_tight_layout('tight')
plt.figure()
plt.plot(data.iloc[:,idx], ':k')
preds = bfitter.get_predictions(parameters=bfitter.mcmc_sampler[idx])
plt.plot(preds, alpha=0.1, color='g')
cpreds = model.predict(parameters=refined_pars)
plt.plot(cpreds.iloc[:,idx], ':r')

In [None]:
idx = 0
burn_in = 100
this_mcmc_pars = bfitter.mcmc_sampler[idx].iloc[burn_in:,:]
rsq = bfitter.get_rsq_for_idx(idx=idx, parameters=this_mcmc_pars)
bloop
sns_plot = sns.pairplot(
    bfitter.sampler[idx].iloc[burn_in:,:],
    # plot_kws={'color':rsq_cols[burn_in:,:]},
    # corner=True, 
    diag_kind='kde'
)
sns_plot.map_lower(sns.kdeplot, levels=4, color=0.2)
# sns_plot.

edit_pair_plot(
    sns_plot.axes, 
    lines_dict=parameters.iloc[idx,:].to_dict(), 
    lim_dict=bounds, 
    color='g', linestyle=':', lw=4, label='Truth'
)
# edit_pair_plot(
#     sns_plot.axes, init_pars.iloc[idx,:].to_dict(), color='c', linestyle=':', lw=4,
# )
# edit_pair_plot(
#     sns_plot.axes, refined_pars.iloc[idx,:].to_dict(), color='r', linestyle=':', lw=4,
# )
plt.legend()
plt.show()

In [None]:
bfitter.mcmc_sampler[idx]

In [None]:
this_mcmc_pars