<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Parameters" data-toc-modified-id="Parameters-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Parameters</a></span></li><li><span><a href="#Load-data" data-toc-modified-id="Load-data-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Load data</a></span></li><li><span><a href="#Create-pycortex-vertex-objects" data-toc-modified-id="Create-pycortex-vertex-objects-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Create pycortex vertex objects</a></span></li><li><span><a href="#Create-figures-etc" data-toc-modified-id="Create-figures-etc-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Create figures etc</a></span></li></ul></div>

In [None]:
import cortex as cx
import platform
from pathlib import Path
import json
import nilearn.surface as surface
import pandas as pd
import numpy as np
import os, shutil, shlex, subprocess
import h5py
import matplotlib.pyplot as plt
from matplotlib.colors import hsv_to_rgb
from matplotlib import cm

# needs prfpy for prf spatial plot, and for model predictions at some point
from prfpy.rf import gauss2D_iso_cart
from prfpy.model import Iso2DGaussianModel

from prf_expect.utils import io

In [None]:
from shutil import which

print(which("inkscape"))
os.system("inkscape --version")

In [None]:
print("cortex.database.default_filestore: {}".format(cx.database.default_filestore))
print("cortex.options.usercfg: {}".format(cx.options.usercfg))

### Parameters

In [None]:
hemis = ["L", "R"]
hemi_names = ["left", "right"]
model = "norm"
prf_par_names = ["x", "y", "prf_size", "prf_ampl", "bold_bsl", "r2"]

space = "fsaverage"
subject = "sub-002"
need_yhat = True
need_prf = True

### Path

In [None]:
settings = io.load_settings()
prf_viz_dir = (
    Path(settings["general"]["data_dir"], "data")
    / "derivatives"
    / "prf_data"
    / subject
    / "ses-1"
)

tc_L = str(
    prf_viz_dir
    / "cut_and_averaged"
    / f"{subject}_ses-1_task-pRF_space-fsaverage_hemi-L_desc-denoised_bold_psc_mean.npy"
)
tc_R = str(
    prf_viz_dir
    / "cut_and_averaged"
    / f"{subject}_ses-1_task-pRF_space-fsaverage_hemi-R_desc-denoised_bold_psc_mean.npy"
)
prf_pars_path = str(
    prf_viz_dir
    / "prf_fits"
    / "prf_params"
    / f"{subject}_ses-1_final-fit_space-fsaverage_model-{model}_stage-iter_desc-prf_params.tsv"
)
dm_fn = str(prf_viz_dir / "dms" / "dm_task-pRF_run-01.npy")
pred_fn = (
    prf_viz_dir
    / "prf_fits"
    / "prf_predictions"
    / f"{subject}_ses-1_task-pRF_final-fit_space-fsaverage_model-{model}_stage-iter_desc-prf_pred.npy"
)

### Setup figure

In [None]:
flatmap_height = 2048
full_figsize = (12, 8)

In [None]:
vf_extent = [-8, 8]
nr_vf_pix = 200
prf_space_x, prf_space_y = np.meshgrid(
    np.linspace(vf_extent[0], vf_extent[1], nr_vf_pix, endpoint=True),
    np.linspace(vf_extent[0], vf_extent[1], nr_vf_pix, endpoint=True),
)

### Load data
timecourses, prf parameters, and design matrix.

In [None]:
# tc data
tc_data = np.concatenate([np.load(tc_L).T, np.load(tc_R).T])

# prf data
if need_prf:
    prf_pars_df = pd.read_csv(prf_pars_path, sep="\t", header=0)

# prediction data
if need_yhat:
    pred_data = np.load(pred_fn)

# design matrix

design_matrix = np.load(dm_fn)
dm_guideline = design_matrix.sum((0, 1))
# scale dm_guideline to 0-5
dm_guideline = (
    5 * (dm_guideline - dm_guideline.min()) / (dm_guideline.max() - dm_guideline.min())
)

sos = design_matrix.sum((1, 2)) != 0

### Create pycortex vertex objects

In [None]:
pscmean_v = cx.Vertex(tc_data.mean(axis=1), subject=space, cmap="hsv")

In [None]:
prf_pars_df.keys()

In [None]:
if need_prf:
    angs_n = np.array(prf_pars_df["polar"])
    eccen = np.array(prf_pars_df["ecc"])
    rsq = np.array(prf_pars_df["r2"])
    thresh = 0.15
    rsq_mask = rsq > thresh
    prf_size = np.nan_to_num(np.array(prf_pars_df["prf_size"]))

    angs_n[rsq_mask < thresh] = np.nan
    prf_size[rsq_mask < thresh] = np.nan
    eccen[rsq_mask < thresh] = np.nan
    rsq_mask[rsq_mask < thresh] = np.nan

    polar_v = cx.Vertex2D(
        dim1=angs_n,
        dim2=rsq_mask,
        subject=space,
        cmap="Retinotopy_HSV_alpha",
        vmin=-np.pi,
        vmax=np.pi,
        vmin2=0,
        vmax2=0.5,
    )
    size_v = cx.Vertex2D(
        dim1=prf_size,
        dim2=rsq_mask,
        subject=space,
        cmap="hot_alpha",
        vmin=0,
        vmax=7.0,
        vmin2=0,
        vmax2=0.5,
    )
    eccen_v = cx.Vertex2D(
        dim1=eccen,
        dim2=rsq_mask,
        subject=space,
        cmap="spectral_alpha",
        vmin=0,
        vmax=7.0,
        vmin2=0,
        vmax2=0.5,
    )
    x_v = cx.Vertex2D(
        dim1=np.nan_to_num(prf_pars_df["x"]),
        dim2=rsq_mask,
        subject=space,
        cmap="seismic_alpha",
        vmin=-10,
        vmax=10.0,
        vmin2=0,
        vmax2=0.5,
    )
    y_v = cx.Vertex2D(
        dim1=np.nan_to_num(prf_pars_df["y"]),
        dim2=rsq_mask,
        subject=space,
        cmap="seismic_alpha",
        vmin=-10,
        vmax=10.0,
        vmin2=0,
        vmax2=0.5,
    )

###################################################################################################
###################################################################################################
#######
# create pycortex vars
#######
###################################################################################################
###################################################################################################

mask, extents = cx.quickflat.utils.get_flatmask(space, height=flatmap_height)
vc = cx.quickflat.utils._make_vertex_cache(space, height=flatmap_height)

mask_index = np.zeros(mask.shape)
mask_index[mask] = np.arange(mask.sum())

### Create figures etc

In [None]:
%matplotlib ipympl
full_fig = plt.figure(constrained_layout=True, figsize=full_figsize)
gs = full_fig.add_gridspec(3, 3)
flatmap_ax = full_fig.add_subplot(gs[:2, :])
timecourse_ax = full_fig.add_subplot(gs[2, :2])
prf_ax = full_fig.add_subplot(gs[2, 2])

flatmap_ax.set_title('flatmap')
timecourse_ax.set_title('timecourse')
prf_ax.set_title('prf')

###################################################################################################
###################################################################################################
#######
# redraw per-vertex data
#######
###################################################################################################
###################################################################################################

def redraw_vertex_plots(vertex, refresh):
    if refresh:
        timecourse_ax.clear()
    timecourse_ax.axhline(0, color='black', lw=0.25)
    timecourse_ax.plot(tc_data[vertex], "o", markersize=5)
    if need_yhat:
        timecourse_ax.plot(pred_data[vertex])
    # timecourse_ax.plot(dm_guideline, color='black', lw=0.5)
    # timecourse_ax.plot(sos, alpha=0.125, lw=3, color='gray')
    # timecourse_ax.plot(np.roll(sos,5), alpha=0.25, ls=':', lw=3, color='gray')
    timecourse_ax.set_xticks(range(0, tc_data.shape[1], 5)) # set major ticks every 5 TRs
    timecourse_ax.set_xlim(0, tc_data.shape[1])


    if need_prf:
        if prf_pars_df['prf_ampl'][vertex] != 0:
            # implement plotting of model predictions using the Iso2DGaussianModel class
            # at some point
            pass

        prf = gauss2D_iso_cart(prf_space_x,
                            prf_space_y,
                            [prf_pars_df['x'][vertex],
                                prf_pars_df['y'][vertex]],
                            prf_pars_df['prf_size'][vertex])
        prf_ax.clear()
        prf_ax.matshow(prf, extent=vf_extent+vf_extent, cmap='cubehelix')
        prf_ax.axvline(0, color='white', linestyle='dashed', lw=0.5)
        prf_ax.axhline(0, color='white', linestyle='dashed', lw=0.5)
        prf_ax.set_title(f"Vertex index: {vertex}\n" + 
                         f"x: {round(prf_pars_df['x'][vertex], 3)}, y: {round(prf_pars_df['y'][vertex], 3)}\n" + 
                         f"rsq: {round(prf_pars_df['r2'][vertex], 3)}")

def zoom_to_roi(axis, space, roi, hem, margin=10.0):
    roi_verts = cx.get_roi_verts(space, roi)[roi]
    roi_map = cx.Vertex.empty(space)
    roi_map.data[roi_verts] = 1

    (lflatpts, lpolys), (rflatpts, rpolys) = cx.db.get_surf(space, "flat",
                                                            nudge=True)
    sel_pts = dict(left=lflatpts, right=rflatpts)[hem]
    roi_pts = sel_pts[np.nonzero(getattr(roi_map, hem))[0], :2]

    xmin, ymin = roi_pts.min(0) - margin
    xmax, ymax = roi_pts.max(0) + margin
    print([xmin, xmax, ymin, ymax])
    axis.axis([xmin, xmax, ymin, ymax])

    return [xmin, xmax, ymin, ymax]

###################################################################################################
###################################################################################################
#######
# actual callback functions
#######
###################################################################################################
###################################################################################################

def onclick(event):
    if event.inaxes == flatmap_ax:
        xmin, xmax = flatmap_ax.get_xbound()
        ax_xrange = xmax-xmin
        ymin, ymax = flatmap_ax.get_ybound()
        ax_yrange = ymax-ymin

        rel_x = int(mask.shape[0] * (event.xdata-xmin)/ax_xrange)
        rel_y = int(mask.shape[1] * (event.ydata-ymin)/ax_yrange)
        clicked_pixel = (rel_x, rel_y)

        clicked_vertex = vc[int(
            mask_index[clicked_pixel[0], clicked_pixel[1]])]

        redraw_vertex_plots(clicked_vertex.indices[0], (event.key == 'shift'))
        plt.draw()


def onkey(event):
    # flatmap_ax.clear()
    if event.key == '1':  # polar angle
        cx.quickshow(polar_v, with_rois=True, with_curvature=True,
                     fig=flatmap_ax, with_colorbar=False)
        flatmap_ax.set_title('Polar Angle')
    elif event.key == '2':  # eccentricity
        cx.quickshow(eccen_v, with_rois=True, with_curvature=True,
                     fig=flatmap_ax, with_colorbar=False)
        flatmap_ax.set_title('Eccentricity')
    elif event.key == '3':  # polar angle
        cx.quickshow(size_v, with_rois=True, with_curvature=True,
                     fig=flatmap_ax, with_colorbar=False)
        flatmap_ax.set_title('pRF Size')
    elif event.key == '4':  # polar angle
        cx.quickshow(x_v, with_rois=True, with_curvature=True,
                     fig=flatmap_ax, with_colorbar=False)
        flatmap_ax.set_title('pRF X')
    elif event.key == '5':  # polar angle
        cx.quickshow(y_v, with_rois=True, with_curvature=True,
                     fig=flatmap_ax, with_colorbar=False)
        flatmap_ax.set_title('pRF Y')        
    plt.draw()
    
###################################################################################################
###################################################################################################
#######
# start
#######
###################################################################################################
###################################################################################################
# start with polar angle.
if need_prf:
    cx.quickshow(polar_v, with_rois=True, with_curvature=True,
                 fig=flatmap_ax, with_colorbar=False)
else:
    cx.quickshow(pscmean_v, with_rois=True, with_curvature=True,
                 fig=flatmap_ax, with_colorbar=False)

# new_bounds  = zoom_to_roi(axis=flatmap_ax, subject=subject,
#             roi='V2', hem='left', margin=10.0)

full_fig.canvas.mpl_connect('button_press_event', onclick)
full_fig.canvas.mpl_connect('key_press_event', onkey)
# plt.show()
# plt.ion()