In [None]:
# Imports
import os
import numpy as np
import nibabel as nb
import pandas as pd
import warnings
import pingouin

warnings.filterwarnings('ignore')

# Figure imports
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.express as px
from plot_utils import plotly_template

# Define folders
base_dir = '/home/mszinte/disks/meso_S/data/gaze_prf'
bids_dir = "{}".format(base_dir)
pp_dir = "{}/derivatives/pp_data".format(base_dir)

# settings
subject = 'sub-005'
best_voxels_num = 250
TR = 1.3

# general figure settings
template_specs = dict(  axes_color="rgba(0, 0, 0, 1)",
                        axes_width=2,
                        axes_font_size=13,
                        bg_col="rgba(255, 255, 255, 1)",
                        font='Arial',
                        title_font_size=15,
                        plot_width=1.5)
# General figure settings
fig_template = plotly_template(template_specs)

In [None]:
def gaus_2d(gauss_x, gauss_y, gauss_sd, screen_side, grain=200):
    """
    Generate 2D gaussian mesh
    
    Parameters
    ----------
    gauss_x : mean x gaussian parameter in dva (e.g. 1 dva)
    gauss_y : mean y gaussian parameter in dva (e.g. 1 dva)
    gauss_sd : sd gaussian parameter in dva (e.g. 1 dva)
    screen_side : mesh screen side (square) im dva (e.g. 20 dva from -10 to 10 dva)
    grain : grain resolution of the mesh in pixels (default = 100 pixels)
    
    Returns
    -------
    x : linspace x of the mesh
    y : linspace x of the mesh
    z : mesh_z values (to plot)
    
    """
    x = np.linspace(-screen_side/2, screen_side/2, grain)
    y = np.linspace(-screen_side/2, screen_side/2, grain)
    mesh_x, mesh_y = np.meshgrid(x,y) 
    
    gauss_z = 1./(2.*np.pi*gauss_sd*gauss_sd)*np.exp(-((mesh_x-gauss_x)**2./(2.*gauss_sd**2.)+(mesh_y-gauss_y)**2./(2.*gauss_sd**2.)))
    return x, y, gauss_z

In [None]:
def draw_timeseries(df, vox_data, vox_model):
    
    # compute r2 if voxel of model and of data are different
    if vox_data != vox_model:
        r2_val = pingouin.corr(df.data_fs[vox_data], df.pred_fs[vox_model]).iloc[0]['r']**2
    else:
        r2_val = df.r2_fs[vox_model]
    
    
    # General figure settings
    fig_template = plotly_template(template_specs)
    
    # Subplot settings
    rows, cols = 2, 2
    margin_t, margin_b, margin_l, margin_r = 50, 50 ,50 ,50
    fig_ratio = 5
    fig_height = 1080/fig_ratio + (1080/fig_ratio*0.15) + margin_t+margin_b
    fig_width = 1920/fig_ratio + 1920/fig_ratio + margin_l+margin_r
    column_widths,row_heights = [1,1],[0.15,1]
    sb_specs = [[{},{}],[{},{}]]
    hover_data = 'Time: %{x:1.2f} s<br>' + 'z-score: %{y:1.2f}'
    hover_model = 'Time: %{x:1.2f} s<br>' + 'z-score: %{y:1.2f}'

    xaxis_range = [0,195]
    yaxis_range = [-1,2]
    yaxis_dtick = 1
    x_tickvals = np.linspace(0,150,6)*TR
    lwd_mot = np.array([ 32*TR,  20*TR, 0.5,  0.5])
    dwd_mot = np.array([ 90*TR, 102*TR, 0.5,  0.5])
    rwd_mot = np.array([ 61*TR,  61*TR, 0.85, 0.15])
    uwd_mot = np.array([131*TR, 131*TR, 0.15, 0.85])
    x0_all = np.array([0,10,42,52,70,80,112,122,140])*TR
    x1_all = np.array([10,42,52,70,80,112,122,140,150])*TR
    rolling = 3
    data_col = 'rgba(0, 0, 0, 1)'
    model_col = 'rgba(200, 0, 0, 1)'
    subplot_titles = ['<b>V1 time series </b> ({})'.format(subject),'','','']
    screen_side = 8.9*2
    prf_xrange = [-8.9,8.9]
    prf_yrange = [-5,5]
    x_par_txt = 2.4
    y_par_text = 4

    # create figure
    fig = make_subplots(rows=rows, cols=cols, specs=sb_specs, print_grid=False, vertical_spacing=0.05, horizontal_spacing=0.05, 
                        column_widths=column_widths, row_heights=row_heights,  subplot_titles=subplot_titles)

    # Timeseries stim
    for x0,x1 in zip(x0_all,x1_all):
        fig.add_shape(type='rect', xref='x', yref='y', x0=x0, y0=0, x1=x1, y1=1, 
                      line_width=2, fillcolor='black', line_color='white')
    for [coord_tp] in zip([lwd_mot, dwd_mot, rwd_mot, uwd_mot]):
        fig.add_annotation(ax=coord_tp[0], x=coord_tp[1], ay=coord_tp[2], y=coord_tp[3], 
                           xref='x', yref='y', axref='x',ayref='y',
                           text='', showarrow=True, arrowhead=2, arrowcolor='white')

    # time series data
    fig.append_trace(go.Scatter(x=np.array(pd.Series(df.time_fs[vox_data]).rolling(window=rolling).mean())[::rolling],
                                y=np.array(pd.Series(df.data_fs[vox_data]).rolling(window=rolling).mean())[::rolling],
                                name='<i>data<i>',
                                showlegend=True, mode='markers', marker_color=data_col,
                                hovertemplate=hover_data,
                                line_width=0, opacity=1, marker_size=6),row=2, col=1)
    # time series predictions
    fig.append_trace(go.Scatter(x=df.time_fs[vox_model],
                                y=df.pred_fs[vox_model],
                                name='<i>model<i>',
                                showlegend=True, mode='lines', line_color=data_col, 
                                hovertemplate=hover_model,
                                line_width=2, opacity=1),row=2, col=1)

    # pRF heatmap
    x,y,z = gaus_2d(gauss_x=df.x_fs[vox_model], gauss_y=df.y_fs[vox_model], 
                    gauss_sd=df.sd_fs[vox_model], screen_side=screen_side)
    fig.append_trace(go.Heatmap(x=x, y=y, z=z,colorscale='viridis', showscale=False, hoverinfo='none'),row=2,col=2)

    fig.add_annotation(x=df.x_fs[vox_model], ax=df.x_fs[vox_model], y=prf_yrange[0], ay=prf_yrange[0]-0.5,
                       xref='x4', yref='y4', axref='x4',ayref='y4', yanchor="top", showarrow=True,
                       text='<i>pRFx</i> = {:1.2g}°'.format(df.x_fs[vox_model]), arrowhead=2, arrowwidth=2.5)

    fig.add_annotation(x=prf_xrange[1], ax=prf_xrange[1]+0.5, y=df.y_fs[vox_model], ay=df.y_fs[vox_model],
                       xref='x4', yref='y4', axref='x4',ayref='y4', yanchor="top", showarrow=True, 
                       text='<i>pRFy</i> = {:1.2g}°'.format(df.y_fs[vox_model]), textangle=-90, arrowhead=2, arrowwidth=2.5)
    fig.add_shape(type='line', xref='x4', yref='y4', x0=prf_xrange[0], x1=prf_xrange[1], y0=df.y_fs[vox_model], y1=df.y_fs[vox_model], 
                      line_width=2, line_color='white', line_dash='dot')
    fig.add_shape(type='line', xref='x4', yref='y4', x0=df.x_fs[vox_model], x1=df.x_fs[vox_model], y0=prf_yrange[0], y1=prf_yrange[1], 
                      line_width=2, line_color='white', line_dash='dot')        
    fig.add_annotation(x=x_par_txt, y=y_par_text, xref='x4', yref='y4', xanchor="left", font_color='white', showarrow=False,
                       text='<i>pRF R2</i> = {:1.2g}'.format(r2_val))
    fig.add_annotation(x=x_par_txt, y=y_par_text-1, xref='x4', yref='y4', xanchor="left", font_color='white', showarrow=False,
                       text='<i>pRF size</i> = {:1.2g}°'.format(df.sd_fs[vox_model]))
    fig.add_annotation(x=x_par_txt, y=y_par_text-2, xref='x4', yref='y4',  xanchor="left", font_color='white', showarrow=False, 
                       text='<i>pRF ecc</i> = {:1.2g}°'.format(df.ecc_fs[vox_model]))
    fig.add_annotation(x=x_par_txt, y=y_par_text-3, xref='x4', yref='y4',  xanchor="left", font_color='white', showarrow=False, 
                       text='<i>pRF angle</i> = {:3.0f}°'.format(np.angle(df.x_fs[vox_model] + 1j * df.y_fs[vox_model],deg=True)))

    # set axis
    for row in np.arange(rows):
        for col in np.arange(cols):
            fig.update_xaxes(visible=True, ticklen=8, linewidth=template_specs['axes_width'], row=row+1, col=col+1)
            fig.update_yaxes(visible=True, ticklen=8, linewidth=template_specs['axes_width'], row=row+1, col=col+1)

    fig.layout.update(xaxis_range=xaxis_range, xaxis_title='', 
                      xaxis_visible=False, yaxis_visible=False,
                      yaxis_range=[0,1], yaxis_title='',
                      xaxis4_range=prf_xrange, xaxis4_title='', 
                      yaxis4_range=prf_yrange, yaxis4_title='', 
                      xaxis4_visible=False, yaxis4_visible=False,
                      xaxis3_tickvals=x_tickvals, xaxis3_ticktext=np.round(x_tickvals),
                      xaxis3_range=xaxis_range, xaxis3_title='Time (seconds)',
                      yaxis3_range=yaxis_range, yaxis3_title='z-score',yaxis3_dtick=yaxis_dtick,
                      template=fig_template, width=fig_width, height=fig_height, 
                      margin_l=margin_l+10, margin_r=margin_r-10, margin_t=margin_t-20, margin_b=margin_b+20,
                      legend_yanchor='top', legend_y=0.85, legend_xanchor='left', 
                      legend_x=0.02, legend_bgcolor='rgba(255,255,255,0)')

    return fig

#### V1 time series plot

In [None]:
# load TSV
tsv_dir = '{}/{}/prf/tsv'.format(pp_dir, subject)
df_fn = "{}/{}_all_res_best{}.pkl".format(tsv_dir,subject,int(best_voxels_num))
df = pd.read_pickle(df_fn)

# to select voxel
# df.loc[(df.roi=='V1') & (df.r2_fs>0.4) & (df.sd_fs<1.5) & (df.x_fs<-1)]
figure1_vox = 114

fig = draw_timeseries(df=df, vox_data=figure1_vox, vox_model=figure1_vox)
fig.write_image("{}/{}_V1-timeseries.pdf".format(tsv_dir, subject))
fig.write_html("{}/{}_V1-timeseries.html".format(tsv_dir, subject),config={"displayModeBar": False})

# show and save figure
fig.show(config={"displayModeBar": False})

#### V1 time series plot for slides

In [None]:
# load TSV
tsv_dir = '{}/{}/prf/tsv'.format(pp_dir, subject)
df_fn = "{}/{}_all_res_best{}.pkl".format(tsv_dir,subject,int(best_voxels_num))
df = pd.read_pickle(df_fn)

vox_data = 160
vox_models = [52, 360, 451, 160]

for anim, vox_model in enumerate(vox_models):
    fig = draw_timeseries(df=df, vox_data=vox_data, vox_model=vox_model)
    fig.show(config={"displayModeBar": False})
    fig.write_image("{}/{}_V1-timeseries_slide{}.pdf".format(tsv_dir, subject,anim))

#### V1/hMT+ time series across gaze conditions

In [None]:
# load TSV
subject = 'sub-003'
tsv_dir = '{}/{}/prf/tsv'.format(pp_dir, subject)
df_fn = "{}/{}_all_res_best{}.pkl".format(tsv_dir,subject,int(best_voxels_num))
df = pd.read_pickle(df_fn)

# Select data
v1_num_vox = 35 # 20, 21 
hmt_num_vox = 1107#1115, 1107

In [None]:
# General figure settings
fig_template = plotly_template(template_specs)
    
# Subplot settings
rows, cols = 6, 2
margin_t, margin_b, margin_l, margin_r = 80, 90, 50, 50
fig_ratio = 5
ratio_fs_gaze = 159/195
fig_height = (1080/fig_ratio + (1080/fig_ratio*0.15))*rows/2 + margin_t + margin_b
fig_width = (1920*ratio_fs_gaze/fig_ratio)*cols + margin_l + margin_r
column_widths,row_heights = [1,1],[0.15,1,0.15,1,0.15,1]
sb_specs = [[{},{}],[{},{}],[{},{}],[{},{}],[{},{}],[{},{}]]
hover_data = 'Time: %{x:1.2f} s<br>' + 'z-score: %{y:1.2f}'

xaxis_range = [0,159]
yaxis_range = [-1,2]
yaxis_dtick = 1
x_tickvals = np.linspace(0,122,5)*TR
mot_1 = np.array([ 24*TR,  15*TR, 0.5, 0.5])
mot_2 = np.array([ 43*TR,  52*TR, 0.5, 0.5])
mot_3 = np.array([ 80*TR,  71*TR, 0.5, 0.5])
mot_4 = np.array([ 99*TR, 108*TR, 0.5, 0.5])
x0_all = np.array([0,10,29,39,57,67,85,95,113])*TR
x1_all = np.array([10,28,38,56,66,84,94,112,122])*TR
rolling = 3
data_col = 'rgba(0, 0, 0, 1)'
model_col = 'rgba(200, 0, 0, 1)'
retino_col = 'rgba(227, 6, 19, 1)'
spatio_col = 'rgba(29, 113, 184, 1)'
subplot_titles = ['<b>V1 time series </b><br>({}, attend-bar)'.format(subject),
                  '<b>hMT+ time series </b><br>({}, attend-bar)'.format(subject),
                  '','','','','','','','','','']
hover_data = 'Time: %{x:1.2f} s<br>' + 'z-score: %{y:1.2f}'
hover_model = 'Time: %{x:1.2f} s<br>' + 'z-score: %{y:1.2f}'

# create figure
fig = make_subplots(rows=rows, cols=cols, specs=sb_specs, print_grid=False, vertical_spacing=0.03, horizontal_spacing=0.08, 
                    column_widths=column_widths, row_heights=row_heights,  subplot_titles=subplot_titles)

for roi in ['V1','hMT+']:
    if roi == 'V1':     col_val, df2plot, num_vox, xrefs, yrefs = 1, df, v1_num_vox, ['x3','x7','x11'], ['y3','y7','y11']
    elif roi == 'hMT+': col_val, df2plot, num_vox, xrefs, yrefs = 2, df, hmt_num_vox, ['x4','x8','x12'], ['y4','y8','y12']
        
    for gaze_pos in ['gc','gl','gr']:
        if gaze_pos == 'gc':     row_val, line_dash, xref, yref = 2, 'dash', xrefs[0], yrefs[0]
        elif gaze_pos == 'gl':   row_val, line_dash, xref, yref = 4, 'solid', xrefs[1], yrefs[1]
        elif gaze_pos == 'gr':   row_val, line_dash, xref, yref = 6, 'solid', xrefs[2], yrefs[2]
        
        if roi == 'V1' and gaze_pos == 'gl': showlegend = True
        else: showlegend = False
        
        fig.append_trace(go.Scatter(x=np.array(pd.Series(df2plot['time_gaze'][num_vox]).rolling(window=rolling).mean())[::rolling],
                                    y=np.array(pd.Series(df2plot['data_{}_ab'.format(gaze_pos)][num_vox]).rolling(window=rolling).mean())[::rolling],
                                    name='<i>data<i>',legendgroup='data',
                                    showlegend=showlegend, mode='markers', marker_color=data_col, hovertemplate=hover_data,
                                    line_width=0, opacity=1, marker_size=6),row=row_val, col=col_val)

        fig.append_trace(go.Scatter(x=df2plot['time_gaze'][num_vox],
                                    y=df2plot['retino_pred_{}_ab'.format(gaze_pos)][num_vox],
                                    name='<i>retinotopic prediction<i>',legendgroup='retino_model',
                                    showlegend=showlegend, mode='lines', line_color=retino_col, hovertemplate=hover_model,
                                    line_width=2, opacity=1),row=row_val, col=col_val)

        fig.append_trace(go.Scatter(x=df2plot['time_gaze'][num_vox], 
                                    y=df2plot['spatio_pred_{}_ab'.format(gaze_pos)][num_vox], 
                                    name='<i>spatiotopic prediction<i>',legendgroup='spatio_model',
                                    showlegend=showlegend, mode='lines', line_color=spatio_col, hovertemplate=hover_model, line_dash=line_dash,
                                    line_width=2, opacity=1),row=row_val, col=col_val)
        
        fig.add_annotation(x=4, y=1.9, xref=xref, yref=yref, xanchor="left", font_color=retino_col, showarrow=False,
                           text='{:1.2f}'.format(df2plot['retino_pred_{}_ab_r2'.format(gaze_pos)][num_vox]))
        
        fig.add_annotation(x=4, y=1.6, xref=xref, yref=yref, xanchor="left", font_color=spatio_col, showarrow=False,
                           text='{:1.2f}'.format(df2plot['spatio_pred_{}_ab_r2'.format(gaze_pos)][num_vox]))


# Timeseries stim
xrefs = ['x','x2','x5','x6','x9','x10']
yrefs = ['y','y2','y5','y6','y9','y10']
axrefs = ['x','x2','x5','x6','x9','x10']
ayrefs = ['y','y2','y5','y6','y9','y10']

for xref,yref,axref,ayref in zip(xrefs, yrefs, axrefs, ayrefs):
    for x0,x1 in zip(x0_all,x1_all):
        fig.add_shape(type='rect', xref=xref, yref=yref, x0=x0, y0=0, x1=x1, y1=1, 
                      line_width=2, fillcolor='black', line_color='white')
    for [coord_tp] in zip([mot_1, mot_2, mot_3, mot_4]):
        fig.add_annotation(ax=coord_tp[0], x=coord_tp[1], ay=coord_tp[2], y=coord_tp[3], 
                           xref=xref, yref=yref, axref=axref,ayref=ayref,
                           text='', showarrow=True, arrowhead=2, arrowcolor='white')

# # set axis
for row in np.arange(rows):
    for col in np.arange(cols):
        fig.update_xaxes(visible=True, ticklen=8, linewidth=template_specs['axes_width'], row=row+1, col=col+1)
        fig.update_yaxes(visible=True, ticklen=8, linewidth=template_specs['axes_width'], row=row+1, col=col+1)

fig.layout.update(# time stim axis
                  xaxis_range=xaxis_range, xaxis_title='',     xaxis2_range=xaxis_range, xaxis2_title='', 
                  xaxis_visible=False, yaxis_visible=False,    xaxis2_visible=False, yaxis2_visible=False,
                  yaxis_range=[0,1], yaxis_title='',           yaxis2_range=[0,1], yaxis2_title='',
                  xaxis5_range=xaxis_range, xaxis5_title='',   xaxis6_range=xaxis_range, xaxis6_title='', 
                  xaxis5_visible=False, yaxis5_visible=False,  xaxis6_visible=False, yaxis6_visible=False,
                  yaxis5_range=[0,1], yaxis5_title='',         yaxis6_range=[0,1], yaxis6_title='',
                  xaxis9_range=xaxis_range, xaxis9_title='',   xaxis10_range=xaxis_range, xaxis10_title='', 
                  xaxis9_visible=False, yaxis9_visible=False,  xaxis10_visible=False, yaxis10_visible=False,
                  yaxis9_range=[0,1], yaxis9_title='',         yaxis10_range=[0,1], yaxis10_title='',
 
                  # time series axis
                  xaxis3_tickvals=x_tickvals, xaxis3_showticklabels=False,                        xaxis4_tickvals=x_tickvals, xaxis4_showticklabels=False,
                  xaxis3_range=xaxis_range, xaxis3_title='',                                      xaxis4_range=xaxis_range, xaxis4_title='',
                  yaxis3_range=yaxis_range, yaxis3_title='z-score',yaxis3_dtick=yaxis_dtick,      yaxis4_range=yaxis_range, yaxis4_title='', yaxis4_dtick=yaxis_dtick,
                  xaxis7_tickvals=x_tickvals, xaxis7_showticklabels=False,                        xaxis8_tickvals=x_tickvals, xaxis8_showticklabels=False,
                  xaxis7_range=xaxis_range, xaxis7_title='',                                      xaxis8_range=xaxis_range, xaxis8_title='',
                  yaxis7_range=yaxis_range, yaxis7_title='z-score', yaxis7_dtick=yaxis_dtick,     yaxis8_range=yaxis_range, yaxis8_title='', yaxis8_dtick=yaxis_dtick,                  
                  xaxis11_tickvals=x_tickvals, xaxis11_ticktext=np.round(x_tickvals),             xaxis12_tickvals=x_tickvals, xaxis12_ticktext=np.round(x_tickvals),
                  xaxis11_range=xaxis_range, xaxis11_title='Time (seconds)',                      xaxis12_range=xaxis_range, xaxis12_title='Time (seconds)',
                  yaxis11_range=yaxis_range, yaxis11_title='z-score', yaxis11_dtick=yaxis_dtick,  yaxis12_range=yaxis_range, yaxis12_title='', yaxis12_dtick=yaxis_dtick,
                  
                  # general settings
                  template=fig_template, width=fig_width, height=fig_height,
                  margin_l=margin_l+10, margin_r=margin_r-10, margin_t=margin_t-20, margin_b=margin_b+20,
                  legend_yanchor='top', legend_y=-0.08, legend_xanchor='left', 
                  legend_x=0.001, legend_bgcolor='rgba(255,255,255,0)')
        
fig.show(config={"displayModeBar": False})
fig.write_image("{}/{}_V1-hMT+_out_of_set_timeseries.pdf".format(tsv_dir, subject))
fig.write_html("{}/{}_V1-hMT+-out_of_set_timeseries.html".format(tsv_dir, subject),config={"displayModeBar": False})