In [1]:
import pandas as pd
import visualizations
import plotly.graph_objects as go

### Extended Data 6b

In [None]:
fig_panel = 'ExtData6b'
agg_data = pd.read_csv('../ExtendedData6/{}.csv'.format(fig_panel))
agg_data = agg_data[agg_data['CutoffFraction'] == 10] # use the 10% of cells cutoff
agg_data['FractionOfGadCells'] *= 100

groupby = 'Ensemble'
plot_var = 'FractionOfGadCells'

colors = ['rgb(236,34,42)', 'rgb(144,96,129)', 'rgb(0,50,255)', 'rgb(255,255,255)']
agg_data[groupby] = pd.Categorical(agg_data[groupby], categories=['Aversive', 'Overlap', 'Neutral', 'Remaining'])

visualizations.plotMeanData(agg_data=agg_data, groupby=groupby, plot_var=plot_var, colors=colors, opacity=1,
                            y_title='% of GAD Cells<br>in Ensemble', y_range=(0,75), plot_datalines=True, plot_title=fig_panel)

### Extended Data 6c

In [None]:
fig_panel = 'ExtData6c'
agg_data = pd.read_csv('../ExtendedData6/{}.csv'.format(fig_panel))
agg_data = agg_data[agg_data['CutoffFraction'] == 10] # use the 10% of cells cutoff
agg_data['FractionOfEnsembleCells'] *= 100

groupby = 'Ensemble'
plot_var = 'FractionOfEnsembleCells'

colors = ['rgb(236,34,42)', 'rgb(144,96,129)', 'rgb(0,50,255)', 'rgb(255,255,255)']
agg_data[groupby] = pd.Categorical(agg_data[groupby], categories=['Aversive', 'Overlap', 'Neutral', 'Remaining'])

visualizations.plotMeanData(agg_data=agg_data, groupby=groupby, plot_var=plot_var, colors=colors, opacity=1,
                            y_title='% of Ensemble<br>Putative GAD+', y_range=(0,30), plot_datalines=True, plot_title=fig_panel)

### Extended Data 6d

In [None]:
fig_panel = 'ExtData6d'
agg_data = pd.read_csv('../ExtendedData6/{}.csv'.format(fig_panel))

mean_data = agg_data.groupby(['Time']).mean().reset_index().sort_values('Time')
time = mean_data['Time']
mean = mean_data['AverageActivity']
sem  = agg_data.groupby(['Time']).sem().reset_index().sort_values('Time')['AverageActivity'].values

fig = go.Figure()

for mouse in agg_data['Mouse'].unique():
    mouse_data = agg_data[agg_data['Mouse'] == mouse].sort_values('Time')
    fig.add_trace(go.Scattergl(x=mouse_data['Time'], y=mouse_data['AverageActivity'], mode='lines', line=dict(color='slategrey', width=1), showlegend=False))

fig.add_trace(go.Scatter(x=time, y=(mean + sem),
                            mode='lines', fill=None, line=dict(color='black', width=0), hoverinfo='skip', showlegend=False))
fig.add_trace(go.Scatter(x=time, y=(mean - sem),
                            mode='lines', fill='tonexty', line=dict(color='black', width=0), hoverinfo='skip', showlegend=False))
fig.add_trace(go.Scattergl(x=time, y=mean, mode='lines', line=dict(color='black', width=2), showlegend=False))

fig.update_layout(template='simple_white', width=500, height=500, font=dict(size=20, family='Arial'), title_text=fig_panel,
                  xaxis_title='Time (sec)', yaxis_title='Mean Population Activity<br>Around Bursts (Z-Score)')
fig.show()

### Extended Data 6e

In [None]:
fig_panel = 'ExtData6e'
agg_data = pd.read_csv('../ExtendedData6/{}.csv'.format(fig_panel))

mean_data = agg_data.groupby(['Time']).mean().reset_index().sort_values('Time')
time = mean_data['Time']
mean = mean_data['AverageLocomotion']
sem  = agg_data.groupby(['Time']).sem().reset_index().sort_values('Time')['AverageLocomotion'].values

fig = go.Figure()

for mouse in agg_data['Mouse'].unique():
    mouse_data = agg_data[agg_data['Mouse'] == mouse].sort_values('Time')
    fig.add_trace(go.Scattergl(x=mouse_data['Time'], y=mouse_data['AverageLocomotion'], mode='lines', line=dict(color='slategrey', width=1), showlegend=False))

fig.add_trace(go.Scatter(x=time, y=(mean + sem),
                            mode='lines', fill=None, line=dict(color='black', width=0), hoverinfo='skip', showlegend=False))
fig.add_trace(go.Scatter(x=time, y=(mean - sem),
                            mode='lines', fill='tonexty', line=dict(color='black', width=0), hoverinfo='skip', showlegend=False))
fig.add_trace(go.Scattergl(x=time, y=mean, mode='lines', line=dict(color='black', width=2), showlegend=False))

fig.update_layout(template='simple_white', width=500, height=500, font=dict(size=20, family='Arial'), title_text=fig_panel,
                  xaxis_title='Time (sec)', yaxis_title='Mean Locomotion<br>Around Bursts (A.U.)')
fig.show()

### Extended Data 6f

In [None]:
fig_panel = 'ExtData6f'
agg_data = pd.read_csv('../ExtendedData6/{}.csv'.format(fig_panel))

groupby = 'EnsembleCombo'
plot_var = 'Fraction'

colors = ['rgb(0,50,255)', 'rgb(236,34,42)', 'rgb(144,96,129)']
agg_data[groupby] = pd.Categorical(agg_data[groupby], categories=['Neutral', 'Aversive', 'Overlap'])

visualizations.plotMeanData(agg_data=agg_data, groupby=groupby, plot_var=plot_var, colors=colors, opacity=1,
                            y_title='% of Events', y_range=(0,25), plot_datalines=True, plot_title=fig_panel)

### Extended Data 6g

In [None]:
fig_panel = 'ExtData6g'
agg_data = pd.read_csv('../ExtendedData6/{}.csv'.format(fig_panel))

groupby = 'EnsembleCombo'
plot_var = 'Fraction'

colors = 'slategrey'
agg_data[groupby] = pd.Categorical(agg_data[groupby], categories=['Overlap x Neutral', 'Overlap x Aversive', 'Neutral x Aversive', 'Overlap x Neutral x Aversive'])

visualizations.plotMeanData(agg_data=agg_data, groupby=groupby, plot_var=plot_var, colors=colors, opacity=1,
                            y_title='% of Events', y_range=(0,25), plot_width=500, plot_height=600, plot_datalines=True, plot_title=fig_panel)

### Extended Data 6h

In [None]:
fig_panel = 'ExtData6h'
agg_data = pd.read_csv('../ExtendedData6/{}.csv'.format(fig_panel))

time_var = 'Bin'
plot_var = 'Participation'

colors = ['black']
agg_data[time_var] = pd.Categorical(agg_data[time_var])

visualizations.plotAcrossTime(agg_data=agg_data, time_var=time_var, plot_var=plot_var, colors=colors, add_hline=False,
                              y_title='% Burst Participation', x_title='Cells Active During Offline<br>(Sorted by Chemotag Response)', plot_datalines=True, title=fig_panel)

### Extended Data 6i

In [None]:
fig_panel = 'ExtData6i'
agg_data = pd.read_csv('../ExtendedData6/{}.csv'.format(fig_panel))

groupby = 'TrueVsShuffle'
plot_var = 'Accuracy'

colors = ['steelblue', 'slategrey']
agg_data[groupby] = pd.Categorical(agg_data[groupby], categories=['True', 'Shuffle'])

visualizations.plotMeanData(agg_data=agg_data, groupby=groupby, plot_var=plot_var, colors=colors, opacity=1,
                            y_title='Accuracy', y_range=(0,1), plot_datalines=True, plot_title=fig_panel)

### Extended Data 6j

In [None]:
fig_panel = 'ExtData6j'
agg_data = pd.read_csv('../ExtendedData6/{}.csv'.format(fig_panel))

time_var = 'Bin'
groupby = 'TrueVsShuffle'
plot_var = 'Accuracy'

colors = ['steelblue', 'slategrey']
agg_data[time_var] = pd.Categorical(agg_data[time_var])
agg_data[groupby] = pd.Categorical(agg_data[groupby], categories=['True', 'Shuffle'])

visualizations.plotAcrossTime(agg_data=agg_data, time_var=time_var, groupby=groupby, plot_var=plot_var, colors=colors, add_hline=False,
                              y_title='SVM Accuracy', x_title='Cells Active During Offline<br>(Sorted by Chemotag Response)', plot_datalines=False, title=fig_panel)