In [1]:
# Stop warnings
import warnings
warnings.filterwarnings("ignore")

# Imports
import os
import cv2
import sys
import time
import json
import copy
import cortex
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt


# Personal imports
sys.path.append("{}/../../../analysis_code/utils".format(os.getcwd()))
from plot_utils import *
from pycortex_utils import draw_cortex, set_pycortex_config_file, load_surface_pycortex, create_colormap, get_rois
from surface_utils import load_surface

In [2]:
# Directories
main_dir = '/Users/uriel/disks/meso_shared'
project_dir = 'RetinoMaps'

In [3]:
with open('../../settings.json') as f:
    json_s = f.read()
    analysis_info = json.loads(json_s)

In [4]:
# Set pycortex db and colormaps
cortex_dir = "{}/{}/derivatives/pp_data/cortex".format(main_dir, project_dir)
set_pycortex_config_file(cortex_dir)

# Webgl port
port_num = 25000

# pRF time series

In [5]:
# Template settings
template_specs = dict(axes_color="rgba(0, 0, 0, 1)",
                      axes_width=2,
                      axes_font_size=15,
                      bg_col="rgba(255, 255, 255, 1)",
                      font='Arial',
                      title_font_size=15,
                      plot_width=1.5)
fig_template = plotly_template(template_specs)

## Functions

In [6]:
def gaus_2d_css(gauss_x, gauss_y, gauss_sd, n, screen_side, grain=200):
    """
    Generate 2D Gaussian mesh with CSS model
    
    Parameters
    ----------
    gauss_x : mean x Gaussian parameter in dva (e.g. 1 dva)
    gauss_y : mean y Gaussian parameter in dva (e.g. 1 dva)
    gauss_sd : sd Gaussian parameter in dva (e.g. 1 dva)
    screen_side : mesh screen side (square) in dva (e.g. 20 dva from -10 to 10 dva)
    n : exponent parameter of the CSS model
    grain : grain resolution of the mesh in pixels (default = 200 pixels)
    
    Returns
    -------
    x : linspace x of the mesh
    y : linspace y of the mesh
    z : mesh_z values (to plot)
    
    """
    x = np.linspace(-screen_side/2, screen_side/2, grain)
    y = np.linspace(-screen_side/2, screen_side/2, grain)
    mesh_x, mesh_y = np.meshgrid(x, y)
    
    gauss_z = 1./(2.*np.pi*gauss_sd*gauss_sd) * np.exp(-((mesh_x-gauss_x)**2./(2.*gauss_sd**2.) + (mesh_y-gauss_y)**2./(2.*gauss_sd**2.)))
    
    # Apply the CSS model by raising the Gaussian to the power of n
    gauss_z_css = gauss_z ** n
    
    return x, y, gauss_z_css

In [None]:
def draw_timeseries(bold_data, prf_prediction, vox_data, vox_model, TRs, roi):
    
    # # compute r2 if voxel of model and of data are different
    # if vox_data != vox_model:
    #     r2_val = pingouin.corr(df.data_fs[vox_data], df.pred_fs[vox_model]).iloc[0]['r']**2
    # else:
    #     r2_val = df.r2_fs[vox_model]
    
    
    # General figure settings
    fig_template = plotly_template(template_specs)
    
    # Subplot settings
    rows, cols = 2, 2
    margin_t, margin_b, margin_l, margin_r = 50, 50 ,50 ,50
    fig_ratio = 5
    fig_height = 1080/fig_ratio + (1080/fig_ratio*0.15) + margin_t+margin_b
    fig_width = 1920/fig_ratio + 1920/fig_ratio + margin_l+margin_r
    column_widths,row_heights = [1,1],[0.15,1]
    sb_specs = [[{},{}],[{},{}]]
    hover_data = 'Time: %{x:1.2f} s<br>' + 'z-score: %{y:1.2f}'
    hover_model = 'Time: %{x:1.2f} s<br>' + 'z-score: %{y:1.2f}'

    xaxis_range = [0,250]
    yaxis_range = [-2,3]
    yaxis_dtick = 1
    x_tickvals = np.linspace(0,208,6)*TR
    
    lwd_mot = np.array([ 114*TR,  142*TR, 0.5,  0.5])
    dwd_mot = np.array([ 80*TR, 80*TR, 0.85,  0.15])
    rwd_mot = np.array([ 46*TR,  18*TR, 0.5, 0.5])
    uwd_mot = np.array([176*TR, 176*TR, 0.15, 0.85])
    
    x0_all = np.array([0,16,48,64,96,112,144,160,192])*TR
    x1_all = np.array([16,48,64,96,112,144,160,192,208])*TR
    
    rolling = 3
    data_col = 'rgba(0, 0, 0, 1)'
    model_col = 'rgba(200, 0, 0, 1)'
    subplot_titles = ['<b>{} time series </b> ({})'.format(roi, subject),'','','']
    screen_side = 10*2
    # prf_xrange = [-8.9,8.9]
    # prf_yrange = [-5,5]
    prf_xrange = [-10,10]
    prf_yrange = [-10,10]    
    
    x_par_txt = -10
    y_par_text = 9
    
    # avg to have less points
    bold_data_reshaped = bold_data.reshape(104, 2, -1)
    bold_data_mean = np.mean(bold_data_reshaped, axis=1)
    
    prf_pred_data_reshaped = prf_pred_data.reshape(104, 2, -1)
    prf_pred_data_mean = np.mean(prf_pred_data_reshaped, axis=1)
    # create figure
    fig = make_subplots(rows=rows, cols=cols, specs=sb_specs, print_grid=False, vertical_spacing=0.05, horizontal_spacing=0.05, 
                        column_widths=column_widths, row_heights=row_heights,  subplot_titles=subplot_titles)

    # Timeseries stim
    for x0,x1 in zip(x0_all,x1_all):
        fig.add_shape(type='rect', xref='x', yref='y', x0=x0, y0=0, x1=x1, y1=1, 
                      line_width=2, fillcolor='black', line_color='white')
        
    for [coord_tp] in zip([rwd_mot, dwd_mot, lwd_mot, uwd_mot]):
        fig.add_annotation(ax=coord_tp[0], x=coord_tp[1], ay=coord_tp[2], y=coord_tp[3], 
                           xref='x', yref='y', axref='x',ayref='y',
                           text='', showarrow=True, arrowhead=2, arrowcolor='white')


    # time series data
    fig.append_trace(go.Scatter(x=np.linspace(0,TRs*TR,104), 
                                y=bold_data_mean[:,vox_data],                        
                                name='<i>data<i>',
                                showlegend=True, mode='markers', marker_color=data_col,
                                hovertemplate=hover_data,
                                line_width=0, opacity=1, marker_size=6),row=2, col=1)
    # time series predictions
    fig.append_trace(go.Scatter(x=np.linspace(0,TRs*TR,104), 
                                y=prf_pred_data_mean[:,vox_model],
                                name='<i>model<i>',
                                showlegend=True, mode='lines', line_color=data_col, 
                                hovertemplate=hover_model,
                                line_width=2, opacity=1),row=2, col=1)

    # pRF heatmap
    x,y,z = gaus_2d_css(gauss_x=prf_fit_data[x_idx,vox_model], gauss_y=prf_fit_data[y_idx,vox_model], 
                gauss_sd=prf_fit_data[size_idx,vox_model], n=prf_fit_data[n_idx,vox_model], screen_side=screen_side)
    
    
    fig.append_trace(go.Heatmap(x=x, y=y, z=z,colorscale='viridis', showscale=False, hoverinfo='none'),row=2,col=2)

    fig.add_annotation(x=prf_fit_data[x_idx,vox_model], ax=prf_fit_data[x_idx,vox_model], y=prf_yrange[0], ay=prf_yrange[0]-0.5,
                       xref='x4', yref='y4', axref='x4',ayref='y4', yanchor="top", showarrow=True,
                       text='<i>pRFx</i> = {:1.2g}°'.format(prf_fit_data[x_idx,vox_model]), arrowhead=2, arrowwidth=2.5)

    fig.add_annotation(x=prf_xrange[1], ax=prf_xrange[1]+0.5, y=prf_fit_data[y_idx,vox_model], ay=prf_fit_data[y_idx,vox_model],
                       xref='x4', yref='y4', axref='x4',ayref='y4', yanchor="top", showarrow=True, 
                       text='<i>pRFy</i> = {:1.2g}°'.format(prf_fit_data[y_idx,vox_model]), textangle=-90, arrowhead=2, arrowwidth=2.5)
    
    fig.add_shape(type='line', xref='x4', yref='y4', x0=prf_xrange[0], x1=prf_xrange[1], y0=prf_fit_data[y_idx,vox_model], y1=prf_fit_data[y_idx,vox_model], 
                      line_width=2, line_color='white', line_dash='dot')
    
    fig.add_shape(type='line', xref='x4', yref='y4', x0=prf_fit_data[x_idx,vox_model], x1=prf_fit_data[x_idx,vox_model], y0=prf_yrange[0], y1=prf_yrange[1], 
                      line_width=2, line_color='white', line_dash='dot')     
    
    fig.add_annotation(x=x_par_txt, y=y_par_text, xref='x4', yref='y4', xanchor="left", font_color='white', showarrow=False,
                       text='<i>pRF loo R<sup>2</sup></i> = {:1.2g}'.format(prf_fit_data[r2_idx,vox_model]))
    
    fig.add_annotation(x=x_par_txt, y=y_par_text-1.5, xref='x4', yref='y4', xanchor="left", font_color='white', showarrow=False,
                       text='<i>pRF size</i> = {:1.2g}°'.format(prf_fit_data[size_idx,vox_model]))
    
    fig.add_annotation(x=x_par_txt, y=y_par_text-3, xref='x4', yref='y4',  xanchor="left", font_color='white', showarrow=False, 
                       text='<i>pRF ecc</i> = {:1.2g}°'.format(prf_fit_data[ecc_idx,vox_model]))
    
    fig.add_annotation(x=x_par_txt, y=y_par_text-4.5, xref='x4', yref='y4',  xanchor="left", font_color='white', showarrow=False, 
                       text='<i>pRF angle</i> = {:3.0f}°'.format(np.angle(prf_fit_data[x_idx,vox_model] + 1j * prf_fit_data[y_idx,vox_model],deg=True)))

    # set axis
    for row in np.arange(rows):
        for col in np.arange(cols):
            fig.update_xaxes(visible=True, ticklen=8, linewidth=template_specs['axes_width'], row=row+1, col=col+1)
            fig.update_yaxes(visible=True, ticklen=8, linewidth=template_specs['axes_width'], row=row+1, col=col+1)
            
    fig.update_xaxes(scaleanchor="y4", scaleratio=1, row=2, col=2)
    fig.update_yaxes(scaleanchor="x4", scaleratio=1, row=2, col=2)
    fig.layout.update(xaxis_range=xaxis_range, xaxis_title='', 
                      xaxis_visible=False, yaxis_visible=False,
                      yaxis_range=[0,1], yaxis_title='',
                      xaxis4_range=prf_xrange, xaxis4_title='', 
                      yaxis4_range=prf_yrange, yaxis4_title='', 
                      xaxis4_visible=False, yaxis4_visible=False,
                      xaxis3_tickvals=x_tickvals, xaxis3_ticktext=np.round(x_tickvals),
                      xaxis3_range=xaxis_range, xaxis3_title='Time (seconds)',
                      yaxis3_range=yaxis_range, yaxis3_title='z-score',yaxis3_dtick=yaxis_dtick,
                      template=fig_template, width=fig_width, height=fig_height, 
                      margin_l=margin_l+10, margin_r=margin_r-10, margin_t=margin_t-20, margin_b=margin_b+20,
                      legend_yanchor='top', legend_y=0.85, legend_xanchor='left', 
                      legend_x=0.02, legend_bgcolor='rgba(255,255,255,0)')

    return fig

## Plots

In [7]:
subject = 'sub-11'
format_ = 'fsnative'
extension = 'func.gii'

In [None]:
bold_dir = '{}/{}/derivatives/pp_data/{}/{}/func/fmriprep_dct_loo_avg'.format(main_dir, project_dir,subject,format_)
bold_fn = '{}/{}_task-pRF_hemi-L_fmriprep_dct_avg_loo-1_bold.{}'.format(bold_dir, subject, extension)
bold_img, bold_data = load_surface(fn = bold_fn)

In [None]:
prf_dir ='{}/{}/derivatives/pp_data/{}/{}/prf/fit/'.format(main_dir, project_dir,subject,format_)
prf_pred_fn = '{}/{}_task-pRF_hemi-L_fmriprep_dct_avg_loo-1_prf-pred_css.func.gii'.format(prf_dir, subject)
prf_pred_img, prf_pred_data = load_surface(fn = prf_pred_fn)

In [None]:
ecc_idx, size_idx, x_idx, y_idx, n_idx, r2_idx = 1,4,7,8,11,12

prf_deriv_dir = '{}/{}/derivatives/pp_data/{}/{}/prf/prf_derivatives'.format(main_dir, project_dir,subject,format_)
prf_fit_fn = '{}/{}_task-pRF_hemi-L_fmriprep_dct_avg_loo-1_prf-deriv_css.func.gii'.format(prf_deriv_dir, subject)
prf_fit_img, prf_fit_data = load_surface(fn = prf_fit_fn)

In [None]:
# vertex_to_plot = 9383

fig = draw_timeseries(bold_data=bold_data, prf_prediction=prf_pred_data, vox_data=vertex_to_plot, vox_model=vertex_to_plot, TRs=TRs, roi='V1')
fig.write_image('/Users/uriel/Downloads/time_seris.pdf')
fig.show()

# GLM time series

In [None]:
# task = 'PurLoc'
# task = 'SacLoc'
# task = 'SacVELoc'
task = 'PurVELoc'

In [None]:
bold_glm_dir = '{}/{}/derivatives/pp_data/{}/{}/func/fmriprep_dct_loo_avg'.format(main_dir, project_dir,subject,format_)
bold_glm_fn = '{}/{}_task-{}_hemi-L_fmriprep_dct_avg_loo-1_bold.func.gii'.format(bold_dir, subject,task)
bold_glm_img,bold_glm_data = load_surface(fn = bold_glm_fn)