In [None]:
# Stop warnings
import warnings
warnings.filterwarnings("ignore")

# Imports
import numpy as np
import pandas as pd
import os
import json
import sys
import cortex
from scipy import stats
import matplotlib.pyplot as plt

import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots


sys.path.append("{}/../../../utils".format(os.getcwd()))
from plot_utils import *
from surface_utils import load_surface
from pycortex_utils import get_rois, calculate_vertex_areas, data_from_rois, load_surface_pycortex
from maths_utils import linear_regression_surf, multipletests_surface

# Import data base
main_dir = '/Users/uriel/disks/meso_shared'
# main_dir = '/home/ulascombes//disks/meso_shared'
project_dir = 'RetinoMaps'
subjects = ['sub-01']
subject = 'sub-03'
format_ = 'fsnative'
tsv_dir ='{}/{}/derivatives/pp_data/{}/{}/prf/tsv'.format(main_dir, 
                                                                project_dir, 
                                                                subject,
                                                         format_)


with open('../../../settings.json') as f:
    json_s = f.read()
    analysis_info = json.loads(json_s)
rois = analysis_info['rois']


In [None]:
data = pd.read_table('{}/{}_css-all_derivatives.tsv'.format(tsv_dir,subject))
ecc_th = [0, 15]
size_th= [0.1, 20]
rsq_th = [0, 1]
pcm_th = [0,20]

# heat map

In [None]:
def gaus_2d(gauss_x, gauss_y, gauss_sd, screen_side, grain=200):
    """
    Generate 2D gaussian mesh
    
    Parameters
    ----------
    gauss_x : mean x gaussian parameter in dva (e.g. 1 dva)
    gauss_y : mean y gaussian parameter in dva (e.g. 1 dva)
    gauss_sd : sd gaussian parameter in dva (e.g. 1 dva)
    screen_side : mesh screen side (square) im dva (e.g. 20 dva from -10 to 10 dva)
    grain : grain resolution of the mesh in pixels (default = 100 pixels)
    
    Returns
    -------
    x : linspace x of the mesh
    y : linspace x of the mesh
    z : mesh_z values (to plot)
    
    """
    x = np.linspace(-screen_side/2, screen_side/2, grain)
    y = np.linspace(-screen_side/2, screen_side/2, grain)
    mesh_x, mesh_y = np.meshgrid(x,y) 
    
    gauss_z = 1./(2.*np.pi*gauss_sd*gauss_sd)*np.exp(-((mesh_x-gauss_x)**2./(2.*gauss_sd**2.)+(mesh_y-gauss_y)**2./(2.*gauss_sd**2.)))
    return x, y, gauss_z

In [None]:
colormap_dict = {'V1': (243, 231, 155),
                 'V2': (250, 196, 132),
                 'V3': (248, 160, 126),
                 'V3AB': (235, 127, 134),
                 'LO': (150, 0, 90), 
                 'VO': (0, 0, 200),
                 'hMT+': (0, 25, 255),
                 'iIPS': (0, 152, 255),
                 'sIPS': (44, 255, 150),
                 'iPCS': (151, 255, 0),
                 'sPCS': (255, 234, 0),
                 'mPCS': (255, 111, 0)
                }
roi_colors = ['rgb({},{},{})'.format(*rgb) for rgb in colormap_dict.values()]
# grain = [50,100,150,200,250,300,350,400,450,500]
grain =50
line_width = 1
fig= make_subplots(rows=1,cols=12)
# General figure settings
template_specs = dict(axes_color="rgba(0, 0, 0, 1)",
                      axes_width=2,
                      axes_font_size=15,
                      bg_col="rgba(255, 255, 255, 1)",
                      font='Arial',
                      title_font_size=15,
                      plot_width=1.5)

# General figure settings
fig_template = plotly_template(template_specs)
for i, roi in enumerate(rois) :
    df_roi = data.loc[data.roi == roi].reset_index()
    gauss_z_tot = np.zeros((grain,grain)) 

    for vert in range(len(df_roi)):
        x, y, gauss_z = gaus_2d(gauss_x=df_roi.prf_x[vert],  
                            gauss_y=df_roi.prf_y[vert], 
                            gauss_sd=df_roi.prf_size[vert], 
                            screen_side=30, 
                            grain=grain)

        gauss_z_tot += gauss_z * df_roi.prf_loo_r2[vert]

    gauss_z_tot = (gauss_z_tot-gauss_z_tot.min())/(gauss_z_tot.max()-gauss_z_tot.min())





    fig.add_trace(go.Contour(x=x, 
                             y=y,
                             z=gauss_z_tot, 
                             colorscale='hot', 
                             showscale=False, 
                             contours_coloring='lines',
                             line_width=0.5,
                             line_smoothing=1
                            ),row=1, col=i+1)

    # fig.add_trace(go.Scatter(x=[0],
    #                      y=[0],
    #                      mode='markers',
    #                      marker=dict(color=roi_colors[i], symbol='cross', size=line_width)
    #                     ),row=1, col=i+1)


    fig.add_trace(go.Scatter(x=[0,0],
                             y=[-15,15],
                             mode='lines',
                             line=dict(dash='dot',color=roi_colors[i], width=line_width)
                            ),row=1, col=i+1)

    fig.add_trace(go.Scatter(x=[-15,15],
                             y=[0,0],
                             mode='lines',
                             line=dict(dash='dot',color=roi_colors[i], width=line_width)
                            ),row=1, col=i+1)

    fig.add_shape(type="rect", 
                  x0=-10, 
                  y0=-10, 
                  x1=10, 
                  y1=10, 
                  line=dict(dash='dot',color=roi_colors[i], width=line_width),
                  row=1, col=i+1)


fig.update_xaxes(color= ('rgba(255,255,255,0)'))
fig.update_yaxes(color= ('rgba(255,255,255,0)'))

fig.update_layout(height=200, 
                  width=1200, 
                  showlegend=False,
                  template=fig_template,
                  margin_l=100, 
                  margin_r=50, 
                  margin_t=50, 
                  margin_b=50)



fig.show()

In [None]:
fig.write_image("/Users/uriel/Downloads/{}_violins5.pdf".format(subject))

In [None]:
roi = 'sPCS'
grain=50 
line_width = 1
fig= go.Figure()

df_roi = data.loc[data.roi == roi].reset_index()
df_roi=df_roi.dropna().reset_index()
gauss_z_tot = np.zeros((grain,grain)) 

for vert in range(len(df_roi)):
    x, y, gauss_z = gaus_2d(gauss_x=df_roi.prf_x[vert],  
                        gauss_y=df_roi.prf_y[vert], 
                        gauss_sd=df_roi.prf_size[vert], 
                        screen_side=30, 
                        grain=grain)

    gauss_z_tot += gauss_z * df_roi.prf_loo_r2[vert]

gauss_z_tot = (gauss_z_tot-gauss_z_tot.min())/(gauss_z_tot.max()-gauss_z_tot.min())





fig.add_trace(go.Contour(x=x, 
                         y=y,
                         z=gauss_z_tot, 
                         colorscale='hot', 
                         showscale=False, 
                         contours_coloring='lines',
                         line_width=0.5,
                         line_smoothing=0.85
                        ))

fig.add_trace(go.Scatter(x=[0,0],
                         y=[-15,15],
                         mode='lines',
                         line=dict(dash='dot',color=roi_colors[i], width=line_width)
                        ))

fig.add_trace(go.Scatter(x=[-15,15],
                         y=[0,0],
                         mode='lines',
                         line=dict(dash='dot',color=roi_colors[i], width=line_width)
                        ))

fig.add_shape(type="rect", 
              x0=-10, 
              y0=-10, 
              x1=10, 
              y1=10, 
              line=dict(dash='dot',color=roi_colors[i], width=line_width))


fig.update_xaxes(color= ('rgba(255,255,255,0)'))
fig.update_yaxes(color= ('rgba(255,255,255,0)'))

fig.update_layout(height=800, 
                  width=800, 
                  showlegend=False,
                  template=fig_template,
                  margin_l=100, 
                  margin_r=50, 
                  margin_t=50, 
                  margin_b=50)
fig.show()

In [None]:
 x, y, gauss_z = gaus_2d(gauss_x=df_roi.prf_x[vert],  
                        gauss_y=df_roi.prf_y[vert], 
                        gauss_sd=df_roi.prf_size[vert], 
                        screen_side=30, 
                        grain=grain)

In [None]:
for vert in range(len(df_roi)):
    x, y, gauss_z = gaus_2d(gauss_x=df_roi.prf_x[vert],  
                        gauss_y=df_roi.prf_y[vert], 
                        gauss_sd=df_roi.prf_size[vert], 
                        screen_side=30, 
                        grain=grain)

    gauss_z_tot += gauss_z * df_roi.prf_loo_r2[vert]

In [None]:
df_roi.dropna()
np.where(df_roi.prf_y.isna())

In [None]:
fig.write_image("/Users/uriel/Downloads/{}_violins5.pdf".format(subject))

# figs

In [None]:
fig = prf_violins_plot(data, subject,fig_height=1080, fig_width=1920, ecc_th=ecc_th, size_th=size_th, rsq_th=rsq_th, pcm_th=pcm_th)

In [None]:
fig.show()
# fig.write_image("/Users/uriel/Downloads/{}_violins5.pdf".format(subject))

In [None]:
fig2 = prf_ecc_size_plot(data, subject, fig_height=400, fig_width=800, ecc_th=ecc_th, size_th=size_th, rsq_th=rsq_th)

In [None]:
fig2.show()

In [None]:
figures, hemis = prf_polar_plot(data, subject, fig_height=300, fig_width=1920, ecc_th=ecc_th, size_th=size_th, rsq_th=rsq_th)

for i, (figure, hemi) in enumerate(zip(figures, hemis), start=1):

    figure.write_image('/Users/uriel/Downloads/{}_subplot_polar_{}.pdf'.format(subject,hemi))


In [None]:
fig5 = prf_contralaterality_plot(data, subject, fig_height=300, fig_width=1920, ecc_th=ecc_th, size_th=size_th, rsq_th=rsq_th)

In [None]:
fig5.show()

# fi6

In [None]:
fig6 = prf_ecc_pcm_plot(data, subject, fig_height=400, fig_width=800, ecc_th=ecc_th, pcm_th=pcm_th, rsq_th=rsq_th)

In [None]:
fig6.show()

In [None]:
fig6.write_image("/Users/uriel/Downloads/{}_pcm.pdf".format(subject))

# Fig 7

In [None]:
fig7 = categories_proportions_roi_plot(data, subject, fig_height=300, fig_width=1920)

In [None]:
fig7.show()

In [None]:
fig7.write_image("/Users/uriel/Downloads/{}percentage.pdf".format(subject))

In [None]:
fig8 = surface_rois_categories_plot(data, subject, fig_height=1080, fig_width=1920)

In [None]:
fig8.show()

In [None]:
fig9 = surface_rois_all_categories_plot(data, subject, fig_height=1080, fig_width=1920,)


In [None]:
fig9.show()