In [1]:
# ------------------------------------------------------------------------------ #
# @Author:        F. Paul Spitzner
# @Email:         paul.spitzner@ds.mpg.de
# @Created:       2024-03-23 16:05:29
# @Last Modified: 2024-03-23 16:06:19
# ------------------------------------------------------------------------------ #
# Needs the data from `bayes_analysis.ipynb`
# ------------------------------------------------------------------------------ #

import os
import pymc as pm
import nutpie
import numpy as np
import pandas as pd
import arviz as az
import sys
import matplotlib.pyplot as plt

import logging

logging.basicConfig(
    format="%(asctime)s | %(levelname)-8s | %(name)-s | %(funcName)-s | %(message)s",
    level=logging.WARNING,
)
log = logging.getLogger("notebook")
log.setLevel("DEBUG")

extra_path = os.path.abspath("../")
sys.path.append(extra_path)
log.info(f"project directory: {extra_path}")

from ana import utility as utl
from ana import plot_helper as ph
from ana import bayesian_models as bm

ph.log.setLevel("DEBUG")
utl.log.setLevel("DEBUG")

# TODO: get these settings from the legacy utility file to recreate plots
plot_settings = utl.get_default_plot_settings()
plt.rcParams.update(plot_settings['rcparams'])
plot_settings['imgdir'] = '../img'

data_dir = os.path.abspath("../dat")
data_dir = os.path.abspath("/data.nst/lucas/projects/mouse_visual_timescales_predictability/paper_code/experiment_analysis/data")


2024-05-21 15:39:10,662 | INFO     | notebook | <module> | project directory: /data.nst/lucas/projects/mouse_visual_timescales_predictability/paper_code_repo/experiment_analysis


In [None]:
# Plot hierarchical parameters 


In [None]:
## posterior visualization (non-hierarchical parameters)
#%%
from itertools import product
from tqdm.notebook import tqdm

models = ['sgm', 'lm']
measures = ["tau_double", "tau_R", "R_tot"]
stimuli = ["natural_movie_three", "spontaneous", "natural_movie_one_more_repeats"]

combinations = list(product(measures, stimuli, models))
loos = dict()

for meas, stim, model in tqdm(combinations):
    idata = az.from_netcdf(f"{data_dir}/bayes_{model}_{meas}_{stim}.nc")
    loos[f"{model}_{meas}_{stim}"] = az.loo(idata, pointwise=True)


    if meas == "R_tot":
        n_axes = 3
        fig_height = 4.5
    else:
        n_axes = 4
        fig_height = 6

        
    fig1, axes = plt.subplots(n_axes, 1, figsize=(plot_settings["panel_width"], fig_height))
    fig1.subplots_adjust(left=0.01, right=0.9, top=0.9, bottom=0.1, hspace=1)
    axes[0].axvline(x=1, ls = "--", lw =2, color = "0.0")    
    axes[1].axvline(x=1, ls = "--", lw =2, color = "0.0")   
    # utl.plot_posterior(trc[-250:], var_names=['mu_intercept'],
    #                    #coords=0,
    #                    ax=axes[0], point_estimate='median',
    #                    hdi_prob=0.95,
    #                    transform=f_transf_int)

    utl.plot_posterior(trc[-1000:], var_names=[#'b0_intercept',
                                              'b_sign_rf',
                                              #'stimulus[T.spontaneous]',
                                              #'sign_rf[T.True]:stimulus[T.spontaneous]',
                                              # 'firing_rate',
                                              'b_log_fr'],
                       #coords=0,
                       ax=axes[:2], point_estimate='median',
                       hdi_prob=0.95,
                       transform=f_transf)
    utl.plot_posterior(trc[-1000:], var_names=['epsilon'],
                       ax=axes[2], point_estimate='median',
                       hdi_prob=0.95,
                       transform=f_transf_log)
    if not args.measure == "R_tot":
        utl.plot_posterior(trc[-1000:], var_names=['alpha'],
                           ax=axes[3], point_estimate='median',
                           hdi_prob=0.95)
    vars = [r'$\mathrm{θ_{rf}}$', r'$θ_{\log \nu}$', 'ε', r'$\alpha$']
    var_names = [r'$\exp(\mathrm{θ_{rf}})$ (responsiveness)', r'$\exp(θ_{\log \nu})$ (log fir. rate)', 'ε (scale)', r'$\alpha$ (shape)']
    for ax, var, var_name in zip(axes, vars[:n_axes], var_names[:n_axes]):
#'$θ_0$ (no rf, fir. rate={:.2f}Hz)'.format(10**np.mean(data['log_fr'])),
        #ax.set_xlabel(var)
        ax.set_ylabel('p({} | E)'.format(var))
        #ax.set_title(var_name)
        ax.set_xlabel(var_name)
        ax.set_title('')
        utl.make_plot_pretty(ax)

    axes[0].set_title(measure_name, ha='center')
    #axes[0].set_xticks([0.06, 0.065])
    #axes[1].set_xticks([-0.075, -0.08, -0.004])
    #axes[2].set_xticks([-0.012, -0.008, -0.004])
    #axes[3].set_xticks([-0.015, -0.01, -0.005])
    #axes[4].set_xticks([0.45, 0.5])
    #axes[5].set_xticks([0.048, 0.049, 0.05])

utl.save_plot(plot_settings, f"{__file__[:-3]}_posterior", allen_bo=args.allen_bo, stimulus=args.stimulus, measure=args.measure)

# with model:
#     fig0, ax0 = plt.subplots(1, 1, figsize=(plot_settings["panel_width"], 2))
#     fig0.subplots_adjust(left=0.01, right=0.9, top=0.9, bottom=0.1, hspace=1)
#
#     utl.plot_posterior(trc[-250:], var_names=['mu_area'],
#                        ax=ax0, point_estimate='median',
#                        hdi_prob=0.95,
#                        transform=f_transf)
#
#     ax0.set_ylabel('p({$θ_3$} | E)')
#     ax0.set_xlabel('$θ_3$ (hier. score)')
#     ax0.set_title('')
#     utl.make_plot_pretty(ax0)
#
#     axes[0].set_title(measure_name, ha='center')
#
# utl.save_plot(plot_settings, f"{__file__[:-3]}_posterior_hierarchy_score", measure=args.measure)

# posterior visualization (hierarchical parameters)

idata = az.from_pymc3(
    trace=trc,
    prior=prior_pred_samples,
    model=model,
)

post = idata.posterior.assign_coords(session_idx=idata.constant_data.session_idx)

# slope

fig0 = plt.figure(figsize=(plot_settings["panel_width"]*1.7, 3))
# fig0.suptitle(measure_name, ha = "center")
# fig0.suptitle("hierarchy score slope", ha = "center")

ax1 = plt.subplot(2,2,1)
ax2 = plt.subplot(2,2,2)
ax3 = plt.subplot(2,1,2)
fig0.subplots_adjust(left=0.01, right=0.9, top=0.9, bottom=0.1, hspace=0.6, wspace=0.4)
# if args.measure == "tau_C":
#     ax1.set_xlim([-0.05,0.22])
#     ax3.set_xlim([-0.24,0.44])
# if args.measure == "tau_R":
#     ax1.set_xlim([-0.02,0.14])
#     ax3.set_xlim([-0.12,0.27])
# if args.measure == "R_tot":
#     ax1.set_xlim([-0.18,0.06])
#     ax3.set_xlim([-0.31,0.21])

with model:
    utl.plot_posterior(post["mu_hierarchy_slope"].values[:,-1000:],
                       ax=ax1, point_estimate='median',
                       hdi_prob=0.95,
                       transform = f_transf_log)
    ax1.set_xlabel(r'mean slope $\mu_{\theta_{\mathrm{hs}}}$')
    ax1.set_ylabel(r'$p\left(\mu_{\theta_{\mathrm{hs}}} | E\right)$')
    utl.plot_posterior(post["sigma_hierarchy_slope"].values[:,-1000:],
                       ax=ax2, point_estimate='median',
                       hdi_prob=0.95,
                       transform = f_transf_log)
    ax2.set_xlabel(r'std slope $\sigma_{\theta_{\mathrm{hs}}}$')
    ax2.set_ylabel(r'$p\left(\sigma_{\theta_{\mathrm{hs}}} | E\right)$')


    for session_idx in range(len(sessions)):
        utl.plot_posterior(post["eff_session_hierarchy_slope"].values[:,-1000:,session_idx],
                       ax=ax3,
                       hdi = False,
                       color = sns.color_palette()[session_idx%10],
                       transform = f_transf_log)
    ax3.set_xlabel(r'hierarchy score slope $\theta_{\mathrm{hs}}$')
    ax3.set_ylabel(r'$p\left(\theta_{\mathrm{hs}} | E\right)$')


for ax in [ax1,ax2,ax3]:
    utl.make_plot_pretty(ax)

ax1.axvline(x=0, ls = "--", lw =2, color = "0.0")    
ax3.axvline(x=0, ls = "--", lw =2, color = "0.0")    

utl.save_plot(plot_settings, f"{__file__[:-3]}_posterior_slope", allen_bo=args.allen_bo, stimulus=args.stimulus, measure=args.measure)

# intercept

fig0 = plt.figure(figsize=(plot_settings["panel_width"]*1.6, 3))
# fig0.suptitle(measure_name, ha = "center")
# fig0.suptitle("intercept", ha = "center")

ax1 = plt.subplot(2,2,1)
ax2 = plt.subplot(2,2,2)
ax3 = plt.subplot(2,1,2)
fig0.subplots_adjust(left=0.01, right=0.9, top=0.9, bottom=0.1, hspace=0.6, wspace=0.3)


with model:
    utl.plot_posterior(post["mu_intercept"].values[:,-1000:],
                       ax=ax1, point_estimate='median',
                       hdi_prob=0.95,
                       transform = f_transf_log_int)

    ax1.set_xlabel(r'mean intercept $\mu_{\theta_0}$')
    ax1.set_ylabel(r'$p\left(\mu_{\theta_0} | E\right)$')
    utl.plot_posterior(post["sigma_intercept"].values[:,-1000:],
                       ax=ax2, point_estimate='median',
                       hdi_prob=0.95,
                       transform = f_transf_log)
    ax2.set_xlabel(r'std intercept $\sigma_{\theta_0}$')
    ax2.set_ylabel(r'$p\left(\sigma_{\theta_0} | E\right)$')

    for session_idx in range(len(sessions)):
        utl.plot_posterior(post["eff_session_intercept"].values[:,-1000:,session_idx],
                       ax=ax3,
                       transform = f_transf_int,
                       hdi = False,
                       color = sns.color_palette()[session_idx%10])
    if args.measure == "R_tot":
        ax3.set_xlabel(r'intercept $\exp(\theta_0)$')
    else:
        ax3.set_xlabel(r'intercept $\exp(\theta_0)$ (ms)')

    ax3.set_ylabel(r'$p\left(\theta_0 | E\right)$')
    # ax3.set_xlim([-0.55,0.22])

for ax in [ax1,ax2,ax3]:
    utl.make_plot_pretty(ax)

utl.save_plot(plot_settings, f"{__file__[:-3]}_posterior_intercept",  allen_bo=args.allen_bo, stimulus=args.stimulus, measure=args.measure)