In [None]:
import cortex
import numpy as np
import gc
import time
import matplotlib.pyplot as plt
from matplotlib.colors import hsv_to_rgb, to_rgb
import os
import nibabel as nib
from pathlib import Path
import matplotlib.patches as mpatches
import pandas as pd
import matplotlib.colors as mcolors
from nibabel.freesurfer.io import read_annot, write_annot


print(cortex.options.usercfg)
print(cortex.database.default_filestore)

# Setup

In [None]:
nsd_path = Path('/mnt/d/Datasets/NSD/')
freesurfer_path = nsd_path / 'nsddata/freesurfer/'

subjects = [f'subj0{i + 1}' for i in range(8)]

roi_names = {
    'prf-visualrois': ['V1', 'V2', 'V3', 'V4'],
    'floc-bodies': ['mTL-bodies', 'FBA-1', 'FBA-2', 'EBA'],
    'floc-faces': ['aTL-faces', 'mTL-faces', 'FFA-2', 'FFA-1', 'OFA'],
    'floc-places': ['RSC', 'PPA', 'OPA'],
    'floc-words': ['mTL-words', 'mfs-words', 'VWFA-2' 'VWFA-1', 'OWFA'], 
}


# Utilities

In [None]:
def rand_cmap(nlabels, type='bright', first_color_black=True, last_color_black=False, verbose=True):
    """
    Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks
    :param nlabels: Number of labels (size of colormap)
    :param type: 'bright' for strong colors, 'soft' for pastel colors
    :param first_color_black: Option to use first color as black, True or False
    :param last_color_black: Option to use last color as black, True or False
    :param verbose: Prints the number of labels and shows the colormap. True or False
    :return: colormap for matplotlib
    """
    from matplotlib.colors import LinearSegmentedColormap
    import colorsys
    import numpy as np


    if type not in ('bright', 'soft'):
        print ('Please choose "bright" or "soft" for type')
        return

    if verbose:
        print('Number of labels: ' + str(nlabels))

    # Generate color map for bright colors, based on hsv
    if type == 'bright':
        randHSVcolors = [(np.random.uniform(low=0.0, high=1),
                          np.random.uniform(low=0.2, high=1),
                          np.random.uniform(low=0.9, high=1)) for i in range(nlabels)]

        # Convert HSV list to RGB
        randRGBcolors = []
        for HSVcolor in randHSVcolors:
            randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2]))

        if first_color_black:
            randRGBcolors[0] = [0, 0, 0]

        if last_color_black:
            randRGBcolors[-1] = [0, 0, 0]

        random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)

    # Generate soft pastel colors, by limiting the RGB spectrum
    if type == 'soft':
        low = 0.6
        high = 0.95
        randRGBcolors = [(np.random.uniform(low=low, high=high),
                          np.random.uniform(low=low, high=high),
                          np.random.uniform(low=low, high=high)) for i in range(nlabels)]

        if first_color_black:
            randRGBcolors[0] = [0, 0, 0]

        if last_color_black:
            randRGBcolors[-1] = [0, 0, 0]
        random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)

    # Display colorbar
    if verbose:
        from matplotlib import colors, colorbar
        from matplotlib import pyplot as plt
        fig, ax = plt.subplots(1, 1, figsize=(15, 0.5))

        bounds = np.linspace(0, nlabels, nlabels + 1)
        norm = colors.BoundaryNorm(bounds, nlabels)

        cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None,
                                   boundaries=bounds, format='%1i', orientation=u'horizontal')

    return random_colormap

In [None]:
def import_flat(fs_subject, patch, hemis=['lh', 'rh'], cx_subject=None,
                flat_type='freesurfer', auto_overwrite=False,
                freesurfer_subject_dir=None, clean=True):
    """Imports a flat brain from freesurfer

    NOTE: This will delete the overlays.svg file for this subject, since THE
    FLATMAPS WILL CHANGE, as well as all cached information (e.g. old flatmap 
    boundaries, roi svg intermediate renders, etc). 

    Parameters
    ----------
    fs_subject : str
        Freesurfer subject name
    patch : str
        Name of flat.patch.3d file; e.g., "flattenv01"
    hemis : list
        List of hemispheres to import. Defaults to both hemispheres.
    cx_subject : str
        Pycortex subject name
    freesurfer_subject_dir : str
        directory for freesurfer subjects. None defaults to evironment variable
        $SUBJECTS_DIR
    clean : bool
        If True, the flat surface is cleaned to remove the disconnected polys.

    Returns
    -------
    """
    if not auto_overwrite:
        proceed = input(('Warning: This is intended to over-write .gii files storing\n'
                         'flatmap vertex locations for this subject, and will result\n'
                         'in deletion of the overlays.svg file and all cached info\n'
                         'for this subject (because flatmaps will fundamentally change).\n'
                         'Proceed? [y]/n: '))
        if proceed.lower() not in ['y', 'yes', '']:
            print(">>> Elected to quit rather than delete & overwrite files.")
            return

    if cx_subject is None:
        cx_subject = fs_subject
    surfs = os.path.join(cortex.database.default_filestore, cx_subject, "surfaces", "flat_{hemi}.gii")

    for hemi in hemis:
        if flat_type == 'freesurfer':
            pts, polys, _ = cortex.freesurfer.get_surf(fs_subject, hemi, "patch", patch+".flat", freesurfer_subject_dir=freesurfer_subject_dir)
            flat = pts
            # Reorder axes: X, Y, Z instead of Y, X, Z
            #flat = pts[:, [1, 0, 2]]
            # Flip Y axis upside down
            #flat[:, 1] = -flat[:, 1]
        elif flat_type == 'slim':
            flat_file = cortex.freesurfer.get_paths(fs_subject, hemi, type='slim',
                                  freesurfer_subject_dir=freesurfer_subject_dir)
            flat_file = flat_file.format(name=patch + ".flat")
            flat, polys = formats.read_obj(flat_file)

        if clean:
            polys = cortex.freesurfer._remove_disconnected_polys(polys)
            flat = cortex.freesurfer._move_disconnect_points_to_zero(flat, polys)

        fname = surfs.format(hemi=hemi)
        print("saving to %s"%fname)
        cortex.formats.write_gii(fname, pts=flat, polys=polys)

    # clear the cache, per #81
    cortex.database.db.clear_cache(cx_subject)
    # Remove overlays.svg file (FLATMAPS HAVE CHANGED)
    overlays_file = cortex.database.db.get_paths(cx_subject)['overlays']
    if os.path.exists(overlays_file):
        os.unlink(overlays_file)

In [None]:
import cortex
for subject_id in range(8):
    cortex.freesurfer.import_subj(f'subj0{subject_id + 1}', freesurfer_subject_dir=freesurfer_path,)

In [None]:
cortex.freesurfer.import_subj(f'fsaverage', freesurfer_subject_dir=freesurfer_path,)

In [None]:
for subject_id in range(8):
    import_flat(f'subj0{subject_id + 1}', 'full', freesurfer_subject_dir=freesurfer_path, auto_overwrite=True)

In [None]:
run_name = 'modified_dbscan5'
reruns_mode = 'multiple'
tag = 'linear_decoding__group-22_reruns-2'
results_path = nsd_path / f'derivatives/figures/concept_maps_voxel_v5/{tag}'

W_concat = np.load(results_path / f'{tag}__W.npy')
W = []
for cluster_id in np.unique(cluster)[1:]:
    W.append(np.mean(W_concat[cluster == cluster_id], axis=0))
W = np.stack(W)
W = W / np.linalg.norm(W, axis=-1, keepdims=True)
W.shape

masks = [np.load(results_path / run_name / subject_name / 'mask.npy') for subject_name in subjects]

component_label_names = np.arange(masks[0].shape[0])
component_ids = np.array(component_label_names)
#voxel_counts = np.zeros((8, component_ids.shape[0]), dtype=int))
voxel_counts = np.stack([(mask > 0).sum(axis=1) for mask in masks])

subject_ids = np.load(results_path / f'{tag}__subject_id.npy')

In [None]:
import gc

num_models = 1
min_neighbors_grid = (3,)
eps_grid = (0.55, 0.6, 0.65, 0.5, 0.7, 0.45)
tag = 'linear_decoding__group-22_reruns-2'
#tag = 'linear_encoding_large'
suffix = '_expanded'
num_vertices = 163842

overwrite_time = 0 * 60 * 60

for min_neighbors in (min_neighbors_grid): #(1, 2, 3, 4):
    for i, eps in enumerate(eps_grid):
        gc.collect()
        print(f'{min_neighbors=}, {i=}, {eps=}')
        run_name = f'num_models-{num_models}/min_neighbors-{min_neighbors}/run-{i}'
        reruns_mode = 'multiple'
        
        results_path = nsd_path / f'derivatives/figures/concept_maps_voxel_v5/{tag}'

        masks = [np.load(results_path / run_name / subject_name / f'mask{suffix}.npy') for subject_name in subjects]

        component_label_names = np.arange(masks[0].shape[0])
        component_ids = np.array(component_label_names)
        #voxel_counts = np.zeros((8, component_ids.shape[0]), dtype=int))
        voxel_counts = np.stack([(mask > 0).sum(axis=1) for mask in masks])

        subject_ids = np.load(results_path / f'{tag}__subject_id.npy')


        space = 'fsavg'
        gyri_and_sulci_list = ['STS', 'SF', 'IPS', 'poCS', '⠀', 'ORBS', 'SFS', 'IFS', 'preCS', 'CS']
        HCP_roi_list = ['IPS1', 'MIP', 'IP1', 'IP0', 'PF', 'PSL', 'PGi', 'STV', 'TPOJ1']
        roi_list = ['V1', 'V2', 'V3', 'V4', 'OPA', 'RSC', 'PPA', 'EBA', 'FFA', ]# +  roi_names['floc-words']
        roi_list += gyri_and_sulci_list

        subject_colors = [to_rgb(f'tab:{c}') for c in ('blue', 'orange', 'green', 'red', 'purple', 'pink', 'olive', 'brown')]
        subject_colors = np.array(subject_colors)

        mixing_behavior = 'white'

        voxel_ids_all = [
            np.load(nsd_path / f'derivatives/figures/concept_maps_voxel_v5/{tag}/voxel_ids__{space}__{subject_name}.npy') 
            for subject_name in subjects
        ]

        for component_id in component_ids:
            mask_path = results_path / run_name
            out_path = mask_path / f'pycortex_flatmaps/component_colors{suffix}'
            out_path.mkdir(exist_ok=True, parents=True)
            
            if (out_path / f'component-{component_id}.png').exists():
                creation_time = (out_path / f'component-{component_id}.png').stat().st_ctime
                time_since_creation = time.time() - creation_time
                print(f'{creation_time=}, {time.time()=}, {time_since_creation=}')
                if time_since_creation < overwrite_time:
                    print(f'skipping {component_id=}')
                    continue
                else:
                    print('plotting')

            data_all = []
            for subject_id in range(8):
                subject_name = f'subj0{subject_id + 1}'
                subject_path = mask_path / subject_name

                #lh_data = np.load(subject_path / f'mask__component-{component_id}__{space}__lh.npy')
                #rh_data = np.load(subject_path / f'mask__component-{component_id}__{space}__rh.npy')

                #data = np.concatenate([lh_data, rh_data], axis=1)

                if reruns_mode == 'average':
                    subject_data = cluster[subject_ids == subject_id]
                    subject_data = (subject_data == component_id).astype(float)
                elif reruns_mode == 'multiple':
                    subject_data = (masks[subject_id][component_id] > 0).astype(float)
                data = subject_data[voxel_ids_all[subject_id]]
                data[voxel_ids_all[subject_id] == -1] = np.nan

                data = np.nanmax(data, axis=1)
                data_all.append(data)
            data_all = np.stack(data_all)
            print(data_all.shape)
            all_nan = np.all(np.isnan(data_all), axis=0)
            
            data_all_ids = np.zeros((data_all.shape[1],)).astype(int)
            for vertex_id, vertex_data in enumerate(data_all.T):
                subject_ids = np.where(vertex_data == 1.)[0]
                if subject_ids.shape[0] > 1:
                    data_all_ids[vertex_id] = 9
                elif subject_ids.shape[0] == 1:
                    data_all_ids[vertex_id] = subject_ids[0] + 1
            data_all_ids[all_nan] = -1
            data_all_ids += 1
            lh_data_all_ids, rh_data_all_ids = np.split(data_all_ids, 2)
            
            ctab = (subject_colors * 255)
            ctab = np.concatenate([ctab, np.zeros(((ctab.shape[0], 1)))], axis=1)
            ctab = np.concatenate([[[0, 0, 0, 255], [1, 1, 1, 0]], ctab, [[255, 255, 255, 0]]])
            label_names = ['unknown', 'model_input', *subjects, 'overlap']
            
            write_annot(out_path / f'lh.component-{component_id}.annot', lh_data_all_ids, ctab, label_names, fill_ctab=True)
            write_annot(out_path / f'rh.component-{component_id}.annot', rh_data_all_ids, ctab, label_names, fill_ctab=True)
            
            all_nan = np.all(np.isnan(data_all), axis=0)

            num_overlaps = np.nansum(data_all, axis=0)
            print(num_overlaps.shape, data_all.shape)
            num_vertices = data_all.shape[1]

            data_color = data_all.T @ subject_colors
            data_color = np.zeros((num_vertices, 3))

            for subject_id in range(8):
                data_color[data_all[subject_id] == 1] = subject_colors[subject_id]

            if mixing_behavior == 'white':
                data_color[num_overlaps > 1] = 1
            data_color[all_nan] = np.nan
            #data_color[np.all(data_color == np.nan, axis=1) & (~all_nan)] = 0.
            
            word_rois = ['V1', 'V2', 'V3', 'V4', 'OPA', 'RSC', 'PPA', 'EBA', 'VWFA-1', 'VWFA-2', 'msf-words', 'mTL-words', 'OWFA'] #+ HCP_roi_list + gyri_and_sulci_list
            if eps == 0.6 and component_id == 11:
                use_rois = word_rois 
            else:
                use_rois = roi_list
            
            data_color = [cortex.dataset.Vertex(data, subject_name if space == 'fssubject' else 'fsaverage', vmin=0, vmax=1) for data in data_color.T]
            #data_color = (data_color * 255).astype(np.uint8)
            braindata = cortex.dataset.VertexRGB(*data_color, subject_name)
            #braindata = braindata.blend_curvature(braindata)
            overlay_file = f'{subject_name}/overlays_version1.svg' if space == 'fssubject' else f'fsaverage/overlays_floc.svg'
            cortex.quickflat.make_figure(braindata, with_sulci=False, with_rois=True, with_curvature=True, with_colorbar=False, #colorbar_location=(0.01, 0.05, 0.2, 0.05),
                                         overlay_file=Path(cortex.database.default_filestore) / overlay_file,
                                         roi_list=roi_list)


            handles = []
            for i, subject_name in enumerate(subjects):
                num_voxels = voxel_counts[i, component_id]
                if num_voxels == 0:
                    continue
                handles.append(mpatches.Patch(color=subject_colors[i], label=f'{subject_name} ({num_voxels})'))
            #for label, color in legend_data:
            #    handles.append(mpatches.Patch(color=color, label=label))

            if mixing_behavior == 'white':
                handles.append(mpatches.Patch(color=(1, 1, 1), label=f'overlap'))
            plt.legend(handles=handles, loc='upper center', ncols=4)

            plt.savefig(out_path / f'component-{component_id}.png')
            cortex.quickflat.make_svg(out_path / f'component-{component_id}.svg', braindata, with_curvature=True, with_labels=True,
                                         overlay_file=Path(cortex.database.default_filestore) / overlay_file,)
            plt.close()
