## This notebook walks you through an example of running a synthetic localizer experiment. You will:
1. Load a pretrained brain-optimized model trained on NSD subjects to predict the entire brain
2. Run batch inference on the model for all NSD subjects on the popular fLoc experiment (Stigliani et al., 2015) http://vpnl.stanford.edu/fLoc 
3. Compute word, face, body, and place contrasts on the predicted whole brain responses and record the precision-recall curve metrics
4. Define a binary contrast per subject based on max F1 score
5. Visualize the model predicted contrast with the ground truth contrast on a flatmap

In [None]:
import os
from PIL import Image
import numpy as np
from collections import defaultdict
from pathlib import Path
from nilearn import plotting
import glob
import matplotlib.pyplot as plt
import s3fs
import hcp_utils as hcp
from sklearn.metrics import precision_recall_curve, average_precision_score
import mosaic
from mosaic.utils.inference import MosaicInference
from mosaic.models.transforms import SelectROIs


### 1. Load a pretrained brain-optimized model trained on NSD subjects to predict the entire brain

In [None]:

model, model_config = mosaic.from_pretrained(
    backbone_name="CNN8",
    vertices="all",
    framework="multihead",
    subjects="NSD"
)

### 2. Run batch inference on the model for all NSD subjects on the popular fLoc experiment (Stigliani et al., 2015) http://vpnl.stanford.edu/fLoc 


In [None]:
#download the fLoc images from the internet
if not os.path.exists("./fLoc_stimuli"):
    !wget http://vpnl.stanford.edu/fLoc/fLoc_stimuli.zip
    !unzip fLoc_stimuli.zip
    !rm fLoc_stimuli.zip
else:
    print("./fLoc_stimuli folder already exists")
stimulus_paths = glob.glob("./fLoc_stimuli/*jpg")
print(f"Found {len(stimulus_paths)} fLoc stimuli (should be 1584)")
images = [Image.open(image_path).convert("RGB") for image_path in stimulus_paths] #convert to PIL images

In [None]:
#by defualt the model uses cuda if available, otherwise cpu
inference = MosaicInference(
    model=model,
    batch_size=32,
    model_config=model_config
)

results = inference.run(
    images = images,
    names_and_subjects={"NSD": "all"}
)

#reformat output by filename
print("Reorganizing predictions by filename")
results_by_filename = defaultdict(dict)
for dataset in results.keys():
    for subjectID, prediction in results[dataset].items():
        for idx in range(prediction.shape[0]):
            results_by_filename[f"{subjectID}_{dataset}"][Path(stimulus_paths[idx]).name] = prediction[idx,:].cpu().detach().numpy()

### 3. Compute word, face, body, and place contrasts on the predicted whole brain responses and record the precision-recall curve metrics
We are comparing the contrasts synthetically generated from the model with the ground truth ROI class defined in the original NSD experiment. The resulting PR curve is essentially the result of binary classifications at different thresholds. In other words, we say "if we threshold the contrast at X, these indices are the proposed ROI class (1) and the other indices are not (0)". Repeating this for many thresholds allows us to construct a PR curve. This allows us to evaluate the synthetic roi predictions agnostic to choosing a (usually arbitrary) threshold. We also compute cutoff values for each contrast as the threshold that gives the max F1 score to be later used as the "best" threshold.

In [None]:
def compute_contrast(activations):
    """
    Compute contrasts according to grill-spector floc. While we can compute it for objects and objects LO those
    parts are commented out to match NSD roi classes.
    http://vpnl.stanford.edu/fLoc/

    INPUTS:
    - activations: dict, keys are filenames (str) of the floc image and values are model predictions for that image (float). 
                    The shape of the model predictions is (nvertices,), where nvertices is the number of vertices the 
                    model was trained on.
    OUTPUTS:
    - contrasts: dict, a dictionary where keys are the contrast names and the value is the result (float) of the 
                    contrast. See 'use_cohensd' for more info on the result.
    """
    categories = ["body", "limb", "child","adult", "corridor", "house", "car", "instrument", "word", "number", "scrambled"]
    category_dict = {cat: [] for cat in categories}
    for filename, activation in activations.items():
        cat = filename.split('-')[0]
        category_dict[cat].append(activation.flatten())
    avg_activation = {cat: np.mean(np.vstack(act), axis=0) for cat, act in category_dict.items()}
    contrasts = {'words': [], 'bodies': [], 'faces': [], 'places': []}
    for con_name in contrasts.keys():
        if con_name == 'words':
            pos = ['word', 'number']
            neg = ["body", "limb", "child","adult", "corridor", "house", "car", "instrument"]
        elif con_name == 'bodies':
            pos = ["body", "limb"]
            neg = ["child","adult" ,"corridor", "house", "car", "instrument", "word", "number"]
        elif con_name == 'faces':
            pos = ["child", "adult"]
            neg = ["word", "number", "body", "limb", "corridor", "house", "car", "instrument"]
        elif con_name == 'places':
            pos = ["corridor", "house"]
            neg = ["word", "number", "body", "limb", "child", "adult", "car", "instrument"]
        else:
            raise ValueError(f"contrast name {con_name} not recognized.")
    
        positive_activation = sum([avg_activation[label] for label in pos])/len(pos) 
        negative_activation = sum([avg_activation[label] for label in neg])/len(neg)
        results = positive_activation - negative_activation

        contrasts[con_name] = results
    return contrasts

#Compute precision - recall curve
def compute_pr_curve(contrast_values, truth_froi_indices, n_vertices=91282):
    """
    Compute precision-recall curve for ROI prediction across thresholds.
    
    Args:
        contrast_values: array of contrast values for all vertices
        truth_froi_indices: set/array of ground truth ROI vertex indices
        n_vertices: total number of vertices in brain surface
    
    Returns:
        precision, recall, thresholds arrays
    """
    # Create binary ground truth array
    y_true = np.zeros(n_vertices, dtype=bool)
    y_true[list(truth_froi_indices)] = True
    
    # Use contrast values as "scores" (higher = more likely to be in ROI)
    y_scores = contrast_values
    
    # Compute PR curve
    precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
    auc_pr = average_precision_score(y_true, y_scores)
    
    return precision, recall, thresholds, auc_pr

In [None]:
#compute the contrasts
contrast_names = ['words','bodies', 'faces', 'places']
subject_contrasts = {subjectID: defaultdict(list) for subjectID in results_by_filename.keys()}
for subject in results_by_filename.keys():
    subject_contrasts[subject] = compute_contrast(results_by_filename[subject])

In [None]:
#download NSD rois resampled to fsLR32k space
fs = s3fs.S3FileSystem(anon=True)

# Download entire folder recursively
s3_folder = 'mosaicfmri/assets/nsd_rois/'
local_folder = './nsd_rois'
if not os.path.exists(os.path.join(f"{local_folder}/nsd_rois")):
    print("Downloading nsd_rois folder from s3 bucket")
    fs.get(s3_folder, local_folder, recursive=True)
else:
    print("nsd_rois folder already downloaded.")

In [None]:
#compute PR stats against ground truth
if model_config['vertices'] == 'visual':
    rois = [f"GlasserGroup_{x}" for x in range(1, 6)]
elif model_config['vertices'] == 'all':
    rois = [f"GlasserGroup_{x}" for x in range(1, 23)]
ROI_selection = SelectROIs(selected_rois=rois)

subject_list = list(subject_contrasts.keys())
contrasts_list = list(subject_contrasts[subject_list[0]].keys())
cutoffs = {subject: {contrast_name: 0 for contrast_name in contrasts_list} for subject in subject_list} #keep track of best cutoffs
pr_results = {contrast_name: {subject_truth: {subject_predict: {} for subject_predict in subject_list} for subject_truth in subject_list} for contrast_name in contrasts_list}

for contrast_name in contrasts_list:
    for i in range(len(subject_list)):
        subA = subject_list[i]
        lh_truth = np.load(os.path.join("./nsd_rois", subA.split('_NSD')[0], "roi_masks", f"lh.floc-{contrast_name}_fsLR32k_space_resampled.npy"))
        rh_truth = np.load(os.path.join("./nsd_rois", subA.split('_NSD')[0], "roi_masks", f"rh.floc-{contrast_name}_fsLR32k_space_resampled.npy"))
        ground_truth_data = np.hstack((lh_truth[hcp.vertex_info.grayl], rh_truth[hcp.vertex_info.grayr])) #go from full mesh to defined grayordinates
        ground_truth_data_select = ground_truth_data[ROI_selection.selected_roi_indices] #go from defined grayordinates to our selected ROIs (whole brain without some nan indices)
        
        ground_truth_froi_indices = np.where(ground_truth_data_select > 0)[0]
        random_auc_pr = len(ground_truth_froi_indices) / len(ROI_selection.selected_roi_indices) #AUC_pr of a random binary classifier is proportion of positive examples to all examples (p/(p+n))

        for j in range(len(subject_list)): #we want the diagonal
            subB = subject_list[j]
            contrast_values = subject_contrasts[subB][contrast_name]
            precision, recall, thresholds, auc_pr = compute_pr_curve(contrast_values, ground_truth_froi_indices, n_vertices=contrast_values.shape[0])

            f1 = 2 * (precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1] + 1e-6)
            max_f1_idx = np.argmax(f1)
            cutoffs[subA][contrast_name] = thresholds[max_f1_idx]
                
            pr_results[contrast_name][subA][subB] = {
                 'precision': precision,
                'recall': recall, 
                'thresholds': thresholds,
                'auc_pr': auc_pr,
                'random_auc_pr': random_auc_pr}

### 4. Define a binary contrast per subject based on max F1 score

In [None]:
#threshold contrast by maximum F1 score
subject_roi_predictions = {subjectID: defaultdict(list) for subjectID in results_by_filename.keys()}
allsubject_roi_predictions = {con_name: np.zeros((len(ROI_selection.selected_roi_indices),)) for con_name in contrast_names}
for subject in results_by_filename.keys():
    contrasts = subject_contrasts[subject]
    for con_name, stat in contrasts.items():
        mask = np.zeros_like(stat)
        #for f1 cutoff
        cutoff = cutoffs[subject][con_name]
        top_k_indices = np.where(stat > cutoff)[0]

        mask[top_k_indices] = 1
        subject_roi_predictions[subject][con_name] = list(np.argwhere(mask).squeeze())
        allsubject_roi_predictions[con_name] += mask

### 5. Visualize the model predicted contrast with the ground truth contrast on a flatmap

In [None]:
def plot_flatmap(stat, title="", cmap='hot', cmap_flag=False, save_flag=False):
    #Save flat maps. hemispheres are combined in one plot
    #get the data for both hemispheres
    cortex_data_left = hcp.left_cortex_data(stat)
    cortex_data_right = hcp.right_cortex_data(stat)

    #determine global min/max for consistent color scaling
    datamin = min(np.nanmin(cortex_data_left), np.nanmin(cortex_data_right))
    datamax = max(np.nanmax(cortex_data_left), np.nanmax(cortex_data_right))
    vmin=None #datamin
    vmax=None
    threshold = None
    #create a figure with multiple axes to plot each anatomical image
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 4), subplot_kw={'projection': '3d'})
    plt.subplots_adjust(wspace=-.4)
    im = plotting.plot_surf(hcp.mesh.flat_left, cortex_data_left,
            threshold=threshold, bg_map=hcp.mesh.sulc_left, vmin=vmin, vmax=vmax,
            colorbar=cmap_flag, cmap=cmap,
            axes = axes[0])
    im = plotting.plot_surf(hcp.mesh.flat_right, cortex_data_right,
            threshold=threshold, bg_map=hcp.mesh.sulc_right, vmin=vmin, vmax=vmax,
            colorbar=cmap_flag, cmap=cmap,
            axes = axes[1])
    
    #flip along the horizontal
    axes[0].invert_yaxis()
    axes[1].invert_yaxis()

    fig.suptitle(title)
    if save_flag:
        if not title:
            print("Warning: title is blank, so saved output filenames may overwrite one another.")
        plt.savefig(f"{title}_flatmap.svg")
        plt.savefig(f"{title}_flatmap.png", dpi=300)
    plt.show()
    plt.close()

In [None]:
#compare single subject ROI class predictions with ground truth
comparison_subject = 'sub-01_NSD' ### change me
for con_name in contrast_names:
    mask = np.zeros((len(ROI_selection.selected_roi_indices),))
    mask[subject_roi_predictions[comparison_subject][con_name]] = 1
    wb_predictions = ROI_selection.sample2wb(mask) #go to whole brain
    wb_predictions[wb_predictions == 0] = np.nan
    wb_predictions[wb_predictions > 0] = 100
    plot_flatmap(wb_predictions, title=f"{comparison_subject} subject {con_name} Prediction", cmap='hot_r')

    lh_truth = np.load(os.path.join("./nsd_rois", comparison_subject.split('_NSD')[0], "roi_masks", f"lh.floc-{con_name}_fsLR32k_space_resampled.npy"))
    rh_truth = np.load(os.path.join("./nsd_rois", comparison_subject.split('_NSD')[0], "roi_masks", f"rh.floc-{con_name}_fsLR32k_space_resampled.npy"))
    ground_truth_data = np.hstack((lh_truth[hcp.vertex_info.grayl], rh_truth[hcp.vertex_info.grayr]))

    ground_truth_data[ground_truth_data == 0] = np.nan
    ground_truth_data[ground_truth_data > 0] = 100 #
    plot_flatmap(ground_truth_data, title=f"{comparison_subject} subject {con_name} Truth", cmap='hot_r')