In [1]:
# IMPORTS
# ---------
# Stop warnings
import warnings
warnings.filterwarnings("ignore")

# General imports
import os
import sys
import json
import glob
import numpy as np
import h5py
import pandas as pd
opj = os.path.join

# Plotly imports
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.offline as pxo

In [2]:
# FUNCTIONS
# ---------
def weighted_regression(x_reg,y_reg,weight_reg):
    """
    Function to compute regression parameter weighted by a matrix (e.g. r2 value).

    Parameters
    ----------
    x_reg : array (1D)
        x values to regress
    y_reg : array
        y values to regress
    weight_reg : array (1D) 
        weight values (0 to 1) for weighted regression

    Returns
    -------
    coef_reg : array
        regression coefficient
    intercept_reg : str
        regression intercept
    """

    from sklearn import linear_model
    regr = linear_model.LinearRegression()
    
    def m(x, w):
        return np.sum(x * w) / np.sum(w)

    def cov(x, y, w):
        # see https://www2.microstrategy.com/producthelp/archive/10.8/FunctionsRef/Content/FuncRef/WeightedCov__weighted_covariance_.htm
        return np.sum(w * (x - m(x, w)) * (y - m(y, w))) / np.sum(w)

    def weighted_corr(x, y, w):
        # see https://www2.microstrategy.com/producthelp/10.4/FunctionsRef/Content/FuncRef/WeightedCorr__weighted_correlation_.htm
        return cov(x, y, w) / np.sqrt(cov(x, x, w) * cov(y, y, w))

    x_reg_nan = x_reg[(~np.isnan(x_reg) & ~np.isnan(y_reg))]
    y_reg_nan = y_reg[(~np.isnan(x_reg) & ~np.isnan(y_reg))]
    weight_reg_nan = weight_reg[~np.isnan(weight_reg)]

    regr.fit(x_reg_nan.reshape(-1, 1), y_reg_nan.reshape(-1, 1),weight_reg_nan)
    coef_reg, intercept_reg = regr.coef_, regr.intercept_

    return coef_reg, intercept_reg

def rgb2rgba(input_col,alpha_val):
    """
    Functionto add an alpha value to color input in plotly

    Parameters
    ----------
    input_col : str
        color value (e.g. 'rgb(200,200,200)')
    alapha_val : float
        transparency valu (0 > 1.0)
    
    Returns
    -------
    rgba_col : str
        color value in rgba (e.g. 'rgba(200,200,200,0.5)')
    """

    rgba_col = "rgba{}, {})".format(input_col[3:-1],alpha_val)
    return rgba_col

def adjust_lightness(input_rgb, amount=0.5):
    """
    Function to change lightness of a specific rgb color

    Parameters
    ----------
    input_rgb : str
        color value (e.g. 'rgb(200,200,200)')
    amount : float
        amount of lightness change (-1.0 to 1.0)
    
    Returns
    -------
    output_col : str
        color value in rgb (e.g. 'rgba(200,200,200)')
    """

    import colorsys
    import matplotlib.colors as mc
    c = np.array(list(map(float, input_rgb[4:-1].split(','))))/255
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    r,g,b=colorsys.hls_to_rgb(c[0], max(0, min(1,  amount* c[1])), c[2])
    r,g,b=int(np.round(r*255,0)),int(np.round(g*255,0)),int(np.round(b*255,0))

    output_col = "rgb({},{},{})".format(r,g,b)
    return output_col

def plotly_template(template_specs):
    """
    Define the template for plotly

    Parameters
    ----------
    template_specs : dict
        dictionary contain specific figure settings
    
    Returns
    -------
    fig_template : 
        Template for plotly figure
    """
    import plotly.graph_objects as go
    fig_template=go.layout.Template()

    # Violin plots
    fig_template.data.violin = [go.Violin(
                                    box_visible=False,
                                    points=False,
                                    opacity=1,
                                    line_color= "rgba(0, 0, 0, 1)",
                                    line_width=template_specs['plot_width'],
                                    width=0.8,
                                    marker_symbol='x',
                                    marker_opacity=0.5,
                                    hoveron='violins',
                                    meanline_visible=True,
                                    meanline_color="rgba(0, 0, 0, 1)",
                                    meanline_width=template_specs['plot_width'],
                                    showlegend=False,
                                    )]

    fig_template.data.barpolar = [go.Barpolar(
                                    marker_line_color="rgba(0,0,0,1)",
                                    marker_line_width=template_specs['plot_width'], 
                                    showlegend=False, 
                                    thetaunit = 'radians'
                                    )]

    # Pie plots
    fig_template.data.pie = [go.Pie(showlegend=False,
                                    textposition=["inside","none"],
                                    marker_line_color=['rgba(0,0,0,1)','rgba(255,255,255,0)'],
                                    marker_line_width=[template_specs['plot_width'],0],
                                    rotation=0,
                                    direction="clockwise",
                                    hole=0.4,
                                    sort=False,
                                    )]

    # Layout
    fig_template.layout = (go.Layout(# general
                                    font_family=template_specs['font'],
                                    font_size=template_specs['axes_font_size'],
                                    plot_bgcolor=template_specs['bg_col'],

                                    # x axis
                                    xaxis_visible=True,
                                    xaxis_linewidth=template_specs['axes_width'],
                                    xaxis_color= template_specs['axes_color'],
                                    xaxis_showgrid=False,
                                    xaxis_ticks="outside",
                                    xaxis_ticklen=0,
                                    xaxis_tickwidth = template_specs['axes_width'],
                                    xaxis_title_font_family=template_specs['font'],
                                    xaxis_title_font_size=template_specs['title_font_size'],
                                    xaxis_tickfont_family=template_specs['font'],
                                    xaxis_tickfont_size=template_specs['axes_font_size'],
                                    xaxis_zeroline=False,
                                    xaxis_zerolinecolor=template_specs['axes_color'],
                                    xaxis_zerolinewidth=template_specs['axes_width'],
                                    xaxis_range=[0,1],
                                    xaxis_hoverformat = '.1f',
                                    
                                    # y axis
                                    yaxis_visible=False,
                                    yaxis_linewidth=0,
                                    yaxis_color= template_specs['axes_color'],
                                    yaxis_showgrid=False,
                                    yaxis_ticks="outside",
                                    yaxis_ticklen=0,
                                    yaxis_tickwidth = template_specs['axes_width'],
                                    yaxis_tickfont_family=template_specs['font'],
                                    yaxis_tickfont_size=template_specs['axes_font_size'],
                                    yaxis_title_font_family=template_specs['font'],
                                    yaxis_title_font_size=template_specs['title_font_size'],
                                    yaxis_zeroline=False,
                                    yaxis_zerolinecolor=template_specs['axes_color'],
                                    yaxis_zerolinewidth=template_specs['axes_width'],
                                    yaxis_hoverformat = '.1f',

                                    # bar polar
                                    polar_radialaxis_visible = False,
                                    polar_radialaxis_showticklabels=False,
                                    polar_radialaxis_ticks='',
                                    polar_angularaxis_visible = False,
                                    polar_angularaxis_showticklabels = False,
                                    polar_angularaxis_ticks = ''
                                    ))

    # Annotations
    fig_template.layout.annotationdefaults = go.layout.Annotation(
                                    font_color=template_specs['axes_color'],
                                    font_family=template_specs['font'],
                                    font_size=template_specs['title_font_size'])

    return fig_template


In [3]:
# DATA
# ----
# Get inputs
subject = 'sub-001'
task = 'GazeCenterFS'
preproc = 'fmriprep_dct_pca'

# Define analysis parameters
with open('mri_analysis/settings.json') as f:
    json_s = f.read()
    analysis_info = json.loads(json_s)
base_dir = analysis_info['base_dir']
h5_dir = "{base_dir}/pp_data/{subject}/gauss/h5".format(base_dir = base_dir, subject = subject)
    
# load deriv data
rsq_idx, ecc_idx, polar_real_idx, polar_imag_idx , size_idx, \
    amp_idx, baseline_idx, cov_idx, x_idx, y_idx, hemi_idx = 0,1,2,3,4,5,6,7,8,9,10
rsq_dict, ecc_dict, polar_real_dict, polar_imag_dict, size_dict, x_dict, y_dict, hemi_dict = {}, {}, {}, {}, {}, {}, {}, {}
ecc_sample_dict, size_sample_dict = {}, {}

# create raw dataframe
df_raw = pd.DataFrame()
for roi in analysis_info['rois']:
    
    h5_file = h5py.File("{h5_dir}/{roi}_{task}_{preproc}.h5".format(h5_dir = h5_dir, roi = roi, task = task, preproc = preproc),'r')
    deriv_data = h5_file['{folder_alias}/derivatives'.format(folder_alias = 'pRF')]
    df_roi = pd.DataFrame(deriv_data,columns = ['rsq','ecc','polar_real','polar_imag','size','amp','baseline','cov','x','y','hemi'])
    df_roi['roi']=[roi for x in range(df_roi.shape[0])]
    df_roi['subject']=[subject for x in range(df_roi.shape[0])]
    df_roi['task']=[task for x in range(df_roi.shape[0])]
    df_roi['preproc']=[preproc for x in range(df_roi.shape[0])]
    df_raw = pd.concat([df_raw, df_roi],ignore_index=True, axis = 0)

# filter dataframe
rsqr_th, size_th, ecc_th = analysis_info['rsqr_th'], analysis_info['size_th'], analysis_info['ecc_th']
df = df_raw[(df_raw.rsq >= rsqr_th) & 
        (df_raw['size'] >= size_th[0]) & (df_raw['size'] <= size_th[1]) & 
        (df_raw.ecc >= ecc_th[0]) & (df_raw.ecc <= ecc_th[1])]

#df.to_csv('{subject}_{task}_{preproc}.gz'.format(subject=subject, task=task, preproc=preproc),compression='gzip', float_format='%.4f')

In [4]:
template_specs = dict(  axes_color="rgba(0, 0, 0, 1)",          # figure axes color
                        axes_width=2,                           # figureaxes line width
                        axes_font_size=15,                      # font size of axes
                        bg_col="rgba(255, 255, 255, 1)",        # figure background color
                        font='Helvetica',                       # general font used
                        title_font_size=18,                     # font size of titles
                        plot_width=1.5,                           # plot line width
                        )

fig_template = plotly_template(template_specs)



# to do
1. make threshold slider
3. make entries slider to pick the condition to plot


4. make plot with two entries (condition 1 and 2 => split violoin, make two circle, make 2 sets of lines changing hue)


In [7]:
# SETTINGS
# --------
# general figure settings
fig_height, fig_width = 900,1200
rois_colors = px.colors.qualitative.Prism
rois_colors.append('rgb(180, 180, 180)')
rois = analysis_info['rois']
rows, cols = 4, 12
fig_title = '<b>Subject:</b> <i>{subject}</i> | <b>Task:</b> <i>{task}</i> | <b>PP:</b> <i>{preproc}</i>'.format(subject=subject, 
                    task=task, preproc=preproc)

y_label_trace, x_label_trace, trace_range = 'Size (dva)', 'Eccentricity (dva)', [0,12], 
trace_tickvals = np.linspace(trace_range[0],trace_range[1],4)
trace_ticktexts = ['{:g}'.format(x) for x in trace_tickvals]
line_x = np.linspace(trace_range[0], trace_range[1], 60)
bins = 12
bin_angle = 2*np.pi/bins
barpolar_hovertemplate = "Angle: %{text:.0f}°<br>Prop: %{r:.0f}%<extra></extra>"
barpolar_range=[0,30]

# subplot settings
column_widths = [1,1,1,1,1,1,1,1,1,1,1,1,]
row_heights = [4,1,1,4]
sb_specs = [[{},{},{},{},{},{},{},{},{},{},{},{}],
            [{'type':'barpolar'},{'type':'barpolar'},{'type':'barpolar'},{'type':'barpolar'},{'type':'barpolar'},{'type':'barpolar'},
             {'type':'barpolar'},{'type':'barpolar'},{'type':'barpolar'},{'type':'barpolar'},{'type':'barpolar'},{'type':'barpolar'}],
            [{'type':'domain'},{'type':'domain'},{'type':'domain'},{'type':'domain'},{'type':'domain'},{'type':'domain'},
             {'type':'domain'},{'type':'domain'},{'type':'domain'},{'type':'domain'},{'type':'domain'},{'type':'domain'}],
            [{'colspan':4},None,None,None,{'colspan':4},None,None,None,{'colspan':4},None,None,None]]

fig = make_subplots(rows=rows, cols=cols, specs=sb_specs, print_grid=False, vertical_spacing=0.04, horizontal_spacing=0.02,
                    column_widths=column_widths, row_heights=row_heights, shared_yaxes=True)

cols_violin, rows_violin        = [1,2,3,4,5,6,7,8,9,10,11,12], [1,1,1,1,1,1,1,1,1,1,1,1]
cols_barpolar, rows_barpolar    = [1,2,3,4,5,6,7,8,9,10,11,12], [2,2,2,2,2,2,2,2,2,2,2,2]
cols_pie, rows_pie              = [1,2,3,4,5,6,7,8,9,10,11,12], [3,3,3,3,3,3,3,3,3,3,3,3]
cols_trace, rows_trace          = [1,1,1,1,5,5,5,5,9, 9, 9, 9], [4,4,4,4,4,4,4,4,4,4,4,4]

# DRAWING
# -------
for num,(roi,roi_color) in enumerate(zip(rois,rois_colors)):
    
    # r2 violin plots
    fig.append_trace(go.Violin( y= df[df.roi==roi].rsq, name=roi, span=[0, 1], orientation= "v", spanmode='manual', fillcolor=roi_color), 
                                row=rows_violin[num], col=cols_violin[num])

    # polar angle
    pol_angles = np.angle(df[df.roi==roi].polar_real + 1j * df[df.roi==roi].polar_imag)
    hist, bin_edges = np.histogram(a=pol_angles, range=(-np.pi,np.pi), bins = bins, weights=df[df.roi==roi].rsq)
    fig.append_trace(go.Barpolar(r=(hist/np.nansum(hist))*100, theta=bin_edges, width=np.ones_like(hist)*bin_angle, 
                                 text=np.rad2deg(bin_edges),hovertemplate=barpolar_hovertemplate, marker_color=roi_color),
                                 row=rows_barpolar[num], col=cols_barpolar[num])
    
    # contra-laterality ratio

    cl_ratio = np.mean([df[(df.hemi==2) & (df.x < 0) & (df.roi==roi)].rsq.sum() / df[(df.hemi==2) & (df.roi==roi)].rsq.sum(),
                        df[(df.hemi==1) & (df.x > 0) & (df.roi==roi)].rsq.sum() / df[(df.hemi==1) & (df.roi==roi)].rsq.sum()])

    fig.append_trace(go.Pie(labels=["Contra-lateral","Ipsi-lateral"],hoverinfo='label+percent', values=[cl_ratio,1-cl_ratio],
                            marker_colors=[roi_color,'rgba(255,255,255,0)']),row=rows_pie[num], col=cols_pie[num])

    # # eccentricity size scatter of sampled data
    idx_sample = df.index[df.roi==roi][np.random.permutation(df[df.roi==roi].shape[0])[0:analysis_info['sample_num']].tolist()]
    fig.append_trace(go.Scatter(x=df.ecc.loc[idx_sample], y=df['size'].loc[idx_sample], mode='markers', showlegend=False, hoverinfo='none',
                                marker_symbol='circle', marker_size=10, marker_color=adjust_lightness(roi_color, amount=1.25),
                                marker_line_color='black',marker_line_width = 0.5, marker_opacity=0.2),row=rows_trace[num], col=cols_trace[num])

# eccentricity/size
for num,(roi,roi_color) in enumerate(zip(rois,rois_colors)):    
    ecc_size_coeff, ecc_size_intercept = weighted_regression(np.array(df[df.roi==roi].ecc),np.array(df[df.roi==roi]['size']),np.array(df[df.roi==roi].rsq))
    line_y = ecc_size_coeff*line_x+ecc_size_intercept
    fig.append_trace(go.Scatter(x=line_x, y=line_y[0], name = roi, mode='lines', line_width=4, line_color=roi_color, showlegend=False)
                                ,row=rows_trace[num], col=cols_trace[num])

# annotations
fig.add_annotation(xref="paper", yref="paper", x=-0.082, y=0.56, text='Polar<br>angle', showarrow=False, textangle=-90)
fig.add_annotation(xref="paper", yref="paper", x=-0.082, y=0.435, text='Contra-<br>laterality', showarrow=False, textangle=-90)
fig.add_annotation(xref="paper", yref="paper", x=-0, y=1.05, text=fig_title, showarrow=False)

# LAYOUT
# ------
fig.layout.update(  # figure settings
                    template=fig_template, width=fig_width, height=fig_height, margin_l=100, margin_r=20, margin_t=50, margin_b=100,
                    # range violin
                    yaxis_range=[0,1], yaxis2_range=[0,1],yaxis3_range=[0,1],yaxis4_range=[0,1], yaxis5_range=[0,1], yaxis6_range=[0,1],
                    yaxis7_range=[0,1],yaxis8_range=[0,1],yaxis9_range=[0,1],yaxis10_range=[0,1],yaxis11_range=[0,1],yaxis12_range=[0,1],
                    # bar polar (uncomment to fix range)
                    # polar_radialaxis_range=barpolar_range,polar2_radialaxis_range=barpolar_range,polar3_radialaxis_range=barpolar_range,
                    # polar4_radialaxis_range=barpolar_range,polar5_radialaxis_range=barpolar_range,polar6_radialaxis_range=barpolar_range,
                    # polar7_radialaxis_range=barpolar_range,polar8_radialaxis_range=barpolar_range,polar9_radialaxis_range=barpolar_range,
                    # polar10_radialaxis_range=barpolar_range,polar11_radialaxis_range=barpolar_range,polar12_radialaxis_range=barpolar_range,
                    # y axis violin
                    yaxis_visible=True, yaxis_linewidth=template_specs['axes_width'], yaxis_title_text='R\u00b2', yaxis_ticklen=8, 
                    # traces #13
                    yaxis13_visible=True, yaxis13_linewidth=template_specs['axes_width'], yaxis13_title_text=y_label_trace, 
                    yaxis13_range=trace_range, yaxis13_ticklen=8, yaxis13_tickvals=trace_tickvals, yaxis13_ticktext=trace_ticktexts,
                    xaxis13_visible=True, xaxis13_linewidth=template_specs['axes_width'], xaxis13_title_text=x_label_trace, 
                    xaxis13_range=trace_range, xaxis13_ticklen=8, xaxis13_tickvals=trace_tickvals, xaxis13_ticktext=trace_ticktexts,
                    # traces #14
                    yaxis14_visible=True, yaxis14_linewidth=template_specs['axes_width'], yaxis14_showticklabels=False, 
                    yaxis14_range=trace_range, yaxis14_ticklen=8, yaxis14_tickvals=trace_tickvals,
                    xaxis14_visible=True, xaxis14_linewidth=template_specs['axes_width'], xaxis14_title_text=x_label_trace, 
                    xaxis14_range=trace_range, xaxis14_ticklen=8, xaxis14_tickvals=trace_tickvals, xaxis14_ticktext=trace_ticktexts,
                    # traces #15
                    yaxis15_visible=True, yaxis15_linewidth=template_specs['axes_width'], yaxis15_showticklabels=False, 
                    yaxis15_range=trace_range, yaxis15_ticklen=8, yaxis15_tickvals=trace_tickvals, 
                    xaxis15_visible=True, xaxis15_linewidth=template_specs['axes_width'], xaxis15_title_text=x_label_trace, 
                    xaxis15_range=trace_range, xaxis15_ticklen=8, xaxis15_tickvals=trace_tickvals, xaxis15_ticktext=trace_ticktexts,)


fig.show(config = {'displayModeBar': False})

In [67]:
# SAVE
# -----
fig.write_image(file="fig1.svg")
fig.write_html(file="fig1.html",config = {'displayModeBar': False})