In [1]:
import numpy as np
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display, clear_output

# Load data
dat = np.load('/home/maria/LuckyMouse2/pixel_transformer_neuro/data/processed/hybrid_neural_responses_reduced.npy')
num_neurons, num_images = dat.shape

# Convert to probabilities
prob_matrix = dat / 50.0
epsilon = 1e-10
prob_matrix_safe = prob_matrix + epsilon

# Compute TV distance from uniform
uniform_dist = np.full((num_images,), 1.0 / num_images)
tv_dists = []
for neuron_probs in prob_matrix_safe:
    neuron_dist = neuron_probs / np.sum(neuron_probs)
    tv = 0.5 * np.sum(np.abs(uniform_dist - neuron_dist))
    tv_dists.append(tv)
tv_dists = np.array(tv_dists)

# Bin setup
num_bins = 30
bin_edges = np.histogram_bin_edges(tv_dists, bins=num_bins)
bin_indices = np.digitize(tv_dists, bins=bin_edges) - 1

bin_to_neuron_indices = {i: [] for i in range(num_bins)}
for idx, bin_idx in enumerate(bin_indices):
    if 0 <= bin_idx < num_bins:
        bin_to_neuron_indices[bin_idx].append(idx)

bin_counts = [len(bin_to_neuron_indices[i]) for i in range(num_bins)]
bin_labels = [f"{bin_edges[i]:.2f}–{bin_edges[i+1]:.2f}" for i in range(num_bins)]

# Output widget for detail plot
output_plot = widgets.Output()

def on_bar_click(trace, points, state):
    with output_plot:
        clear_output(wait=True)
        if points.point_inds:
            bin_idx = points.point_inds[0]
            neuron_ids = bin_to_neuron_indices.get(bin_idx, [])
            if neuron_ids:
                neuron_id = neuron_ids[0]
                sorted_events = np.sort(dat[neuron_id])
                fig_detail = go.Figure()
                fig_detail.add_trace(go.Bar(
                    x=list(range(len(sorted_events))),
                    y=sorted_events,
                    marker_color='darkgreen'
                ))
                fig_detail.update_layout(
                    title=f"Sorted Event Counts for First Neuron in TV Bin {bin_labels[bin_idx]}",
                    xaxis_title="Sorted Stimulus Index",
                    yaxis_title="Event Count",
                    height=300
                )
                fig_detail.show()
            else:
                print("No neurons in this bin.")

# TV histogram
fig = go.FigureWidget(
    data=[go.Bar(x=bin_labels, y=bin_counts, marker_color='lightgreen')],
    layout=dict(
        title="Total Variation Distance from Uniform Distribution (per Neuron)",
        xaxis_title="TV Distance Bin",
        yaxis_title="Number of Neurons",
        xaxis_tickangle=45,
        bargap=0.1
    )
)

fig.data[0].on_click(on_bar_click)

# Display both
display(fig)
display(output_plot)


FigureWidget({
    'data': [{'marker': {'color': 'lightgreen'},
              'type': 'bar',
              'uid': '32e380fc-b646-4b94-8760-74f5af77dae0',
              'x': [0.07–0.09, 0.09–0.11, 0.11–0.14, 0.14–0.16, 0.16–0.18,
                    0.18–0.21, 0.21–0.23, 0.23–0.25, 0.25–0.28, 0.28–0.30,
                    0.30–0.32, 0.32–0.35, 0.35–0.37, 0.37–0.39, 0.39–0.42,
                    0.42–0.44, 0.44–0.46, 0.46–0.48, 0.48–0.51, 0.51–0.53,
                    0.53–0.55, 0.55–0.58, 0.58–0.60, 0.60–0.62, 0.62–0.65,
                    0.65–0.67, 0.67–0.69, 0.69–0.72, 0.72–0.74, 0.74–0.76],
              'y': [31, 139, 245, 380, 451, 572, 711, 1000, 1465, 2031, 2827,
                    3700, 4830, 4915, 4468, 3560, 2736, 2018, 1214, 828, 503, 262,
                    139, 88, 52, 20, 14, 2, 5, 2]}],
    'layout': {'bargap': 0.1,
               'template': '...',
               'title': {'text': 'Total Variation Distance from Uniform Distribution (per Neuron)'},
               '

Output()