In [None]:
import sys
path2cpp_pkg = "/Users/mariusmahiout/Documents/repos/ising_core/build"
sys.path.append(path2cpp_pkg)
import ising

import os
os.chdir("/Users/mariusmahiout/Documents/repos/ising_core/python/src")
import preprocessing as pre
import model_eval as eval
import utils as utils
import misc_plotting as misc_plotting
import isingfitter as fitter
os.chdir("../..")
print(os.getcwd())

import numpy as np
import matplotlib.pyplot as plt
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objs as go
from IPython.display import display, HTML
from ipywidgets import HBox, VBox, widgets
import scipy
import pandas as pd

plotly.offline.init_notebook_mode()
display(HTML(
    '<script type="text/javascript" async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-MML-AM_SVG"></script>'
))

In [None]:
import plotly.graph_objects as go

def get_hex_colors(num_colors, cmap_name="plasma"):
    cmap = plt.get_cmap(cmap_name)
    colors = cmap(np.linspace(0, 1, num_colors + 1))

    hex_colors = [
        "#%02x%02x%02x" % (int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255))
        for rgb in colors[1:]
    ]
    return hex_colors

def plot_distributions(
    fig: go.FigureWidget,
    labels: list,
    colors: list,
    obs_datas: list,
    row: int,
    col: int,
    num_bins: int,
    ftr_symbol: str
):

    all_data = np.concatenate(obs_datas)
    min_val, max_val = np.min(all_data), np.max(all_data)
    bin_width = (max_val - min_val) / num_bins
    bin_edges = np.arange(min_val, max_val + bin_width, bin_width)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    
    showlegend = row == 1 and col == 1
    for obs_data, label, color in zip(obs_datas, labels, colors):
        values, distr = utils.get_distr_and_range(obs_data, is_discrete=False, value_range=bin_edges)
        fig.add_trace(
            go.Scatter(
                x=bin_centers,
                y=distr,
                mode="lines",
                name=label,
                line=dict(color=color),
                marker=dict(color=color, size=5),
                legendgroup=label,
                showlegend=showlegend,
            ),
            row=row,
            col=col,
        )
    fig.update_xaxes(title_text=ftr_symbol, row=row, col=col)
    fig.update_yaxes(title_text="Rel. Freq.", row=row, col=col)

def get_all_recording_dcorrs(all_samples: list, dts: list) -> list:
    return [
        [sample.getDelayedCorrs(dt).flatten() for dt in dts] for sample in all_samples
    ]

def plot_ftr_distribution(fig, labels, colors, ftr_vals, num_bins, ftr_symbol, row):
    for i in range(len(ftr_vals)):
        plot_distributions(
            fig, labels, colors, ftr_vals[i], row=row, col=1, num_bins=num_bins, ftr_symbol=ftr_symbol
        )


def get_ks_results(ftr_vals, ftr_range):
    pvals = np.zeros((len(ftr_range), len(ftr_range)))
    stats = np.zeros((len(ftr_range), len(ftr_range)))

    for i in range(len(ftr_range)):
        for j in range(len(ftr_range)):
            data1 = ftr_vals[0][i]
            data2 = ftr_vals[0][j] 
            stat, pval = scipy.stats.ks_2samp(data1, data2)
            stats[i,j] = stat
            pvals[i,j] = pval

    stats = pd.DataFrame(stats, columns=ftr_range, index=ftr_range)
    pvals = pd.DataFrame(pvals, columns=ftr_range, index=ftr_range)
    return stats, pvals



In [None]:
sample_names = ["performing1", "performing2", "observing1", "observing2"]
path = "analyses/empirical_obs/"
utils.make_dir(path)


bin_width = 50 # ms
num_units = None


mouse_name = "Angie"
# angie performing 1
recording_fname = "RESULTS_Angie_20170825_1220_allbeh_1000s.mat"

In [None]:
# For delayed covariances
sample_original = pre.get_recording_sample(
    fname=recording_fname, 
    mouse_name=mouse_name, 
    bin_width=bin_width, 
    num_units=num_units
    )
sample_shuffled_bins = pre.get_recording_sample(
    fname=recording_fname, 
    mouse_name=mouse_name, 
    bin_width=bin_width, 
    num_units=num_units,
    is_shuffled_bins=True,
    )


num_rows = 2
num_cols = 1

fig = go.FigureWidget(
    make_subplots(
        rows=num_rows,
        cols=num_cols,
    )
)

def save_figure(b):
    fig.write_image(path)
    print(f"Saved to {path}")

dts = [1, 10, 100]
labels_dts = [rf"$$\Large \delta t = {dt}$$" for dt in dts]
colors = get_hex_colors(len(dts), "magma_r")

dcorrs = get_all_recording_dcorrs([sample_original], dts)
dcorrs_shuffled = get_all_recording_dcorrs([sample_shuffled_bins], dts)

num_bins=50
plot_ftr_distribution(fig, labels_dts, colors, dcorrs, num_bins, ftr_symbol = r"$$\Large D_{ij}(\delta t) $$", row=1)
plot_ftr_distribution(fig, labels_dts, colors, dcorrs_shuffled, num_bins, ftr_symbol = r"$$\Large D_{ij}(\delta t) $$", row=2)

fig.update_layout(
    height=400 * num_rows, width=500 * num_cols, margin=dict(l=100, t=40, b=70)
)
fig.update_xaxes(range=[-0.005, 0.01], row=1, col=1)
fig.update_xaxes(range=[-0.005, 0.01], row=2, col=1)
fig.show()

path = "analyses/empirical_obs/" + "dcorrs_dt_sens.pdf"
if path is not None:  # to-do: also check if is notebook
    button = widgets.Button(description="Save Figure")
    button.on_click(save_figure)
    display(button)

In [None]:
np.sqrt(dcorrs[0][0].shape)

In [None]:
stats_d, pvals_d = get_ks_results(dcorrs, labels_dts)
stats_ds, pvals_ds = get_ks_results(dcorrs_shuffled, labels_dts)

pd.options.display.float_format = "{:.2e}".format

print("ORIGINAL")
print("Stats: ")
display(stats_d)
print()
print("P-values: ")
display(pvals_d)

print()
print("---------------")

print("SHUFFLED")
print("Stats: ")
display(stats_ds)
print()
print("P-values: ")
display(pvals_ds)

In [None]:
samples_bw[0].getNumUnits()

In [None]:
# for bin-width sensitivity
bin_widths = [50, 150, 250]
labels_bw = [str(bin_width) + " ms" for bin_width in bin_widths]
samples_bw = [
    pre.get_recording_sample(
        fname=recording_fname, 
        mouse_name=mouse_name, 
        bin_width=bin_width
    ) for bin_width in bin_widths
]

num_rows = 2
num_cols = 1

fig = go.FigureWidget(
    make_subplots(
        rows=num_rows,
        cols=num_cols,
    )
)

def save_figure(b):
    fig.write_image(path)
    print(f"Saved to {path}")

colors = get_hex_colors(len(labels_bw), "magma_r")

means_bw = utils.get_all_recording_means([samples_bw])
pcorrs_bw = utils.get_all_recording_pcorrs([samples_bw])

num_bins=25
plot_ftr_distribution(fig, labels_bw, colors, means_bw, num_bins, ftr_symbol = r"$$\Large \langle \sigma_i \rangle$$", row=1)
plot_ftr_distribution(fig, labels_bw, colors, pcorrs_bw, num_bins, ftr_symbol = r"$$\Large \langle \sigma_i \sigma_j \rangle$$", row=2)

fig.update_layout(
    height=400 * num_rows, width=500 * num_cols, margin=dict(l=100, t=40, b=70)
)

fig.show()


path = "analyses/empirical_obs/" + "bin_width_sens.pdf"
if path is not None:  # to-do: also check if is notebook
    button = widgets.Button(description="Save Figure")
    button.on_click(save_figure)
    display(button)

In [None]:
stats_m_bw, pvals_m_bw = get_ks_results(means_bw, labels_bw)
stats_chi_bw, pvals_chi_bw = get_ks_results(pcorrs_bw, labels_bw)


print("Means:")
print("---------------")
print("Stats: ")
display(stats_m_bw)
print()
print("P-values: ")
display(pvals_m_bw)
        
print("Corrs:")
print("---------------")
print("Stats: ")
display(stats_chi_bw)
print()
print("P-values: ")
display(pvals_chi_bw)

In [None]:
def get_ks_sensitivity_results(samples: list, labels: list):
    states = [s.getStates() for s in samples]
    num_recs = len(labels)
    test_stats = np.zeros((num_recs, num_recs))
    pvals = np.zeros((num_recs, num_recs))

    for i in range(num_recs):
        for j in range(num_recs):
            stat, pval = scipy.stats.ks_2samp(states[i].flatten(), states[j].flatten())
            if isinstance(stat, np.generic):
                stat = stat.item()
            if isinstance(pval, np.generic):
                pval = pval.item()
            test_stats[i, j] = stat
            pvals[i, j] = pval

    test_stats = pd.DataFrame(test_stats, index=labels, columns=labels)
    pvals = pd.DataFrame(pvals, index=labels, columns=labels)

    pd.options.display.float_format = "{:.2e}".format

    return test_stats, pvals

In [None]:
stats_bw, pvals_bw = get_ks_sensitivity_results(samples=samples_bw, labels=labels_bw)

print("Stats:")
display(stats_bw)
print()
print("P-values:")
display(pvals_bw)

In [None]:
def get_partitioned_sample(
    fname,
    mouse_name,
    bin_width=50,
    num_subsamples=1,
    data_dir="data",
):
    mat_dict = pre.load_recordings(fname, mouse_name, data_dir)
    states = pre.get_recording_states(mat_dict)
    states = pre.reduce_time_resolution(states, bin_width)
    subsamples = pre.get_nonoverlapping_subsamples(states, num_subsamples)
    subsamples = [pre.binary2ising(s) for s in subsamples]
    subsamples = [ising.Sample(s.T) for s in subsamples]
    return subsamples  # list(map(lambda s: ising.Sample(binary2ising(s)), subsamples))

In [None]:
# for sys-size sensitivity
samples_ss = get_partitioned_sample(
    fname=recording_fname, 
    mouse_name=mouse_name, 
    num_subsamples=3
)
labels_ss = [f"Sample {i+1}" for i in range(len(samples_ss))]

samples_ss[1].getNumUnits()

In [None]:



num_rows = 2
num_cols = 1

fig = go.FigureWidget(
    make_subplots(
        rows=num_rows,
        cols=num_cols,
    )
)

def save_figure(b):
    fig.write_image(path)
    print(f"Saved to {path}")

colors = get_hex_colors(len(labels_ss), "magma_r")

means_ss = utils.get_all_recording_means([samples_ss])
pcorrs_ss = utils.get_all_recording_pcorrs([samples_ss])

num_bins=25
plot_ftr_distribution(fig, labels_ss, colors, means_ss, num_bins, ftr_symbol = r"$$\Large \langle \sigma_i \rangle$$", row=1)
plot_ftr_distribution(fig, labels_ss, colors, pcorrs_ss, num_bins, ftr_symbol = r"$$\Large \langle \sigma_i \sigma_j \rangle$$", row=2)

fig.update_layout(
    height=400 * num_rows, width=500 * num_cols, margin=dict(l=100, t=40, b=70)
)

fig.show()


path = "analyses/empirical_obs/" + "subsample_sens.pdf"
if path is not None:  # to-do: also check if is notebook
    button = widgets.Button(description="Save Figure")
    button.on_click(save_figure)
    display(button)

In [None]:
91*3

In [None]:
stats_m_ss, pvals_m_ss = get_ks_results(means_ss, labels_ss)
stats_chi_ss, pvals_chi_ss = get_ks_results(pcorrs_ss, labels_ss)


print("Means:")
print("---------------")
print("Stats: ")
display(stats_m_ss)
print()
print("P-values: ")
display(pvals_m_ss)
        
print("Corrs:")
print("---------------")
print("Stats: ")
display(stats_chi_ss)
print()
print("P-values: ")
display(pvals_chi_ss)

In [None]:
stats_ss, pvals_ss = utils.get_ks_sensitivity_results(samples=samples_ss, labels=labels_ss)

print("Stats:")
display(stats_ss)
print()
print("P-values:")
display(pvals_ss*100)

In [None]:
# for bin-width sensitivity
sys_sizes = [50, 100, 200]
labels_sz = [rf"$N = {num_units}$" for num_units in sys_sizes]
samples_sz = [
    pre.get_recording_sample(
        fname=recording_fname, 
        mouse_name=mouse_name, 
        bin_width=50,
        num_units=num_units
    ) for num_units in sys_sizes
]

num_rows = 2
num_cols = 1

fig = go.FigureWidget(
    make_subplots(
        rows=num_rows,
        cols=num_cols,
    )
)

def save_figure(b):
    fig.write_image(path)
    print(f"Saved to {path}")

colors = get_hex_colors(len(labels_sz), "magma_r")

means_sz = utils.get_all_recording_means([samples_sz])
pcorrs_sz = utils.get_all_recording_pcorrs([samples_sz])

num_bins=25
plot_ftr_distribution(fig, labels_sz, colors, means_sz, num_bins, ftr_symbol = r"$$\Large \langle \sigma_i \rangle$$", row=1)
plot_ftr_distribution(fig, labels_sz, colors, pcorrs_sz, num_bins, ftr_symbol = r"$$\Large \langle \sigma_i \sigma_j \rangle$$", row=2)

fig.update_layout(
    height=400 * num_rows, width=500 * num_cols, margin=dict(l=100, t=40, b=70)
)

fig.show()


path = "analyses/empirical_obs/" + "sys_size_sens.pdf"
if path is not None:  # to-do: also check if is notebook
    button = widgets.Button(description="Save Figure")
    button.on_click(save_figure)
    display(button)

In [None]:
stats_sz, pvals_sz = utils.get_ks_sensitivity_results(samples=samples_sz, labels=labels_sz)

print("Stats:")
display(stats_sz)
print()
print("P-values:")
display(pvals_sz)

In [None]:
np.array([1.38e-01,	8.43e-03,	9.64e-01]) * 100