In [None]:
# Imports
import os
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

# Figure imports
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.express as px
from plot_utils import plotly_template

# Define parameters
subjects = ['sub-001', 'sub-002', 'sub-003', 'sub-004',
            'sub-005', 'sub-006', 'sub-007', 'sub-008']
subjects_plot = ['sub-001', 'sub-002', 'sub-003', 'sub-004',
                 'sub-005', 'sub-006', 'sub-007', 'sub-008', 'group']
rois = ['V1', 'V2', 'V3', 'V3AB', 'hMT+', 'LO',
        'VO', 'iIPS', 'sIPS', 'iPCS', 'sPCS', 'mPCS']
TR = 1.3

# Graph specific plot
gaze_tasks = ['GazeCenter', 'GazeLeft', 'GazeRight']
attend_tasks = ['AttendBar','AttendFix']
attend_tasks_txt = ['Attend-bar', 'Attend-fix']

# Define folders
base_dir = '/home/mszinte/disks/meso_S/data/gaze_prf'
bids_dir = "{}".format(base_dir)
pp_dir = "{}/derivatives/pp_data".format(base_dir)

# General figure settings
template_specs = dict(  axes_color="rgba(0, 0, 0, 1)",
                        axes_width=2,
                        axes_font_size=13,
                        bg_col="rgba(255, 255, 255, 1)",
                        font='Arial',
                        title_font_size=15,
                        plot_width=1.5)
fig_template = plotly_template(template_specs)

In [None]:
# Subplot settings
margin_t, margin_b, margin_l, margin_r = 50, 100, 100 ,50
rows, cols = 6, 2

row_heights, column_widths =  [0.2,1,0.2,1,0.2,1], [1,1]
sb_specs = [[{},{}],[{},{}],[{},{}],[{},{}],[{},{}],[{},{}]]
subplot_width, subplot_height = 95,120

fig_width, fig_height = 0, 0
for column_width in column_widths:fig_width += subplot_width*column_width
for row_height in row_heights: fig_height +=subplot_height*row_height
fig_width = fig_width + margin_l + margin_r
fig_height = fig_height + margin_t + margin_b

xaxis_range = [0, 1]
xaxis_tickvals = [0, .5, 1]
xaxis_ticktext = [0, .5, 1]
xaxis_title = 'Time (%)'

yaxis_range = [-8, 8]
yaxis_tick = 5
yaxis_tickvals = np.linspace(yaxis_range[0],yaxis_range[1],yaxis_tick)
yaxis_title = 'Decoded<br>position (dva)'
yaxis_ticktext = []
[yaxis_ticktext.append(('{:g}'.format(val))) for val in yaxis_tickvals]

xrefs, axrefs = ['x1','x2','x5','x6','x9','x10'], ['x1','x2','x5','x6','x9','x10']
yrefs, ayrefs = ['y1','y2','y5','y6','y9','y10'], ['y1','y2','y5','y6','y9','y10']


line_width = 3
line_width_pred = 2
gc_line_colors = ["rgba(243, 146, 0, 1)", "rgba(242, 190, 121, 1)"]
gc_area_colors = ["rgba(243, 146, 0, 0.3)", "rgba(242, 190, 121, 0.3)"]
gl_line_colors = ["rgba(41, 101, 44, 1)", "rgba(153, 198, 98, 1)"]
gl_area_colors = ["rgba(41, 101, 44, 0.3)", "rgba(153, 198, 98, 0.3)"]
gr_line_colors = ["rgba(142, 19, 84, 1)", "rgba(230, 151, 193, 1)"]
gr_area_colors = ["rgba(142, 19, 84, 0.3)", "rgba(230, 151, 193, 0.3)"]
retino_color = 'rgba(227, 6, 19, 0.5)'
spatio_color = 'rgba(29, 113, 184, 0.5)'

gc_hover = 'Time: %{x:1.2f}%% s<br>' + 'Gaze center: %{y:1.2f} dva'
gl_hover = 'Time: %{x:1.2f}%% s<br>' + 'Gaze left: %{y:1.2f} dva'
gr_hover = 'Time: %{x:1.2f}%% s<br>' + 'Gaze right: %{y:1.2f} dva'

gc_retino_pred = np.linspace(4,-4,18)
gc_spatio_pred = gc_retino_pred
gl_retino_pred, gr_retino_pred = gc_retino_pred, gc_retino_pred
gl_spatio_pred, gr_spatio_pred = gc_retino_pred+4, gc_retino_pred-4

In [None]:
# across participants
for gaze_task in gaze_tasks:
    for attend_task, attend_task_txt in zip(attend_tasks, attend_tasks_txt):
        for subject_num, subject in enumerate(subjects):
            # get data
            tsv_dir = '{}/{}/decode/tsv'.format(pp_dir, subject)
            df_sub_fn = "{}/{}_task-{}{}_decode_par_barpass.pkl".format(tsv_dir,subject,gaze_task,attend_task)
            df_sub = pd.read_pickle(df_sub_fn)
    
            # across subject
            if subject_num == 0: df_group = df_sub
            else: df_group = pd.concat([df_group, df_sub])
            
        # create dataframe
        time = np.linspace(0,1,18)
        for roi_num, roi in enumerate(rois):
            df_roi = pd.DataFrame({'subject': ['group'] * time.shape[0],
                                   'roi': [roi] * time.shape[0],
                                   'Time': time,
                                   'decoded_x_mean_barpass': df_group.loc[df_group.roi==roi].groupby(['Time']).decoded_x_mean_barpass.mean().reset_index(name='decoded_x_mean').decoded_x_mean,
                                   'decoded_x_sem_barpass': df_group.loc[df_group.roi==roi].groupby(['Time']).decoded_x_mean_barpass.sem(ddof=-1).reset_index(name='decoded_x_sem').decoded_x_sem,
                                   'decoded_h_mean_barpass': df_group.loc[df_group.roi==roi].groupby(['Time']).decoded_h_mean_barpass.mean().reset_index(name='decoded_h_mean').decoded_h_mean,
                                   'decoded_h_sem_barpass': df_group.loc[df_group.roi==roi].groupby(['Time']).decoded_h_mean_barpass.sem(ddof=-1).reset_index(name='decoded_h_sem').decoded_h_sem})
            # across rois
            if roi_num == 0: df = df_roi
            else: df = pd.concat([df, df_roi])
            
        # save group data
        tsv_dir_group = '{}/group/decode/tsv'.format(pp_dir)
        df_fn = "{}/group_task-{}{}_decode_par_barpass.pkl".format(tsv_dir_group,gaze_task,attend_task)
        print('saving {}'.format(df_fn))
        df.to_pickle(df_fn)

In [None]:
for subject in subjects_plot:
    tsv_dir = '{}/{}/decode/tsv'.format(pp_dir, subject)
    
    subplot_titles = ['<b>Attend-bar</b><br>({})'.format(subject),
                      '<b>Attend-fix</b><br>({})'.format(subject),
                      '','','','',
                      '','','','']
    for roi_num, roi in enumerate(rois):
        fig = make_subplots(rows=rows, cols=cols, specs=sb_specs, print_grid=False, vertical_spacing=0.04, horizontal_spacing=0.1,
                    column_widths=column_widths, row_heights=row_heights,  subplot_titles=subplot_titles)

        for gaze_task in gaze_tasks:
            if gaze_task == 'GazeCenter': 
                line_colors, area_colors, row, retino_line_dash, spatio_line_dash, retino_pred, spatio_pred, hover = \
                            gc_line_colors, gc_area_colors, 2, 'solid', 'dash', gc_retino_pred, gc_spatio_pred, gc_hover
                showlegend=False
            elif gaze_task == 'GazeLeft': 
                line_colors, area_colors, row, retino_line_dash, spatio_line_dash, retino_pred, spatio_pred, hover = \
                            gl_line_colors, gl_area_colors, 4, 'dash', 'dash', gl_retino_pred, gl_spatio_pred, gl_hover
                showlegend=True
            elif gaze_task == 'GazeRight': 
                line_colors, area_colors, row, retino_line_dash, spatio_line_dash, retino_pred, spatio_pred, hover = \
                            gr_line_colors, gr_area_colors, 6 , 'dash', 'dash', gr_retino_pred, gr_spatio_pred, gr_hover
                showlegend=False

            
            for attend_task, attend_task_txt in zip(attend_tasks, attend_tasks_txt):
                if attend_task == 'AttendBar':
                    col, line_color, area_color = 1, line_colors[0], area_colors[0]
                elif attend_task == 'AttendFix':
                    col, line_color, area_color = 2, line_colors[1], area_colors[1]
                    showlegend=False
            
                # get data
                df_fn = "{}/{}_task-{}{}_decode_par_barpass.pkl".format(tsv_dir,subject,gaze_task,attend_task)
                df = pd.read_pickle(df_fn)
                
                # motion caption
                for xref,yref,axref,ayref in zip(xrefs, yrefs, axrefs, ayrefs):
                    fig.add_shape(type='rect', xref=xref, yref=yref, x0=0, y0=0, x1=1, y1=1, 
                                  line_width=2, fillcolor='black', line_color='white')

                    fig.add_annotation(ax=0.75, x=0.25, ay=0.5, y=0.5, 
                                       xref=xref, yref=yref, axref=axref,ayref=ayref,
                                       text='', showarrow=True, arrowhead=2, arrowcolor='white')


                # retino prediction
                x_retino_pred = df.loc[(df.roi==roi)].Time
                y_retino_pred = retino_pred
                fig.append_trace(go.Scatter(x=x_retino_pred, y=y_retino_pred, showlegend=showlegend, mode='lines', line_dash=retino_line_dash, line_color=retino_color, line_width=line_width_pred,
                                            name='<i>retinotopic prediction<i>',legendgroup='retino_model', hoverinfo='skip'), row=row, col=col)

                # spatio prediction
                x_spatio_pred = df.loc[(df.roi==roi)].Time
                y_spatio_pred = spatio_pred
                fig.append_trace(go.Scatter(x=x_spatio_pred, y=y_spatio_pred, showlegend=showlegend, mode='lines', line_dash=spatio_line_dash, line_color=spatio_color, line_width=line_width_pred, 
                                             name='<i>spatiotopic prediction<i>',legendgroup='spatio_model', hoverinfo='skip'), row=row, col=col)

                # data
                x_decode = df.loc[(df.roi==roi)].Time
                y_decode = df.loc[(df.roi==roi)].decoded_x_mean_barpass
                if subject == 'group':eb_y_decode = df.loc[(df.roi==roi)].decoded_x_sem_barpass
                else:eb_y_decode = df.loc[(df.roi==roi)].decoded_x_std_barpass
                fig.append_trace(go.Scatter(x=x_decode, y=y_decode, showlegend=False, mode='lines', line_color=line_color, line_width=line_width, connectgaps=False, name='', hovertemplate=hover), row=row, col=col)
                fig.append_trace(go.Scatter(x=x_decode, y=y_decode+eb_y_decode, showlegend=False, mode='lines', fillcolor=area_color, line_width=0, connectgaps=False, hoverinfo='skip'), row=row, col=col)
                fig.append_trace(go.Scatter(x=x_decode, y=y_decode-eb_y_decode, showlegend=False, mode='lines', fillcolor=area_color, line_width=0, connectgaps=False, hoverinfo='skip', fill='tonexty'), row=row, col=col)
                
                
        for xref, yref in zip(['x3','x4','x7','x8','x11','x12'],['y3','y4','y7','y8','y11','y12']):
            fig.add_annotation(x=0.9, y=-6, xref=xref, yref=yref, showarrow=False, text='{}'.format(roi))
        
        for xaxis in ['xaxis1','xaxis2','xaxis5','xaxis6','xaxis9','xaxis10']:
            exec("fig.layout.update({}_range=xaxis_range)".format(xaxis))
            exec("fig.layout.update({}_title='')".format(xaxis))
            exec("fig.layout.update({}_showticklabels=False)".format(xaxis))
            exec("fig.layout.update({}_visible=False)".format(xaxis))

        for yaxis in ['yaxis1','yaxis2','yaxis5','yaxis6','yaxis9','yaxis10']:
            exec("fig.layout.update({}_range=[0,1])".format(yaxis))
            exec("fig.layout.update({}_title='')".format(yaxis))
            exec("fig.layout.update({}_showticklabels=False)".format(yaxis))
            exec("fig.layout.update({}_visible=False)".format(yaxis))

        for xaxis in ['xaxis3','xaxis4','xaxis7','xaxis8','xaxis11','xaxis12']:
            exec("fig.layout.update({}_range=xaxis_range)".format(xaxis))
            exec("fig.layout.update({}_tickvals=xaxis_tickvals)".format(xaxis))
            exec("fig.layout.update({}_ticktext=xaxis_ticktext)".format(xaxis))
            if xaxis == 'xaxis11' or xaxis == 'xaxis12':
                exec("fig.layout.update({}_title=xaxis_title)".format(xaxis))
            else:
                exec("fig.layout.update({}_showticklabels=False)".format(xaxis))

        for yaxis in ['yaxis3','yaxis4','yaxis7','yaxis8','yaxis11','yaxis12']:
            if yaxis == 'yaxis3' or yaxis == 'yaxis7' or yaxis == 'yaxis11':
                exec("fig.layout.update({}_title=yaxis_title)".format(yaxis))
            else:
                exec("fig.layout.update({}_showticklabels=False)".format(yaxis))
            exec("fig.layout.update({}_range=yaxis_range)".format(yaxis))
            exec("fig.layout.update({}_tickvals=yaxis_tickvals)".format(yaxis))
            exec("fig.layout.update({}_ticktext=yaxis_ticktext)".format(yaxis))
            
        # set axis
        for row in np.arange(rows):
            for col in np.arange(cols):
                fig.update_xaxes(ticklen=8, linewidth=template_specs['axes_width'], row=row+1, col=col+1)
                fig.update_yaxes(ticklen=8, linewidth=template_specs['axes_width'], row=row+1, col=col+1)

        # set figure
        fig.layout.update(template=fig_template, width=fig_width, height=fig_height, margin_l=margin_l, margin_r=margin_r, margin_t=margin_t, margin_b=margin_b,
                          legend_yanchor='top', legend_y=-0.17,legend_x=-0.05, legend_xanchor='left', legend_bgcolor='rgba(255,255,255,0)', legend_tracegroupgap=1)


        fig.show(config={"displayModeBar": False})
        fig.write_image("{}/{}_task-{}_decode_x_barpass_{}.pdf".format(tsv_dir, subject, attend_task, roi))
        fig.write_html("{}/{}_task-{}_decode_x_barpass_{}.html".format(tsv_dir, subject, attend_task, roi),config={"displayModeBar": False})
