In [None]:
import numpy as np
import xarray as xr
import pandas as pd
from scipy import stats

from pathta import Study

from frites.conn import conn_ii
from frites.core import gcmi_1d_cc
from frites import set_mpl_style

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl

import json
with open("../config.json", 'r') as f:
    cfg = json.load(f)

set_mpl_style('frites')

plt.rcParams['xtick.labelsize'] = 20
plt.rcParams['ytick.labelsize'] = 20
plt.rcParams['axes.titlesize'] = 25
plt.rcParams['axes.labelsize'] = 24

# **Data simulation**
---

In [None]:
###############################################################################
n_epochs = 200
n_times = 300
sf = 64.
###############################################################################


def compute_mi(x, y=None):
    return gcmi_1d_cc(x, y)

def simulate_data(ii_type='redundancy'):
    bump = np.hanning(50).reshape(1, -1)
    x = np.random.rand(n_epochs, n_times)
    y = np.random.rand(n_epochs, n_times)
    pe = np.random.rand(n_epochs)
    sl = slice(100, 100 + bump.shape[1])

    if ii_type == 'redundancy':
        x[:, sl] += 1. * pe.reshape(-1, 1) * bump
        y[:, sl] += 1.5 * pe.reshape(-1, 1) * bump
    elif ii_type == 'synergy':
        x[0:100, sl] += 1.5 * pe[0:100].reshape(-1, 1) * bump
        y[100::, sl] += 4. * pe[100::].reshape(-1, 1) * bump
    
    xy = xr.DataArray(
        np.stack((x, y), axis=1), dims=('trials', 'roi', 'times'),
        coords=(pe, ['X', 'Y'], np.arange(n_times) / sf)
    )

    # compute mi
    mi = xr.apply_ufunc(
        compute_mi, xy, input_core_dims=[['trials']], vectorize=True, 
        kwargs={'y': pe}
    )

    # compute ii
    ii = conn_ii(xy, roi='roi', times='times', y=pe, mi_type='cc', verbose=False)

    return ii, mi


# define redundant and synergistic interactions
ii_red, mi_red = simulate_data(ii_type='redundancy')
ii_syn, mi_syn = simulate_data(ii_type='synergy')




In [None]:
ii_red.data.shape

In [None]:
###############################################################################
nb_kw = dict(size=27, weight='bold')
###############################################################################

fig, axs = plt.subplots(2, 2, figsize=(18, 10), sharex=True, sharey='col')

ylim = max(abs(ii_red).max(), abs(ii_syn).max())

cmap = plt.get_cmap('coolwarm')
norm = mpl.colors.Normalize(vmin=-ylim, vmax=ylim)
cm = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
ylim *= 1.1


def lineplot(ax, x, y):
    npts = len(x)
    ax.set_prop_cycle(color=[cm.to_rgba(i) for i in y])
    for i in range(npts - 1):
        ax.plot(x[i:i + 2], y[i:i + 2], lw=4)

# redundancy

plt.sca(axs[0, 0])
mi_red.plot(x='times', hue='roi')
plt.ylabel("Local MI [bits]")
plt.xlabel("")
axs[0, 0].text(*tuple([-0.15, 1.05]), 'A', transform=axs[0, 0].transAxes, **nb_kw)

plt.sca(axs[0, 1])
lineplot(axs[0, 1], ii_red['times'].data, ii_red.data.squeeze())
# ii_red.plot(color='C3')
plt.title("Redundant II", fontweight="bold")
plt.xlabel("")
plt.ylabel("II [bits]")
plt.ylim(-ylim, ylim)

plt.sca(axs[1, 0])
mi_syn.plot(x='times', hue='roi')
plt.ylabel("Local MI [bits]")
plt.xlabel("Times [s]")
axs[1, 0].text(*tuple([-0.15, 1.05]), 'B', transform=axs[1, 0].transAxes, **nb_kw)

plt.sca(axs[1, 1])
lineplot(axs[1, 1], ii_syn['times'].data, ii_syn.data.squeeze())
# ii_syn.plot(color='C3')
plt.title("Synergistic II", fontweight="bold")
plt.ylim(-ylim, ylim)
plt.ylabel("II [bits]")
plt.xlabel("Times [s]")

In [None]:
save_to = cfg['export']['review']
cfg_export = cfg['export']['cfg']

fig.savefig(f'{save_to}/supp_sim_local_ii.png', **cfg_export)