# Cortical Magnification

## Introduction

### Dependencies

In [None]:
import os, sys, pimms, pandas, warnings
from pathlib import Path
from functools import reduce, partial

import numpy as np
import scipy as sp
import nibabel as nib
import neuropythy as ny

import matplotlib as mpl
import matplotlib.pyplot as plt
import ipyvolume as ipv
import torch

### Dependency Initialization

In [None]:
# Additional matplotlib preferences:
font_data = {'family':'sans-serif',
             'sans-serif':['HelveticaNeue', 'Helvetica', 'Arial'],
             'size': 10,
             'weight': 'light'}
mpl.rc('font',**font_data)
# we want relatively high-res images, especially when saving to disk.
mpl.rcParams['figure.dpi'] = 72*2
mpl.rcParams['savefig.dpi'] = 72*4
mpl.rcParams['axes.unicode_minus'] = False

### Data Loading

In [None]:
def subdata(sid, h,
            vdpath='/data/crcns2021/results/proc/labels/mean',
            cachepath='/data/crowding/hcp-cmag',
            overwrite=False):
    cachefile = Path(cachepath) / f'{h}.{sid}.npy'
    if overwrite or not cachefile.is_file():
        sub = ny.data['hcp_lines'].subjects[sid]
        hem = sub.hemis[h]
        lbl = np.array(hem.prop('visual_area'))
        vdpath = Path(vdpath) / str(sid)
        vpath = vdpath / f'{h}.ventral_label.mgz'
        dpath = vdpath / f'{h}.dorsal_label.mgz'
        if not vpath.is_file():
            raise RuntimeError(f"ventral file not found for subject {sid}/{h}")
        if not dpath.is_file():
            raise RuntimeError(f"dorsal file not found for subject {sid}/{h}")
        vlbl = ny.load(str(vpath))
        dlbl = ny.load(str(dpath))
        ii = (lbl == 0)
        lbl[ii] = vlbl[ii]
        ii = (lbl == 0)
        lbl[ii] = dlbl[ii]
        nz = (lbl > 0)
        x = hem.prop('prf_x')
        y = hem.prop('prf_y')
        r = hem.prop('prf_radius')
        t = hem.prop('prf_polar_angle')
        e = hem.prop('prf_eccentricity')
        w = hem.prop('prf_variance_explained')
        a = hem.prop('midgray_surface_area')
        rows = np.stack([lbl, x, y, r, t, e, a, w], axis=0, dtype=np.float32)
        dat = rows[:, nz]
        np.save(cachefile, dat)
    rows = np.load(cachefile)
    (lbl, x, y, r, t, e, a, w) = rows
    return dict(
        label=lbl,
        x=x, y=y, sigma=r,
        polar_angle=t, eccentricity=e,
        surface_area=a,
        cod=w)

## Fitting Functions

### Fitting C.Mag. Models

#### Method of Cumulative Surface Area

In [None]:
def fit_cmag_cumecc(prf_x, prf_y, surface_areas, formfn, params0,
                    method=None,
                    lossfn='mse',
                    weights=None,
                    argtx=None):
    from scipy.optimize import minimize
    prf_x = np.asarray(prf_x)
    prf_y = np.asarray(prf_y)
    sarea = np.asarray(surface_areas)
    eccen = np.hypot(prf_x, prf_y)
    ii = np.argsort(eccen)
    sarea = sarea[ii]
    eccen = eccen[ii]
    cumsa = np.cumsum(sarea)
    if argtx is None:
        argtx = (lambda a:a, lambda a:a)
    if lossfn == 'rss':
        def lossfn(gold, pred):
            return np.sum((gold - pred)**2)
    elif lossfn == 'mse':
        if weights is None:
            def lossfn(gold, pred):
                return np.mean((gold - pred)**2)
        else:
            wsum = np.sum(weights)
            def lossfn(gold, pred):
                return np.sum(weights * (gold - pred)**2) / wsum
    def stepfn(params):
        return lossfn(cumsa, formfn(eccen, *argtx[1](params)))
    params0 = argtx[0](params0)
    r = minimize(stepfn, params0, method=method)
    r.x = argtx[1](r.x)
    return r

def fitall_cmag_cumecc(data, formfn, params0,
                       method=None,
                       lossfn='mse',
                       weights=None,
                       argtx=None,
                       filter=None,
                       labels=None):
    kw = dict(method=method, lossfn=lossfn, argtx=argtx)
    if not filter:
        filter = lambda dat: True
    if labels is None:
        labels = np.arange(1,11)
    result = {}
    for (sid, sdata) in data.items():
        sres = []
        for hdata in sdata:
            hres = []
            for lbl in labels:
                ii = filter(hdata) & (hdata['label'] == lbl)
                if np.sum(ii) < 5:
                    r = None
                else:
                    weights = None if not weights else hdata[weights][ii]
                    r = fit_cmag_cumecc(
                        hdata['x'][ii],
                        hdata['y'][ii],
                        hdata['surface_area'][ii],
                        formfn,
                        params0,
                        weights=weights,
                        **kw)
                hres.append(r)
            sres.append(tuple(hres))
        result[sid] = tuple(sres)
    return result

def filt_base(subdat, maxecc=7):
    return (subdat['eccentricity'] < maxecc)
def filt_wedge(subdat, minangle, maxangle):
    ang = subdat['polar_angle']
    return filt_base(subdat) & (ang >= minangle) & (ang <= maxangle)
def filt_ring(subdat, minecc, maxecc):
    ecc = subdat['eccentricity']
    return filt_base(subdat) & (ecc >= minecc) & (ecc <= maxecc)
def filt_sect(subdat, minang, maxang, minecc, maxecc):
    return (
        filt_base(subdat) &
        filt_wedge(subdat, minang, maxang) & 
        filt_ring(subdat, minecc, maxecc))

### C.Mag Model Forms

#### Horton & Hoyt Functions

In [None]:
def HH91(x, a=17.3, b=0.75):
    return (a / (x + b))**2

def HH91_integral(x, a=17.3, b=0.75):
    xb = x + b
    return a**2 * np.pi * (np.log(xb / b) - x / xb)

def HH91_c1(totalarea, b=0.75, maxecc=7):
    mb = maxecc + b
    return np.sqrt(totalarea / np.pi / (np.log(mb / b) - maxecc/mb))

HH91_argtx = (np.sqrt, lambda x: np.array(x)**2)

#### Beta Functions

In [None]:
from scipy.stats import beta

def beta_form(x, a, b, maxecc=7):
    return beta.cdf(x / maxecc, a, b)

def beta_loss(gold, pred):
    return np.mean((gold/np.max(gold) - pred)**2)

beta_argtx = (np.sqrt, lambda x: np.array(x)**2)

## Loading Data

In [None]:
# The subject IDs.
sids = np.setdiff1d(
    ny.data['hcp_lines'].subject_list,
    [r[0] for r in ny.data['hcp_lines'].exclusions if r[0] != 'mean'])

# The data for each subject.
data = {
    sid: (subdata(sid, 'lh'), subdata(sid, 'rh'))
    for sid in sids}

# The HCP visual areas have a max eccentricity of about 7°.
maxecc = 7

## Fitting

In [None]:
# Fitting beta functions; should take ~2-3 minutes to run.
#
# To fit only wedges (like the upper vertical meridian), we can use the
#   filter=lambda dat: filt_wedge(dat, -15, 15).
# For lower vertical this would be more like:
#   filter=lambda dat: filt_wedge(dat, 165, 180) | filt_wedge(dat, -180, -165)
#
# Keep in mind that the polar_angle data uses 0° as the upper vertical meridian,
# +90° as the right horizontal meridian, -90° as the left horizontal meridian,
# and ±180° as the lower vertical meridian.

fits = fitall_cmag_cumecc(
    data, beta_form, [1, 3],
    lossfn=beta_loss,
    filter=filt_base,
    argtx=beta_argtx)

### Plotting a Subject's Fits

In [None]:
sid = 111312

(fig, axs) = plt.subplots(4,2, figsize=(7,7), dpi=288, sharex=True, sharey=True)
fig.subplots_adjust(0,0,1,1,0.15,0.1)
subdat = data[sid]
subfit = fits[sid]

for (ii,axcol) in enumerate(axs.T):
    hdat = subdat[ii]
    hfit = subfit[ii]
    for (lbl,fit,ax) in zip([1,2,3,4], hfit, axcol):
        ii = hdat['label'] == lbl
        ecc = hdat['eccentricity'][ii]
        sar = hdat['surface_area'][ii]
        ii = np.argsort(ecc)
        cum = np.cumsum(sar[ii])
        ecc = ecc[ii]
        ax.plot(ecc, cum, 'k-', lw=0.5)
        pre = cum[-1] * beta_form(ecc, *fit.x)
        ax.plot(ecc, pre, 'r-', lw=0.5)
        ax.fill_between(ecc, cum, pre, color='r', alpha=0.2)

for ax in axs.flat:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim([0,8])
for ax in axs[:,0]:
    ax.set_ylabel(r'Cum. Surface Area [cm$^2$]')
for ax in axs[-1]:
    ax.set_xlabel('Eccentricity [deg]')

plt.show()

### Plotting Beta Parameters for all Subjects

In [None]:
allparams = np.array(
    [tuple(
         np.stack([fit.x for fit in hfit])
         for hfit in sfit)
     for sfit in fits.values()])

(fig,axs) = plt.subplots(4,2, figsize=(4,8), dpi=288, sharex=True, sharey=True)
fig.subplots_adjust(0,0,1,1,0.15,0.1)

for (hii,axcol) in enumerate(axs.T):
    for (lbl,ax) in zip([1,2,3,4], axcol):
        (x,y) = allparams[:, hii, lbl-1, :].T
        ax.plot(x, y, 'ko', ms=0.5, alpha=0.5)

for ax in axs.flat:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim([0,2])
    ax.set_ylim([0,4])
for ax in axs[:,0]:
    ax.set_ylabel(r'$\beta$')
for ax in axs[-1]:
    ax.set_xlabel(r'$\alpha$')

plt.show()