In [None]:
# Imports
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import os
import scipy
import matplotlib.pyplot as plt
from scipy import stats
import random
from sklearn.linear_model import LinearRegression

# Figure imports
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

# Import data base
main_dir = '/home/mszinte/disks/meso_S/data'
project_dir = 'gaze_exp'
pp_dir = "{}/{}/derivatives/pp_data".format(main_dir, project_dir)
tsv_dir ='{}/sub-all/tsv'.format(pp_dir)
fig_dir = '{}/sub-all/figures'.format(pp_dir)
os.makedirs(fig_dir, exist_ok=True)

In [None]:
def weighted_regression(x_reg,y_reg,weight_reg):
    """
    Function to compute regression parameter weighted by a matrix (e.g. r2 value).

    Parameters
    ----------
    x_reg : array (1D)
        x values to regress
    y_reg : array
        y values to regress
    weight_reg : array (1D) 
        weight values (0 to 1) for weighted regression

    Returns
    -------
    coef_reg : array
        regression coefficient
    intercept_reg : str
        regression intercept
    """

    from sklearn import linear_model
    import numpy as np
    
    regr = linear_model.LinearRegression()
    
    x_reg = np.array(x_reg)
    y_reg = np.array(y_reg)
    weight_reg = np.array(weight_reg)
    
    def m(x, w):
        return np.sum(x * w) / np.sum(w)

    def cov(x, y, w):
        # see https://www2.microstrategy.com/producthelp/archive/10.8/FunctionsRef/Content/FuncRef/WeightedCov__weighted_covariance_.htm
        return np.sum(w * (x - m(x, w)) * (y - m(y, w))) / np.sum(w)

    def weighted_corr(x, y, w):
        # see https://www2.microstrategy.com/producthelp/10.4/FunctionsRef/Content/FuncRef/WeightedCorr__weighted_correlation_.htm
        return cov(x, y, w) / np.sqrt(cov(x, x, w) * cov(y, y, w))

    
    x_reg_nan = x_reg[(~np.isnan(x_reg) & ~np.isnan(y_reg))]
    y_reg_nan = y_reg[(~np.isnan(x_reg) & ~np.isnan(y_reg))]
    weight_reg_nan = weight_reg[~np.isnan(weight_reg)]

    regr.fit(x_reg_nan.reshape(-1, 1), y_reg_nan.reshape(-1, 1),weight_reg_nan)
    coef_reg, intercept_reg = regr.coef_, regr.intercept_

    return coef_reg, intercept_reg

In [None]:
# Load data
data = pd.read_table('{}/sub-all_prf_cf.tsv'.format(tsv_dir))
data.drop(columns=['Unnamed: 0'])

In [None]:
# Filter data
ecc_th = [0, 20]
size_th= [0, 20]
rsq_th = [0, 1]

# Replace all data outer threshold with NaN data
data.loc[(data.prf_ecc < ecc_th[0]) | (data.prf_ecc > ecc_th[1]) | 
         (data.prf_size < size_th[0]) | (data.prf_size > size_th[1]) | 
         (data.cf_center_ecc < ecc_th[0]) | (data.cf_center_ecc > ecc_th[1]) | 
         (data.cf_center_size < size_th[0]) | (data.cf_center_size > size_th[1]) | 
         (data.prf_rsq_loo <=rsq_th[0]) | (data.cf_center_cf_rsq <=rsq_th[0]) | 
         (data.cf_center_cf_rsq >=rsq_th[1])
        ] = np.nan

data = data.dropna()

rois = pd.unique(data.roi)
mask = pd.notnull(data.subject)
subjects = pd.unique(data.subject[mask])

# Define colors
roi_colors = px.colors.sequential.Sunset[:4] + px.colors.sequential.Rainbow[:]

In [None]:
subjects = ['sub-001','sub-002']
rois = ['V1', 'V2', 'V3','V3AB', 'LO','VO','hMT+','iIPS', 'sIPS','iPCS','sPCS', 'mPCS']
params = ['x','y']

rows = 2
cols = 14
num_to_plot = 200

for i, subject in enumerate(subjects):
    fig = make_subplots(rows=rows, cols=cols, print_grid=False, 
                        vertical_spacing=0.3, 
                        horizontal_spacing=0.02
                       )
    for l, param in enumerate(params):
        for j, roi in enumerate(rois):
            df = data.loc[(data.subject == subject) & (data.roi == roi)]
            x_cor = np.array(df['prf_{}'.format(param)])
            y_cor = np.array(df['cf_center_{}'.format(param)])
            r_cor = np.array(df['prf_rsq_loo'])

            # plot randomly selected 250 vertex
            vertex_to_plot = np.arange(0,x_cor.size)
            random.shuffle(vertex_to_plot)
            vertex_to_plot = vertex_to_plot[0:num_to_plot]
            
            
            
            fig.add_trace(go.Scatter(x=x_cor[vertex_to_plot], 
                                     y=y_cor[vertex_to_plot],  mode='markers', 
                                     marker=dict(color=roi_colors[j], size=4, opacity=0.3,
                                                 line=dict(width=0)), 
                                     showlegend=False), 
                          row=l+1, col=j+1)

            # plot correlation
            slope, intercept = weighted_regression(x_cor, y_cor, r_cor)
            cor_x = np.linspace(-15,15,100)
            cor_y = slope[0][0] * cor_x + intercept[0]
            fig.add_trace(go.Scatter(x=cor_x, y=cor_y, mode='lines', 
                                     line=dict(color=roi_colors[j], width=3), 
                                     showlegend=False), 
                              row=l+1, col=j+1)

            # change tick
            fig.update_xaxes(range=[-15,15], tickmode='array', tickvals = [-15,0,15], ticktext = [-15,0,15], row=l+1, col=j+1)
            fig.update_yaxes(range=[-15,15], tickmode='array', tickvals = [-15,0,15], ticktext = [-15,0,15], row=l+1, col=j+1)

            # put label
            fig.update_xaxes(title_text='pRF {} coord (dva)'.format(param), row=l+1)
            fig.update_yaxes(title_text='CF {} coord (dva)'.format(param), row=l+1, col=1)

        fig.update_layout(height=500, width=14*170, showlegend=False, template='simple_white')

    fig.show()
    fig.write_image("{}/{}_cor_xy.pdf".format(fig_dir, subject)) 