In [None]:
%load_ext blackcellmagic
%load_ext autoreload
%autoreload 2

#### Imports and setup

In [None]:
import numpy as np
import pandas as pd
import h5py
import nibabel.freesurfer.mghformat as mgh

In [None]:
from spacestream.core.constants import SUBJECTS, CORE_ROI_NAMES
from spacestream.core.paths import DATA_PATH, RESULTS_PATH
from spacestream.utils.get_utils import get_mapping

In [None]:
seeds = [0,1,2,3,4]
hemis = ["lh","rh"]
model_types = [
    "MB_RN50_v2_detection",
    "MB_RN50_v2_clip",
    "MB_RN50_v2_categorization",
    "MB_RN50_detection",
    "MB_RN50_action",
    "MB_RN50_categorization",
    "MB_RN18_detection",
    "MB_RN18_action",
    "MB_RN18_categorization",
]
checkpoint = "checkpoint0"

#### Load and format data

In [None]:
def correct_for_voxel_noise_ceiling(NC, mapping):

    brain_r = np.sqrt(
        NC[mapping["winning_idx"].astype(int)] / 100
    )  # convert from R^2 to r
    mapping["winning_roi"] = mapping["winning_roi"].astype(np.float32)

    if np.sum(np.isinf(mapping["winning_test_corr"])) > 0:
        mapping["winning_test_corr"][np.isinf(mapping["winning_test_corr"])] = np.nan

    corrected = mapping["winning_test_corr"] / brain_r
    corrected[mapping["winning_test_corr"] == 0] = np.nan

    return corrected

In [None]:
# read in data
long = {
    "model_type": [],
    "hemi": [],
    "subject": [],
    "ROIS": [],
    "result": [],
}


for hidx, hemi in enumerate(hemis):

    for sidx, subj in enumerate(SUBJECTS):

        # get ROI info
        mgh_file = mgh.load(DATA_PATH + "brains/" + hemi + ".ministreams.mgz")
        streams = mgh_file.get_fdata()[:, 0, 0].astype(int)
        # get noise ceiling estimates
        mgh_file = mgh.load(
            DATA_PATH + "brains/NC/subj" + subj + "/" + hemi + ".nc_3trials.mgh"
        )
        NC = mgh_file.get_fdata()[:, 0, 0]
        NC_trim = NC[streams != 0]
        NC_trim[NC_trim == 0] = np.nan  # Set all 0s to nans to avoid dividing by 0

        for mtype in model_types:

            mapping = get_mapping(
                subj_name="subj" + str(subj),
                model_seed=0,
                hemi=hemi,
                model_type = "MB18" if "18" in mtype else "MB50_v2" if "50_v2" in mtype else "MB50",
                checkpoint=checkpoint,
            )

            corrected = correct_for_voxel_noise_ceiling(NC_trim, mapping)

            for ridx, r in enumerate(CORE_ROI_NAMES):
                long["model_type"].append(mtype)
                long["hemi"].append(hemi)
                long["subject"].append(subj)
                long["ROIS"].append(r)

                # Determine task type based on model name
                if "categorization" in mtype:
                    t = 0
                elif "action" in mtype or "clip" in mtype:
                    t = 1
                else:
                    t = 2

                long["result"].append(
                    np.nanmean(
                        corrected[
                            (mapping["winning_roi"] == ridx + 5)
                            & (mapping["winning_task"] == t)
                        ]
                    )
                )

In [None]:
df = pd.DataFrame(long)
df = df.sort_values('ROIS') #just to get the plotting order right

In [None]:
## load subject2subject estimates
s2s_corrected_by_stream= np.zeros((len(SUBJECTS),len(seeds),len(CORE_ROI_NAMES), len(hemis)))

for hidx, hemi in enumerate(hemis):
    
    for sidx, subj in enumerate(SUBJECTS):

        for seedix, seed in enumerate(seeds):

            load_path = (RESULTS_PATH
                            + "mappings/one_to_one/voxel2voxel/target_subj"
                            + subj
                            + "/mode_"
                            + hemi
                            + "_ministreams_HVA_only_radius5_max_iters100_constant_radius_2.0dist_cutoff_constant_dist_cutoff_spherical"
                            + ("_CV_seed" + str(seed))
                            + "_"
                            + checkpoint
                            + "_voxel2voxel_correlation_info.hdf5"
                        )
            with h5py.File(load_path, "r") as f:

                for r, ridx in enumerate(CORE_ROI_NAMES):
                    s2s_corrected_by_stream[sidx,seedix,r,hidx] =  np.nanmean(f['corrected_test_corr'][:][f['winning_roi'][:] == (2-r)+5])
across_seed_corrected_mean = np.mean(np.mean(s2s_corrected_by_stream,axis=-1),axis=1)

In [None]:
across_seed_corrected_mean

In [None]:
# Reformat data
rows = []
for i, roi in enumerate(CORE_ROI_NAMES[::-1]):
    for j, subject in enumerate(SUBJECTS):
        rows.append({"subject": subject, "ROI": roi, "result": across_seed_corrected_mean[j, i]})
s2s_reformatted = pd.DataFrame(rows)

In [None]:
# Save the dataframes for matlab plotting function
# matlab/F03_B.m

s2s_reformatted.to_csv('/oak/stanford/groups/kalanit/biac2/kgs/projects/Dawn/SpaceStreamPaper/Revision/code/new_Fig3b_noiseCeiling_checkpoint0.csv', index=False)
df.to_csv('/oak/stanford/groups/kalanit/biac2/kgs/projects/Dawn/SpaceStreamPaper/Revision/code/new_Fig3b_dataFrame_0420_checkpoint0.csv', index=False)
