In [1]:
import ipywidgets as widgets
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
import torch
from IPython.display import display
from plotly.subplots import make_subplots

pio.renderers.default = "jupyterlab+png"

In [2]:
def plot_rsa(path, area, noise_corrected=False, show_legend=False):
    data = torch.load(path)
    rsa = data['rsa']
    noise_ceiling = data['nc']
    
    data = []
    for species, kt in rsa.items():
        df_model = pd.DataFrame(kt)
        df_med = df_model.median(axis=0)
        df_err = df_model.std(axis=0)
        df_model = pd.concat([df_med, df_err], axis=1)
        df_model.columns = ['RSA', 'Error']
        df_model['Layer'] = df_model.index
        df_model['Model'] = species
        df_model['Path'] = [layer[:2] if layer[:2] in ['p1', 'p2'] else '' for layer in df_model['Layer']]
        df_model['Layer'] = [layer[3:] if layer[:2] in ['p1', 'p2'] else layer for layer in df_model['Layer']]
        df_model.reset_index(drop=True, inplace=True)
        data.append(df_model)

    df = pd.concat(data)
    df.loc[df['Layer'] == 'pixel', 'Model'] = 'Baseline'
    df = df.drop_duplicates(keep='first')
    df['Model_Path'] = df['Model'] + df['Path'].replace({'': '', 'p1': ' P1', 'p2': ' P2'})
    if noise_ceiling is not None and noise_corrected:
        df['RSA'] /= np.median(noise_ceiling)
        
    model_names = df['Model_Path'].unique()
    complete_index = pd.MultiIndex.from_product([model_names, df['Layer']], names=['Model_Path', 'Layer'])
    df.set_index(['Model_Path', 'Layer'], inplace=True)
    df = df.reindex(complete_index).reset_index()
    
    # Dodge setup
    layer_order = {layer: i for i, layer in enumerate(df['Layer'].unique())}
    df['Layer_Index'] = df['Layer'].map(layer_order)
    
    dodge_amount = 0.4  # Adjust this value as needed for clarity
    dp = np.linspace(-dodge_amount/2, dodge_amount/2, 4)
    dodge_positions = [0, dp[0], dp[0], dp[1], dp[3], dp[2], dp[3]]
    
    # Create a Plotly figure
    fig = go.Figure()
    colors = ['black', 'darkblue', 'blue', 'deepskyblue', 'darkred', 'red', 'orange']
    color_dict = {model: colors[i % len(colors)] for i, model in enumerate(model_names)}
    
    for i, model_path in enumerate(model_names):
        filtered_df = df[df['Model_Path'] == model_path].copy()
        filtered_df['DodgeX'] = filtered_df['Layer_Index'] + dodge_positions[i]
        fig.add_trace(go.Scatter(
            x=filtered_df['DodgeX'],
            y=filtered_df['RSA'],
            error_y=dict(type='data', array=filtered_df['Error'], visible=True),
            mode='markers',
            name=model_path,
            marker=dict(color=color_dict[model_path]),
            showlegend=show_legend,
        ))

    # Adding noise ceiling if applicable
    if noise_ceiling is not None and not noise_corrected:
        med = np.median(noise_ceiling)
        sd = np.std(noise_ceiling)
        fig.add_trace(go.Scatter(x=list(layer_order.values()), y=[med]*len(df['Layer'].unique()), mode='lines',
                                 name='Noise Ceiling Median', line=dict(dash='dash', color='black'), showlegend=show_legend,))
        fig.add_traces([
            go.Scatter(x=list(layer_order.values()), y=[med+sd]*len(df['Layer'].unique()), fill='tonexty', mode='lines',
                       line=dict(color='gray'), name='Noise Ceiling SD', showlegend=show_legend,),
            go.Scatter(x=list(layer_order.values()), y=[med-sd]*len(df['Layer'].unique()), fill=None, mode='lines',
                       line=dict(color='gray'), showlegend=False,),
        ])

    fig.update_layout(title=f'{area}',
                      xaxis=dict(
                          tickvals=list(layer_order.values()),
                          ticktext=list(layer_order.keys())
                      ))
    return fig, layer_order

In [None]:
# Assuming 'rsa' data is structured appropriately and your plot_rsa function is adapted to return dataframes or similar data structures.
areas = ['VISl', 'VISp', 'VISpm', 'VISal', 'VISam']
depths = [175, 275]
stim_types = ['natural_movie_one', 'natural_scenes']
noise_options = ['Noise Ceiling', 'Noise Corrected']
cre_line = 'Cux2-CreERT2'
seed = 42

# Create widgets for depth and noise options
depth_dropdown = widgets.Dropdown(options=depths, value=175, description='Depth:')
stim_dropdown = widgets.Dropdown(options=stim_types, value='natural_movie_one', description='Stimulus:')
noise_toggle = widgets.ToggleButtons(options=noise_options, description='Noise:')

def update_plots(change):
    # Create a subplot layout
    fig = make_subplots(rows=2, cols=3, subplot_titles=areas + [''] * (6 - len(areas)), vertical_spacing=0.05)
    
    stim_type = stim_dropdown.value
    depth = depth_dropdown.value
    noise_corrected = noise_toggle.value == 'Noise Corrected'
    
    for i, area in enumerate(areas, start=1):
        path = f'../results/{area}/{stim_type}/{depth}_{cre_line}_{seed}.pt'
        subfig, layer_order = plot_rsa(path, area, noise_corrected=noise_corrected, show_legend=(i == 1))
        traces = subfig.data if noise_corrected else reversed(subfig.data)
        for trace in traces:
            fig.add_trace(trace, row=(i-1)//3 + 1, col=(i-1)%3 + 1)
            
    fig.update_xaxes(ticklabelstep=1)
    fig.update_xaxes(tickvals=list(layer_order.values()), ticktext=list(layer_order.keys()))
    fig.update_xaxes(title_text="", showticklabels=False, row=1, col=1)
    fig.update_xaxes(title_text="", showticklabels=False, row=1, col=2)
    fig.update_xaxes(title_text="Layer", showticklabels=True, row=1, col=3)
    fig.update_xaxes(title_text="Layer", showticklabels=True, row=2, col=1)
    fig.update_xaxes(title_text="Layer", showticklabels=True, row=2, col=2)
    fig.update_yaxes(title_text='RDM similarity<br>(noise corrected)' if noise_corrected else 'RDM similarity', showticklabels=True, row=1, col=1)
    fig.update_yaxes(title_text='RDM similarity<br>(noise corrected)' if noise_corrected else 'RDM similarity', showticklabels=True, row=2, col=1)
    fig.update_layout(height=1200, width=1600, 
                      title_text=f'Representational Similarity Analysis between Mouse Visual Areas and ANNs trained with SSL<br>'
                                 f'cre_line: {cre_line}, depth: {depth}, stimulus: {stim_type}', 
                      title_x=0.5,
                      legend_title='Model')
    with output:
        output.clear_output(wait=True)
        fig.show("notebook")
        
# Observe changes in widgets and update the plots accordingly
depth_dropdown.observe(update_plots, names='value')
stim_dropdown.observe(update_plots, names='value')
noise_toggle.observe(update_plots, names='value')

# Set up a container for the widgets
widgets_container = widgets.VBox([depth_dropdown, stim_dropdown, noise_toggle])

# Layout to display widgets and plot output
output = widgets.Output()
display(widgets_container, output)

# Initialize the plots with default values
update_plots(None)

VBox(children=(Dropdown(description='Depth:', options=(175, 275), value=175), Dropdown(description='Stimulus:'…

Output()