In [1]:
import pandas as pd
import numpy as np

In [2]:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

In [3]:
data=pd.read_json("../all_mean_Swiss-Short.json",orient="records")
data.head()

Unnamed: 0,file_ID,pat_id,sez_id,band,sez_length,no_of_elec,pre_mean,pre_cv,sez_mean,sez_cv,post_mean,post_cv
0,p1s1,1,1,delta,125,47,3.270011,0.13605,4.095254,0.095119,3.741837,0.072901
1,p1s1,1,1,theta,125,47,2.484213,0.148351,2.799338,0.164442,2.670549,0.104244
2,p1s1,1,1,alpha,125,47,1.92694,0.161039,1.90572,0.20408,1.933093,0.163607
3,p1s1,1,1,beta,125,47,1.983676,0.120551,1.820028,0.216146,1.908682,0.129175
4,p1s1,1,1,lgamma,125,47,0.803513,0.242805,0.792638,0.69383,0.731817,0.279288


In [4]:
custom_nanstd = lambda data: np.nanstd(data, ddof=1) if np.count_nonzero(~np.isnan(data)) > 1 else(
    np.unique(data * (~np.isnan(data)))[0] if np.count_nonzero(~np.isnan(data)) == 1 else np.nan
)

In [5]:
pat_grouped_data = data.groupby(['band','pat_id']).agg(
    pre_mean=('pre_mean', 'mean'),
    pre_cv=('pre_cv', lambda x: custom_nanstd(x)),
    sez_mean=('sez_mean', 'mean'),
    sez_cv=('sez_cv', lambda x: custom_nanstd(x)),
    post_mean=('post_mean', 'mean'),
    post_cv=('post_cv', lambda x: custom_nanstd(x)),
    sezs=('file_ID', 'unique'),
    sez_length=('sez_length', 'mean'),
    no_of_elec=('no_of_elec','unique'),
    no_of_sez=('sez_id', 'count')
    )
pat_grouped_data = pat_grouped_data.reset_index()  
pat_grouped_data.head()

Unnamed: 0,band,pat_id,pre_mean,pre_cv,sez_mean,sez_cv,post_mean,post_cv,sezs,sez_length,no_of_elec,no_of_sez
0,alpha,1,1.800235,0.040728,1.862464,0.06404,1.729868,0.063876,"[p1s1, p1s2, p1s3, p1s4, p1s5, p1s6, p1s7, p1s...",71.076923,[47],13
1,alpha,2,1.411227,0.055227,2.57,0.088385,0.662849,0.169988,"[p2s1, p2s2, p2s3, p2s4]",223.25,[42],4
2,alpha,3,1.514411,0.0258,3.2737,0.078552,0.790268,0.016501,"[p3s1, p3s2]",99.0,[98],2
3,alpha,4,2.350336,0.087058,1.945237,0.084471,2.296282,0.083568,"[p4s1, p4s2, p4s3, p4s4, p4s5, p4s6, p4s7, p4s...",97.571429,[62],14
4,alpha,5,1.418611,0.016419,1.93539,0.045092,1.445931,0.009804,"[p5s1, p5s2, p5s3, p5s4, p5s5, p5s6, p5s7, p5s...",98.5,[54],10


In [6]:
band_grouped_data = pat_grouped_data.groupby(['band']).agg(
    pre_mean=('pre_mean', 'mean'),
    pre_cv=('pre_cv', lambda x: custom_nanstd(x)),
    sez_mean=('sez_mean', 'mean'),
    sez_cv=('sez_cv', lambda x: custom_nanstd(x)),
    post_mean=('post_mean', 'mean'),
    post_cv=('post_cv', lambda x: custom_nanstd(x)),
    no_of_pats=('pat_id', 'count'),
    )
band_grouped_data = band_grouped_data.reset_index()  
band_grouped_data

Unnamed: 0,band,pre_mean,pre_cv,sez_mean,sez_cv,post_mean,post_cv,no_of_pats
0,alpha,1.611853,0.022758,2.403951,0.034503,1.307114,0.163595,16
1,beta,1.675046,0.020895,2.485724,0.023995,1.232187,0.207746,16
2,delta,2.864191,0.022369,3.533913,0.01989,3.268874,0.035079,16
3,hgamma,0.052867,0.911383,0.292305,0.588536,0.073173,1.050253,16
4,lgamma,0.539356,0.061809,1.099668,0.064668,0.384157,0.475253,16
5,theta,2.11642,0.023048,2.940845,0.028241,1.947055,0.085602,16


In [7]:
bandds=pd.DataFrame({'name':['delta','theta','alpha','beta','lgamma','hgamma'], 
             'sym':["δ","θ","α","β","Lγ","Hγ"],
             'color':["rgba(139,0,0,0.5)","rgba(255,69,0,0.5)","rgba(0,205,0,0.5)","rgba(0,206,209,0.5)","rgba(105,89,205,0.5)","rgba(238,122,233,0.5)"]})

bandds

Unnamed: 0,name,sym,color
0,delta,δ,"rgba(139,0,0,0.5)"
1,theta,θ,"rgba(255,69,0,0.5)"
2,alpha,α,"rgba(0,205,0,0.5)"
3,beta,β,"rgba(0,206,209,0.5)"
4,lgamma,Lγ,"rgba(105,89,205,0.5)"
5,hgamma,Hγ,"rgba(238,122,233,0.5)"


In [8]:
def mappi(fig,bd,ridx,cidx,p1,p2,p3):
    x_coords = [0, 1, 2]
    means=band_grouped_data.query('band=="'+bd+'"')[['pre_mean','sez_mean','post_mean']].iloc[0].values
    #sems=band_grouped_data.query('band=="'+bd+'"')[['pre_cv','sez_cv','post_cv']].iloc[0].values
    
    fig.add_trace(go.Scatter(x=x_coords,y=means,mode='lines+markers',
                             #error_y=dict(type='data', array=sems,color='rgba(0,0,0,0.5)', thickness=1.5, width=10),
                             name=bandds[bandds.name==bd].sym.values[0],
                             marker=dict(color=bandds[bandds.name==bd].color.values[0],size=10),
                             line=dict(color=bandds[bandds.name==bd].color.values[0],width=3)
                             ),row=ridx, col=cidx)
    fig.update_xaxes(tickvals=x_coords,ticktext=["pre", "seizure", "post"],row=ridx, col=cidx)
        
    p_value_brackets = [
            {'x_coords': [0,0,0.99,0.99],
            'y_coords': [means[0]  +0.05, means[1] + +0.1 ,
                        means[1]  +0.1 , means[1] + 0.05 ],
            'label': p1},
            {'x_coords': [1.01,1.01,2,2],
            'y_coords': [means[1] +  0.05, means[1]  + 0.1,
                        means[1] +  0.1, means[2]  + 0.05],
            'label': p2},
            {'x_coords': [0,0,2,2],
            'y_coords': [means[1]  + 0.15, means[1]  +0.25,
                        means[1]  + 0.25, means[1]  + 0.15 ],
            'label': p3}
            ] 
    if (ridx*cidx!=6):
        # Add p-value brackets and labels
        for bracket in p_value_brackets:
            for i in range(1, len(bracket['x_coords'])):
                fig.add_shape(dict(type="line", xref="x", yref="y",
                    x0=bracket['x_coords'][i - 1],x1=bracket['x_coords'][i],
                    y0=bracket['y_coords'][i - 1],y1=bracket['y_coords'][i]),
                    line=dict(color='rgba(0,0,0,1)', width=1.5),
                    row=ridx, col=cidx)
            fig.add_annotation(dict(
                text=bracket['label'],name="p-value",xref="x", 
                x=(bracket['x_coords'][0] + bracket['x_coords'][2]) / 2,
                y=bracket['y_coords'][1]+0.08,
                showarrow=False),
                font=dict(size=30, color="blue",family="Times new Roman"),
                row=ridx, col=cidx)
    return fig

In [10]:
fig = make_subplots(rows=3, cols=2,vertical_spacing=0.05,horizontal_spacing=0.03,shared_xaxes=True,
                    subplot_titles=(
                            "Delta (δ) Band",
                            "Theta (θ) Band",
                            "Alpha (α) Band",
                            "Beta (β) Band",
                            "Lgamma (Lγ) Band",
                            "Hgamma (Hγ) Band"))


mappi(fig,"delta",1,1,"<0.001","ns","<0.001")
mappi(fig,"theta",1,2,"<0.001","<0.001","ns")
mappi(fig,"alpha",2,1,"<0.001","<0.001","ns")
mappi(fig,"beta",2,2,"<0.001","<0.001","<0.001")
mappi(fig,"lgamma",3,1,"<0.001","<0.001","ns")
mappi(fig,"hgamma",3,2,"<0.001","<0.001","ns")

means=band_grouped_data.query('band=="hgamma"')[['pre_mean','sez_mean','post_mean']].iloc[0].values
p_value_brackets = [
            {'x_coords': [0,0,0.99,0.99],
            'y_coords': [means[0]  +0.03, means[1] + +0.04 ,
                        means[1]  +0.04 , means[1] + 0.03 ],
            'label': "<0.001"},
            {'x_coords': [1.01,1.01,2,2],
            'y_coords': [means[1] +  0.03, means[1]  + 0.04,
                        means[1] +  0.04, means[2]  + 0.03],
            'label': "<0.001"},
            {'x_coords': [0,0,2,2],
            'y_coords': [means[1]  + 0.09, means[1]  +0.11,
                        means[1]  + 0.11, means[1]  + 0.09 ],
            'label': "0.371"}] 

for bracket in p_value_brackets:
        for i in range(1, len(bracket['x_coords'])):
                fig.add_shape(dict(type="line", xref="x", yref="y",
                    x0=bracket['x_coords'][i - 1],x1=bracket['x_coords'][i],
                    y0=bracket['y_coords'][i - 1],y1=bracket['y_coords'][i]),
                    line=dict(color='rgba(0,0,0,1)', width=1.5),
                    row=3, col=2)
        fig.add_annotation(dict(
                text=bracket['label'],name="p-value",xref="x", 
                x=(bracket['x_coords'][0] + bracket['x_coords'][2]) / 2,
                y=bracket['y_coords'][1]+0.03,
                showarrow=False),
                font=dict(size=30, color="blue",family="Times new Roman"),
                row=3, col=2)



fig.update_yaxes(title_text="average AE",range=(2.65,3.95),row=1, col=1)
fig.update_yaxes(range=(1.6,3.4),row=1, col=2)
fig.update_yaxes(title_text="average AE",range=(1.1,2.9),row=2, col=1)
fig.update_yaxes(range=(1.1,2.9),row=2, col=2)
fig.update_yaxes(title_text="average AE",range=(0.2,1.5),row=3, col=1)
fig.update_yaxes(range=(0.03,0.47),row=3, col=2)

# Add shapes
# fig.update_layout(
#     shapes=[
#         dict(type="rect", xref="x1",yref="y1",
#             x0=0, y0=2.65, x1=2, y1=3.95, line_width=3),
#         dict(type="rect", xref="x2", yref='y2',
#              x0=0, y0=1.6, x1=2, y1=3.4),
#         dict(type="rect", xref="x3", yref="y3",
#              x0=0, y0=1.1, x1=2, y1=2.9),
#         dict(type="rect", xref="x4", yref="y4",
#              x0=0, y0=1.1, x1=2, y1=2.9),
#         dict(type="rect", xref="x5", yref="y5",
#              x0=0, y0=0.2, x1=2, y1=1.5),
#         dict(type="rect", xref="x6", yref="y6",
#              x0=0, y0=0.03, x1=2, y1=0.47)
#         ])

fon_sz=30;
fig.update_layout(
        legend=dict(orientation="h",yanchor="bottom",y=1.02,xanchor="center",x=0.5,font=dict(size=40)),
        template="simple_white",
        font_family="Times new Roman",font_color="black",font_size=fon_sz,height=1200,width=2200)
fig.update_annotations(font=dict(family="Times new Roman", size=30))
fig.for_each_xaxis(lambda x: x.update(showgrid=True))
fig.for_each_yaxis(lambda x: x.update(showgrid=True))


fig.show()
fig.write_image("../Link to img/Fig2.png")
fig.write_html("../images/interactive/Fig2.html")


Suplimentary Figure 2S1

In [33]:
# for pid in range(1,17):        
#         x_coords = [0, 1, 2]
#         s=pat_grouped_data.query('pat_id=='+str(pid))[['band','pre_mean','sez_mean','post_mean']]
#         fig = make_subplots(rows=3, cols=2,vertical_spacing=0.05,horizontal_spacing=0.03,shared_xaxes=True,subplot_titles=(
#                                 "Delta (δ) Band",
#                                 "Theta (θ) Band",
#                                 "Alpha (α) Band",
#                                 "Beta (β) Band",
#                                 "Lgamma (Lγ) Band",
#                                 "Hgamma (Hγ) Band"))
        
#         bad=["delta","theta","alpha","beta","lgamma","hgamma"]
#         bd_idx=0

#         for ridx in range(1,4):
#                 for cidx in range(1,3):
#                         bd=bad[bd_idx]
#                         means=s.query('band=="'+bd+'"')[['pre_mean','sez_mean','post_mean']].iloc[0].values
#                         fig.add_trace(go.Scatter(x=x_coords,y=means,mode='lines+markers',
#                                         #error_y=dict(type='data', array=sems,color='rgba(0,0,0,0.5)', thickness=1.5, width=10),
#                                         name=bandds[bandds.name==bd].sym.values[0],
#                                         marker=dict(color=bandds[bandds.name==bd].color.values[0],size=10),
#                                         line=dict(color=bandds[bandds.name==bd].color.values[0],width=3)
#                                         ),row=ridx, col=cidx)
#                         fig.update_xaxes(tickvals=x_coords,ticktext=["pre", "seizure", "post"],row=ridx, col=cidx)
#                         bd_idx+=1

#         fig.update_yaxes(title_text="average AE",row=1, col=1)
#         fig.update_yaxes(title_text="average AE",row=2, col=1)
#         fig.update_yaxes(title_text="average AE",row=3, col=1)

#         # Add shapes
#         # fig.update_layout(
#         #     shapes=[
#         #         dict(type="rect", xref="x1",yref="y1",
#         #             x0=0, y0=2.65, x1=2, y1=3.95, line_width=3),
#         #         dict(type="rect", xref="x2", yref='y2',
#         #              x0=0, y0=1.6, x1=2, y1=3.4),
#         #         dict(type="rect", xref="x3", yref="y3",
#         #              x0=0, y0=1.1, x1=2, y1=2.9),
#         #         dict(type="rect", xref="x4", yref="y4",
#         #              x0=0, y0=1.1, x1=2, y1=2.9),
#         #         dict(type="rect", xref="x5", yref="y5",
#         #              x0=0, y0=0.2, x1=2, y1=1.5),
#         #         dict(type="rect", xref="x6", yref="y6",
#         #              x0=0, y0=0.03, x1=2, y1=0.47)
#         #         ])

#         fon_sz=30;
#         fig.update_layout(
#                 title_text="Patient ID"+str(pid),
#                 legend=dict(orientation="h",yanchor="bottom",y=1.02,xanchor="center",x=0.5,font=dict(size=40)),
#                 template="simple_white",
#                 font_family="Times new Roman",font_color="black",font_size=fon_sz,height=1200,width=2200)
#         fig.update_annotations(font=dict(family="Times new Roman", size=30))
#         fig.for_each_xaxis(lambda x: x.update(showgrid=True))
#         fig.for_each_yaxis(lambda x: x.update(showgrid=True))

                
#         #fig.show()    
#         fig.write_image("../Link to img/sup/fig2s/Fig4S3_ID"+str(pid)+".png")
#         fig.write_image("../images/SUP/fig2s/Fig4S3_ID"+str(pid)+".png")   