In [1]:
import os
os.chdir("/Users/mariusmahiout/Documents/repos/ising_core/python")
import src.preprocessing as pre
import src.model_eval as eval
import src.utils as utils
os.chdir("..")

In [18]:
analysis_name = "recording_test"

analysis_path = utils.get_analysis_path(analysis_name, 214, 50)
metadata = utils.get_metadata(
    num_units=214,
    is_empirical_analysis=True,
    num_bins=samples[0].getStates().shape[0],
    bin_width=50
)


In [3]:
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objs as go
from ipywidgets import HBox, VBox, widgets
import matplotlib.pyplot as plt

In [53]:
def get_colors_from_colormap(num_colors, cmap_name='viridis'):
    cmap = plt.get_cmap(cmap_name)  # Choose colormap
    colors = cmap(np.linspace(0, 1, num_colors+1))  # Generate colors
    # Convert RGBA to HEX
    hex_colors = ['#%02x%02x%02x' % (int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255)) for rgb in colors[:-1]]
    return hex_colors

In [54]:
def plot_histogram(fig: go.FigureWidget, labels: list, colors: list, obs_datas: list, row: int, col: int, num_bins: int):
    
    all_data = np.concatenate(obs_datas)
    min_val, max_val = np.min(all_data), np.max(all_data)
    bin_width = (max_val - min_val) / num_bins
    bin_edges = np.arange(min_val, max_val + bin_width, bin_width)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    showlegend = (row == 1 and col == 1)
    for obs_data, label, color in zip(obs_datas, labels, colors):
        counts, _ = np.histogram(obs_data, bins=bin_edges, density=True)

        adjusted_counts = counts * bin_width
        fig.add_trace(
            go.Bar(
                x=bin_centers,
                y=adjusted_counts,
                width=bin_width,
                marker=dict(color=color, opacity=0.75),
                name=label,
                legendgroup=label,
                showlegend=showlegend,
            ),
            row=row,
            col=col,
        )

In [66]:
num_bins = 30
num_rows = 2
num_cols = 4

bin_widths = [50, 100, 150, 200]

# PPC, mouse Seven
# performing
samples1 = [
    pre.get_recording_sample(fname="Resulttable_Seven_1504_1813.mat", mouse_name="Seven", bin_width=bin_width) for bin_width in bin_widths
]
samples2 = [
    pre.get_recording_sample(fname="Resulttable_Seven_1504_1850.mat", mouse_name="Seven", bin_width=bin_width) for bin_width in bin_widths
]

# observing
samples3 = [
    pre.get_recording_sample(fname="Resulttable_Seven_1504_1825.mat", mouse_name="Seven", bin_width=bin_width) for bin_width in bin_widths
]
samples4 = [
    pre.get_recording_sample(fname="Resulttable_Seven_1504_1837.mat", mouse_name="Seven", bin_width=bin_width) for bin_width in bin_widths
]


colors = get_colors_from_colormap(len(samples1), 'winter')

labels = [str(bin_width) + " ms" for bin_width in bin_widths]

means1 = [sample.getMeans() for sample in samples1]
means2 = [sample.getMeans() for sample in samples2]
means3 = [sample.getMeans() for sample in samples3]
means4 = [sample.getMeans() for sample in samples4]

pcorrs1 = [sample.getPairwiseCorrs().flatten() for sample in samples1]
pcorrs2 = [sample.getPairwiseCorrs().flatten() for sample in samples2]
pcorrs3 = [sample.getPairwiseCorrs().flatten() for sample in samples3]
pcorrs4 = [sample.getPairwiseCorrs().flatten() for sample in samples4]


fig = go.FigureWidget(
    make_subplots(
        rows=num_rows,
        cols=num_cols,
        #subplot_titles=[key[0] for key in self.layout_spec.keys()],
    )
)

fig.update_layout(
    height=400 * num_rows,
    width=400 * num_cols,
)

plot_histogram(fig, labels, colors, means1, row=1, col=1, num_bins=num_bins)
plot_histogram(fig, labels, colors, means2, row=1, col=2, num_bins=num_bins)
plot_histogram(fig, labels, colors, means3, row=1, col=3, num_bins=num_bins)
plot_histogram(fig, labels, colors, means4, row=1, col=4, num_bins=num_bins)

plot_histogram(fig, labels, colors, pcorrs1, row=2, col=1, num_bins=num_bins)
plot_histogram(fig, labels, colors, pcorrs2, row=2, col=2, num_bins=num_bins)
plot_histogram(fig, labels, colors, pcorrs3, row=2, col=3, num_bins=num_bins)
plot_histogram(fig, labels, colors, pcorrs4, row=2, col=4, num_bins=num_bins)

fig.add_annotation(
    text="Unit, i",  # Replace this with your subtitle text
    xref="paper",  # 'paper' refers to the entire figure from 0 to 1
    yref="paper",
    x=0.475,  # Centered horizontally
    y=-0.1,  # Adjust this value to move the subtitle up or down relative to the bottom
    showarrow=False,  # No arrow pointing to the annotation
    font=dict(
        size=22,  # Adjust font size as needed
        color="black"  # Adjust font color as needed
    ),
    align="center"  # Center the text horizontally
)


# Vertical subtitle on the left
fig.add_annotation(
    text="Relative frequency",
    xref="paper",
    yref="paper",
    x=0.001,
    y=0.5,
    showarrow=False,
    font=dict(size=22, color="black"),
    textangle=-90  # Rotated 90 degrees
)

fig.update_layout(
    height=400 * num_rows,
    width=400 * num_cols,
    margin=dict(l=80, t=40, b=70)  # Adjusted margins
)

display(fig)

In [59]:
fig = go.FigureWidget(
    make_subplots(
        rows=num_rows,
        cols=num_cols,
        #subplot_titles=[key[0] for key in self.layout_spec.keys()],
    )
)

fig.update_layout(
    height=400 * num_rows,
    width=400 * num_cols,
)

plot_histogram(fig, labels, colors, means1, row=1, col=1, num_bins=num_bins)
# plot_histogram(fig, labels, colors, means2, row=1, col=2, num_bins=num_bins)
# plot_histogram(fig, labels, colors, means3, row=1, col=3, num_bins=num_bins)
# plot_histogram(fig, labels, colors, means4, row=1, col=4, num_bins=num_bins)

plot_histogram(fig, labels, colors, pcorrs1, row=2, col=1, num_bins=num_bins)
# plot_histogram(fig, labels, colors, pcorrs2, row=2, col=2, num_bins=num_bins)
# plot_histogram(fig, labels, colors, pcorrs3, row=2, col=3, num_bins=num_bins)
# plot_histogram(fig, labels, colors, pcorrs4, row=2, col=4, num_bins=num_bins)

display(fig)

FigureWidget({
    'data': [{'legendgroup': '50 ms',
              'marker': {'color': '#440154', 'opacity': 0.75},
              'name': '50 ms',
              'showlegend': True,
              'type': 'bar',
              'uid': '6deb82c7-b2be-4ab2-92f9-eb3fb28427a8',
              'width': 0.027393939393939394,
              'x': array([-0.98630303, -0.95890909, -0.93151515, -0.90412121, -0.87672727,
                          -0.84933333, -0.82193939, -0.79454545, -0.76715152, -0.73975758,
                          -0.71236364, -0.6849697 , -0.65757576, -0.63018182, -0.60278788,
                          -0.57539394, -0.548     , -0.52060606, -0.49321212, -0.46581818,
                          -0.43842424, -0.4110303 , -0.38363636, -0.35624242, -0.32884848,
                          -0.30145455, -0.27406061, -0.24666667, -0.21927273, -0.19187879]),
              'xaxis': 'x',
              'y': array([2.38392932e-02, 2.14636993e-01, 3.34833708e-01, 2.40393432e-01,
                  

In [57]:
num_bins = 30
num_rows = 2
num_cols = 4

bin_widths = [50, 100, 150, 200]

# M2, mouse Angie
samples1 = [
    pre.get_recording_sample(fname="RESULTS_Angie_20170825_1158_allbeh_1000s.mat", mouse_name="Angie", bin_width=bin_width) for bin_width in bin_widths
]
samples2 = [
    pre.get_recording_sample(fname="RESULTS_Angie_20170825_1220_allbeh_1000s.mat", mouse_name="Angie", bin_width=bin_width) for bin_width in bin_widths
]
samples3 = [
    pre.get_recording_sample(fname="RESULTS_Angie_20170825_1232_allbeh_1000s.mat", mouse_name="Angie", bin_width=bin_width) for bin_width in bin_widths
]
samples4 = [
    pre.get_recording_sample(fname="RESULTS_Angie_20170825_1248_allbeh_1000s.mat", mouse_name="Angie", bin_width=bin_width) for bin_width in bin_widths
]


colors = get_colors_from_colormap(len(samples1), 'viridis')

labels = [str(bin_width) + " ms" for bin_width in bin_widths]

means1 = [sample.getMeans() for sample in samples1]
means2 = [sample.getMeans() for sample in samples2]
means3 = [sample.getMeans() for sample in samples3]
means4 = [sample.getMeans() for sample in samples4]

pcorrs1 = [sample.getPairwiseCorrs().flatten() for sample in samples1]
pcorrs2 = [sample.getPairwiseCorrs().flatten() for sample in samples2]
pcorrs3 = [sample.getPairwiseCorrs().flatten() for sample in samples3]
pcorrs4 = [sample.getPairwiseCorrs().flatten() for sample in samples4]


fig = go.FigureWidget(
    make_subplots(
        rows=num_rows,
        cols=num_cols,
        #subplot_titles=[key[0] for key in self.layout_spec.keys()],
    )
)

fig.update_layout(
    height=400 * num_rows,
    width=400 * num_cols,
)

plot_histogram(fig, labels, colors, means1, row=1, col=1, num_bins=num_bins)
plot_histogram(fig, labels, colors, means2, row=1, col=2, num_bins=num_bins)
plot_histogram(fig, labels, colors, means3, row=1, col=3, num_bins=num_bins)
plot_histogram(fig, labels, colors, means4, row=1, col=4, num_bins=num_bins)

plot_histogram(fig, labels, colors, pcorrs1, row=2, col=1, num_bins=num_bins)
plot_histogram(fig, labels, colors, pcorrs2, row=2, col=2, num_bins=num_bins)
plot_histogram(fig, labels, colors, pcorrs3, row=2, col=3, num_bins=num_bins)
plot_histogram(fig, labels, colors, pcorrs4, row=2, col=4, num_bins=num_bins)

display(fig)

FigureWidget({
    'data': [{'legendgroup': '50 ms',
              'marker': {'color': '#440154', 'opacity': 0.75},
              'name': '50 ms',
              'showlegend': True,
              'type': 'bar',
              'uid': '9ec776f1-4964-403f-9cd4-f9d18bfdbd63',
              'width': 0.027393939393939394,
              'x': array([-0.98630303, -0.95890909, -0.93151515, -0.90412121, -0.87672727,
                          -0.84933333, -0.82193939, -0.79454545, -0.76715152, -0.73975758,
                          -0.71236364, -0.6849697 , -0.65757576, -0.63018182, -0.60278788,
                          -0.57539394, -0.548     , -0.52060606, -0.49321212, -0.46581818,
                          -0.43842424, -0.4110303 , -0.38363636, -0.35624242, -0.32884848,
                          -0.30145455, -0.27406061, -0.24666667, -0.21927273, -0.19187879]),
              'xaxis': 'x',
              'y': array([2.38392932e-02, 2.14636993e-01, 3.34833708e-01, 2.40393432e-01,
                  

In [None]:
# To-do:
# (1) Plot histograms for means and correlations
# Note: think about what you're going to do here. We want to plot histograms with just the empirical data,
#       and scatters to compare the analytic and simulation based observables. Though, in the latter case, I think
#       we could in principle make a dummy Sample object (though we'd need to make a python class whose syntax 
#       mimics the C++ class for this), with methods getMeans() and getPairwiseCorrs().

# (2) Plot min, mean, and max delayed correlation for different delays (curve -- can use plot generalized)
