In [None]:
# %matplotlib qt
# plot_size = 5
%matplotlib notebook
plot_size = 3
fig_basenum = 30

In [None]:
from retinotopic_helpers import *
import matplotlib.pyplot as plt
from functools import reduce
from operator import add
from mne.io import read_raw_fif

In [None]:
from mne_viz_circle import plot_connectivity_circle

In [None]:
ch_type = 'meg'  # may make little sense, since scale of 'grad' dominates!
# ch_type = 'grad'
# ch_type = 'mag'

regions = ['V1', 'V2', 'V3']

In [None]:
data_path = '/Users/cjb/projects/CFF_Retinotopy/scratch'
subject = '030_WAH'
subjects_dir = join(data_path, 'fs_subjects_dir')
fname_fwd = join(data_path, 'restricted_1LBEM-fwd.fif')
fname_raw = join(data_path, 'VS1_cropped1sec.fif')
info = read_raw_fif(fname_raw, preload=False).info

In [None]:
if ch_type == 'meg':  # re-scale mags
    coil_scale, mag_scale = get_mag_scaling_factor(info)

## TODO: where to use mag scaling?

In [None]:
fwd = read_forward_solution(fname_fwd)

In [None]:
# change to surface coords
fwd = prepare_gain(fwd, ch_type=ch_type)

In [None]:
# get all RM-labels
labels = get_RM_labels(subject, subjects_dir=subjects_dir)

In [None]:
# DEBUG
# ii = 10
# lablist = [labels['V1']['lh'][ii], labels['V1']['rh'][ii],
#            labels['V1']['lh'][ii] + labels['V1']['rh'][ii]]

# patch_sensitivity(fwd, lablist)

## Plot some

In [None]:
regions = labels.keys()
alllabs = []
for reg in regions:
    for hemi in ['lh', 'rh']:
        alllabs.extend(labels[reg][hemi])

In [None]:
# For normalization!
# takes a few secs, but hard to do without?
all_patch_sens = patch_sensitivity(fwd, alllabs)

In [None]:
# Each region separately
plot_vals = ['mean_sens', 'ci', 'w_sens', 'tot_sens']
plot_cms  = [cm.hot, cm.bone_r, cm.inferno, cm.inferno]

fig = plt.figure(num=fig_basenum + 1,
                 figsize = (len(regions) * plot_size,
                            len(plot_vals) * plot_size))
fig.clear()
for irow, (pval, pcm) in enumerate(zip(plot_vals, plot_cms)):
    maxvals = []
    for icol, reg in enumerate(regions):

        polax_bardata_reset()

        ax = plt.subplot(len(plot_vals), len(labels.keys()),
                         irow * len(labels.keys()) + icol + 1,
                         projection='polar')

        for hemi_bit, hemi in enumerate(('lh', 'rh')):
            patch_sens = patch_sensitivity(fwd, labels[reg][hemi])
            for il, lab in enumerate(labels[reg][hemi]):
                if 'None' in lab.name:
                    continue
                ecc_ind, ang_ind = get_ecc_ang_inds(lab.name)

                polax_bardata_append(hemi_bit, ecc_ind, ang_ind,
                                     patch_sens[pval][il])
                
        maxval = np.max(all_patch_sens[pval]) if not pval == 'ci' else 1
        polax_bardata_setcols(ax, cmap=pcm, normalizer=maxval)
        ax.set_title('{reg:s} - {val:s}'.format(reg=reg, val=pval))

In [None]:
# from mpl_toolkits.axes_grid1 import make_axes_locatable
# sm = plt.cm.ScalarMappable(cmap=pcm, norm=plt.Normalize(vmin=0, vmax=1))
# # fake up the array of the scalar mappable. Urgh...
# sm._A = []
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
# plt.colorbar(sm, cax=cax)

In [None]:
# Combine V1, V2 and V3
plot_vals = ['mean_sens', 'ci', 'w_sens', 'tot_sens']
plot_cms  = [cm.hot, cm.bone_r, cm.inferno, cm.inferno]

regions = labels.keys()

fig = plt.figure(num=fig_basenum + 2,
                 figsize = (plot_size, len(plot_vals) * plot_size))
fig.clear()
for irow, (pval, pcm) in enumerate(zip(plot_vals, plot_cms)):
    for hemi_bit, hemi in enumerate(('lh', 'rh')):
        
        # list of length 3, each element being a list of all locations
        reglabs = [labels[r][hemi] for r in regions]
        # list of lists of length 3: V1, V2 and V3 for each location
        lablist = [[reglabs[jj][ii] for jj in range(len(reglabs))] 
                   for ii in range(len(reglabs[0]))]
        
        ax = plt.subplot(len(plot_vals), 1,
                         irow + 1,
                         projection='polar')

        for il, reglab in enumerate(lablist):
            
            # reduce the list of V1, V2 and V3 labels in the current
            # region to a single list by addition of indiv. labels
            lab = reduce(add, reglab)
            patch_sens = patch_sensitivity(fwd, lab)

            ecc_ind, ang_ind = get_ecc_ang_inds(lab.name)

            polax_bardata_append(hemi_bit, ecc_ind, ang_ind,
                                 patch_sens[pval][0])

        maxval = np.max(all_patch_sens[pval]) if not pval == 'ci' else 1
        polax_bardata_setcols(ax, cmap=pcm, normalizer=maxval)
        ax.set_title('{reg:s} - {val:s}'.format(reg='V1+V2+V3', val=pval))

## Total cross-cancellation between retinotopic locations

How much does the combined response of V1, V2 and V3 at one stimulus location (_e.g._, SSVEP/Fs) interact with that coming from a different location?

In [None]:
regions = ['V1', 'V2', 'V3']
regions = ['V1']
regions = ['V1', 'V2']

reglabs = []
for hemi in ('lh', 'rh'):
    # 2D list, n_row = n_regions, n_cols=n_locations
    these_labels = [labels[r][hemi] for r in regions]
    # list transpose!
    # n_row=n_locations, n_col=n_regions
    reglabs.extend([list(i) for i in zip(*these_labels)])

# if more than one regions, calculate their sum
for ri, rlabs in enumerate(reglabs):
    if len(rlabs) == 1:
        reglabs[ri] = rlabs[0]
    else:
        reglabs[ri] = reduce(add, rlabs)   

In [None]:
# cross-cancellation 
plot_vals = ['w_sens', 'tot_sens']
plot_cms  = [cm.inferno, cm.inferno]

sens_measure = 'tot_sens'

n_hemi = len(radii) * len(theta_starts_deg)
D = np.empty((2 * n_hemi, 2 * n_hemi), dtype=np.float)
D.fill(np.nan)

DcosT = np.empty((2 * n_hemi, 2 * n_hemi), dtype=np.float)
DcosT.fill(np.nan)

self_sens = np.empty(2 * n_hemi)

node_angles = np.empty(2 * n_hemi)
node_radii = np.empty(2 * n_hemi)
node_heights = np.empty(2 * n_hemi)
# Need to prepare an empty list to order maintained!
label_names = ['' for _ in range(2 * n_hemi)]

# for icol, (pval, pcm) in enumerate(zip(plot_vals, plot_cms)):
for i_reg, reflab in enumerate(reglabs):
    ref_sens = patch_sensitivity(fwd, reflab)
    r_ecc_ind, r_ang_ind = get_ecc_ang_inds(reflab.name)
    ref_hidx = 0 if reflab.hemi == 'lh' else 1

    if ref_hidx == 0:
        circ_na = 90 + theta_starts_deg[r_ang_ind] + \
                    (r_ecc_ind + 1) * wedge_width / (len(radii) + 1)
    else:
        circ_na = 90 + 360 - (theta_starts_deg[r_ang_ind] + \
                    (r_ecc_ind  + 1) * wedge_width / (len(radii) + 1))

#     print('{}: ecc {}, ang {}'.format(i_reg, r_ecc_ind, r_ang_ind))
    
    ridx = ref_hidx * n_hemi + len(radii)*r_ang_ind + r_ecc_ind

    node_angles[ridx] = circ_na
    node_heights[ridx] = r_ecc_ind + 1
    node_radii[ridx] = 10. + r_ecc_ind
    label_names[ridx] = reflab.name
    self_sens[ridx] = ref_sens[sens_measure][0]

    for trglab in reglabs[i_reg + 1:]:  # won't run for last reflab
        trg_sens = patch_sensitivity(fwd, trglab)
        ecc_ind, ang_ind = get_ecc_ang_inds(trglab.name)
        cmb_sens = patch_sensitivity(fwd, reflab + trglab)

        trg_hidx = 0 if trglab.hemi == 'lh' else 1

        cidx = trg_hidx * n_hemi + len(radii)*ang_ind + ecc_ind

        D[ridx, cidx] = cmb_sens[sens_measure] / \
            (ref_sens[sens_measure] + trg_sens[sens_measure])
        D[cidx, ridx] = D[ridx, cidx]
        
        DcosT[ridx, cidx] = np.dot(ref_sens['sigvec'][:, 0],
                                   trg_sens['sigvec'][:, 0]) / \
            (ref_sens['tot_sens'] * trg_sens['tot_sens'])
        DcosT[cidx, ridx] = DcosT[ridx, cidx]

In [None]:
import seaborn as sns

In [None]:
fig = plt.figure(num=fig_basenum + 20, figsize=(2 * plot_size, 2 * plot_size))
fig.clear()
ax = plt.subplot(111)
sns.distplot(DcosT[~np.isnan(DcosT)], ax=ax)

In [None]:
fig = plt.figure(num=fig_basenum + 3, figsize=(3 * plot_size, 1.5 * plot_size))
fig.clear()

ax = plt.subplot(121)
cmap = sns.cubehelix_palette(8, as_cmap=True)
# cmap.set_bad('black', 1.)
cmap.set_bad('white', 1.)

vmin, vmax = np.nanmin(1/D), np.nanmax(1/D)
# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(1/D, cmap=cmap, vmin=vmin, vmax=vmax,
            square=True, xticklabels=5, yticklabels=5,
            linewidths=.5, cbar_kws={"shrink": .5}, ax=ax)

ax = plt.subplot(122)
cmap_cosT = sns.diverging_palette(220, 10, sep=80, n=8, as_cmap=True)
cmap_cosT.set_bad('white', 1.)

vmin, vmax = np.nanmin(DcosT), np.nanmax(DcosT)
# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(DcosT, cmap=cmap_cosT, vmin=vmin, vmax=vmax,
            square=True, xticklabels=5, yticklabels=5,
            linewidths=.5, cbar_kws={"shrink": .5}, ax=ax)

In [None]:
# For development / debugging
# import mne_viz_circle
# import importlib
# importlib.reload(mne_viz_circle)
# plot_connectivity_circle = mne_viz_circle.plot_connectivity_circle

In [None]:
fig = plt.figure(num=fig_basenum + 4, figsize=(3 * plot_size, 3 * plot_size))
fig.clear()

node_colors = list()
normalizer = np.max(self_sens)
for ss in self_sens:
    node_colors.append(polax_get_colour(ss, cmap=cm.inferno,
                                        normalizer=normalizer))
# node_heights = None  # can't get this "right"
plot_connectivity_circle(1/D, label_names, n_lines=10, node_radii=node_radii,
                         node_angles=node_angles, node_heights=node_heights,
                         title='Cross-cancellation', fig=fig,
                         facecolor='white', textcolor='black',
                         colormap=cm.inferno_r, node_linewidth=1.0,
                         linewidth=4., node_colors=node_colors)

In [None]:
fig = plt.figure(num=fig_basenum + 5, figsize=(3 * plot_size, 3 * plot_size))
fig.clear()

node_colors = list()
normalizer = np.max(self_sens)
for ss in self_sens:
    node_colors.append(polax_get_colour(ss, cmap=cm.inferno,
                                        normalizer=normalizer))
# node_heights = None  # can't get this "right"
plot_connectivity_circle(DcosT, label_names, n_lines=10, node_radii=node_radii,
                         node_angles=node_angles, node_heights=node_heights,
                         title='Cross-cancellation', fig=fig,
                         facecolor='white', textcolor='black',
                         colormap=cmap_cosT, node_linewidth=1.0,
                         linewidth=4., node_colors=node_colors)

In [None]:
fwd_fixed = convert_forward_solution(fwd, force_fixed=True, copy=True)

In [None]:
# For development / debugging
# import retinotopic_helpers
# import importlib
# importlib.reload(retinotopic_helpers)
# _stc_from_labels = retinotopic_helpers._stc_from_labels
# plot_stc_topomap = retinotopic_helpers.plot_stc_topomap
# plot_region_interaction_topomap = retinotopic_helpers.plot_region_interaction_topomap
# get_2D_connectivity_matrix_value = retinotopic_helpers.get_2D_connectivity_matrix_value

In [None]:
fig = plt.figure(num=fig_basenum + 6, figsize=(3 * plot_size, 1 * plot_size))
fig.clear()

l_one = find_labels_in_list(labels['V1']['lh'], '144')[0]
l_two = find_labels_in_list(labels['V1']['rh'], '160')[0]

plot_region_interaction_topomap([l_one, l_two],
                                fwd_fixed, info, fig=fig)

canc = get_2D_connectivity_matrix_value(1/D, l_one, l_two)
cosT = get_2D_connectivity_matrix_value(DcosT, l_one, l_two)
print('Cancellation:\t{}\nTheta (deg):\t{}'.format(canc,
                                                   180 / np.pi * np.arccos(cosT)))

In [None]:
fig = plt.figure(num=fig_basenum + 7, figsize=(3 * plot_size, 1 * plot_size))
fig.clear()

l_one = find_labels_in_list(labels['V1']['lh'], '136')[0]
l_two = find_labels_in_list(labels['V1']['lh'], '260')[0]

plot_region_interaction_topomap([l_one, l_two],
                                fwd_fixed, info, fig=fig)
canc = get_2D_connectivity_matrix_value(1/D, l_one, l_two)
cosT = get_2D_connectivity_matrix_value(DcosT, l_one, l_two)
print('Cancellation:\t{}\nTheta (deg):\t{}'.format(canc,
                                                   180 / np.pi * np.arccos(cosT)))

In [None]:
fig = plt.figure(num=fig_basenum + 8, figsize=(3 * plot_size, 1 * plot_size))
fig.clear()

l_one = find_labels_in_list(labels['V1']['lh'], '66')[0]
l_two = find_labels_in_list(labels['V1']['rh'], '80')[0]

plot_region_interaction_topomap([l_one, l_two],
                                fwd_fixed, info, fig=fig)
canc = get_2D_connectivity_matrix_value(1/D, l_one, l_two)
cosT = get_2D_connectivity_matrix_value(DcosT, l_one, l_two)
print('Cancellation:\t{}\nTheta (deg):\t{}'.format(canc,
                                                   180 / np.pi * np.arccos(cosT)))

In [None]:
fig = plt.figure(num=fig_basenum + 9, figsize=(3 * plot_size, 1 * plot_size))
fig.clear()

l_one = find_labels_in_list(labels['V1']['lh'], '129')[0]
l_two = find_labels_in_list(labels['V1']['lh'], '130')[0]

plot_region_interaction_topomap([l_one, l_two],
                                fwd_fixed, info, fig=fig)
canc = get_2D_connectivity_matrix_value(1/D, l_one, l_two)
cosT = get_2D_connectivity_matrix_value(DcosT, l_one, l_two)
print('Cancellation:\t{}\nTheta (deg):\t{}'.format(canc,
                                                   180 / np.pi * np.arccos(cosT)))

In [None]:
fig = plt.figure(num=fig_basenum + 10, figsize=(3 * plot_size, 1 * plot_size))
fig.clear()

l_one = find_labels_in_list(labels['V1']['rh'], '129')[0]
l_two = find_labels_in_list(labels['V1']['rh'], '130')[0]

plot_region_interaction_topomap([l_one, l_two],
                                fwd_fixed, info, fig=fig)
canc = get_2D_connectivity_matrix_value(1/D, l_one, l_two)
cosT = get_2D_connectivity_matrix_value(DcosT, l_one, l_two)
print('Cancellation:\t{}\nTheta (deg):\t{}'.format(canc,
                                                   180 / np.pi * np.arccos(cosT)))

In [None]:
labels['V1']['lh']