In [1]:
from dotenv import load_dotenv
load_dotenv()
import os
import sys
sys.path.append(os.getenv('PYTHONPATH')) 
import pandas as pd
import os
import tqdm as tqdm
import glob as glob
import re

In [None]:
dataset_root = os.getenv("DATASETS_ROOT", "/default/path/to/datasets") #use default if DATASETS_ROOT env variable is not set.
print(f"dataset_root: {dataset_root}")

In [3]:
#Step 01: load each dataset's stiminfo file
nsd_stiminfo = pd.read_table(os.path.join(dataset_root, "MOSAIC", "stimuli", "datasets_stiminfo", "nsd_stiminfo.tsv"))
bmd_stiminfo = pd.read_table(os.path.join(dataset_root, "MOSAIC", "stimuli", "datasets_stiminfo", "bmd_stiminfo.tsv"))
b5000_stiminfo = pd.read_table(os.path.join(dataset_root, "MOSAIC", "stimuli", "datasets_stiminfo", "b5000_stiminfo.tsv"))
things_stiminfo = pd.read_table(os.path.join(dataset_root, "MOSAIC", "stimuli", "datasets_stiminfo", "things_stiminfo.tsv"))
god_stiminfo = pd.read_table(os.path.join(dataset_root, "MOSAIC", "stimuli", "datasets_stiminfo", "god_stiminfo.tsv"))
deeprecon_stiminfo = pd.read_table(os.path.join(dataset_root, "MOSAIC", "stimuli", "datasets_stiminfo", "deeprecon_stiminfo.tsv"))
had_stiminfo = pd.read_table(os.path.join(dataset_root, "MOSAIC", "stimuli", "datasets_stiminfo", "had_stiminfo.tsv"))
nod_stiminfo = pd.read_table(os.path.join(dataset_root, "MOSAIC", "stimuli", "datasets_stiminfo", "nod_stiminfo.tsv"))


In [4]:
datasets = {
    'NSD': nsd_stiminfo,
    'BMD':bmd_stiminfo,
    'BOLD5000':b5000_stiminfo,
    'THINGS':things_stiminfo,
    'GOD':god_stiminfo,
    'deeprecon':deeprecon_stiminfo,
    'HAD':had_stiminfo,
    'NOD':nod_stiminfo,
}
#base dataset is NSD. arbitrary choice, doesnt matter
base_dataset = 'NSD'
merged_df = datasets['NSD'].rename(columns={col: f"{col}_{base_dataset}" for col in datasets[base_dataset].columns if col != 'filename'})

for dataset_name, df in datasets.items():
    if dataset_name == 'NSD':
        continue  #skip base dataset

    #rename columns to make it dataset specific
    df = datasets[dataset_name].rename(columns={col: f"{col}_{dataset_name}" for col in datasets[dataset_name].columns if col != 'filename'})
    merged_df = pd.merge(merged_df, df, on='filename', how='outer')  # use 'outer' to keep all filenames

pattern = r"sub-.*_reps.*"
repetition_columns = [col for col in merged_df.columns if re.search(pattern, col)]

merged_df[repetition_columns] = merged_df[repetition_columns].fillna(0) #fill nans as 0
merged_df[repetition_columns] = merged_df[repetition_columns].astype(int) #change float values to int for the subject repetition columns


In [None]:
#merge columns that should be the same
source = [] #get the source of each filename and ensure no conflicts
for idx, filename in enumerate(merged_df['filename']):
    possible_sources = set()
    for dataset_name in datasets.keys(): #loop over the datasets
        if f'source' in datasets[dataset_name].columns: #if the column exists for this dataset
            possible_source = merged_df.loc[idx, f'source_{dataset_name}'] #get whatever this dataset says is the filenames source
            if pd.notna(possible_source): #if the source is not nan
                if possible_source not in possible_sources: #and if the possible source has not been added already
                    possible_sources.add(possible_source)
    if len(possible_sources) == 1:
        source.append(list(possible_sources)[0])
    elif len(possible_sources) > 1:
        raise ValueError(f"This filename {idx} {filename} has multiple possible sources: {possible_sources}. It should have just one.")
    else:
        source.append(None)
merged_df['source'] = source
#drop the dataset-specific source columns
for dataset_name in datasets.keys(): #loop over the datasets
    if f'source' in datasets[dataset_name].columns: 
        merged_df = merged_df.drop(f'source_{dataset_name}', axis=1)
print(merged_df.shape)

In [None]:
overlap = {} #get the source of each filename and ensure no conflicts
for idx, filename in enumerate(merged_df['filename']):
    seen_by = {}
    for dataset in datasets.keys():
        repetition_columns = [col for col in merged_df.columns if f'_reps_{dataset}' in col] 
        sub_reps = []
        for col in repetition_columns:
            if merged_df.loc[idx, col] > 0:
                sub_reps.append(col)
        if sub_reps:
            seen_by.update({dataset: sub_reps})
    if len(seen_by) > 1:
        overlap.update({filename: seen_by})

#summing up all individual unique stimuli across datasets will be 166,594.
# Of those, 2829 are not unique. 2805 have been seen across exactly two 
# datasets and 24 have been seen across exactly 3 datasets.
#Thus, there are 2853 duplicated 'filename' rows resulting in 163,741 unique stimuli in the compiled dataset (before exclusion of similar stim and resolution of test/train conflicts)
print(f"Found {len(overlap)} stimuli overlapping across datasets") #found 2829
assert(len(overlap) == 2829)


In [None]:
test_stim_only = 0
train_stim_only = 0
mix_stim = 0
datasets = ['NSD','BMD','BOLD5000','THINGS','GOD','deeprecon','HAD','NOD']
for filename, dset_dict in overlap.items():
    merged_info = merged_df[merged_df['filename'] == filename]
    test_or_train = [merged_info[f'test_train_{dset}'].item() for dset in datasets if not pd.isna(merged_info[f'test_train_{dset}'].item())]
    if set(['test']) == set(test_or_train):
        test_stim_only += 1
    elif set(['train']) == set(test_or_train):
        train_stim_only += 1
    elif set(['test','train']) == set(test_or_train):
        mix_stim += 1
    else:
        raise ValueError(f"stim {filename} should be in either test or train or both if it overlaps.")
    if len(test_or_train) == 3:
        print(f"{filename}: {test_or_train}")
print(f"Test stim only: {test_stim_only}")
print(f"Train stim only: {train_stim_only}")
print(f"Found in both test and train: {mix_stim}")

In [None]:
overlap_count = {"count_1": 0, "count_2": 0, "count_3": 0, "count_4+:0": 0}
for filename, seen_by in overlap.items():
    if len(seen_by) == 1:
        overlap_count['count_1'] += 1
    elif len(seen_by) == 2:
        overlap_count['count_2'] += 1
    elif len(seen_by) == 3:
        overlap_count['count_3'] += 1
    elif len(seen_by) >= 4:
        overlap_count['count_4+'] += 1
    else:
        raise ValueError(f"Invalid seen_by {seen_by}")
for k,v in overlap_count.items():
    print(f"{k}: {v}")


In [None]:
#GOD and deeprecon should overlap by 1250 stimuli, and BOLD5000 and NSD should overlap by 1410 stimuli.
#while other datasets have overlapping stimuli, these are the only two pairs I could identify that
#published how many stimuli overlap. We use this to check our numbers against those.
check_datasets = [("GOD", "deeprecon"), ("BOLD5000", "NSD")]
for pair in check_datasets:
    overlap = {} #get the source of each filename and ensure no conflicts
    for idx, filename in enumerate(merged_df['filename']):
        seen_by = {}
        for dataset in pair:
            repetition_columns = [col for col in merged_df.columns if f'_reps_{dataset}' in col] 
            sub_reps = []
            for col in repetition_columns:
                if merged_df.loc[idx, col] > 0:
                    sub_reps.append(col)
            if sub_reps:
                seen_by.update({dataset: sub_reps})
        if len(seen_by) > 1:
            overlap.update({filename: seen_by})
    print(f"Pair {pair}: Found {len(overlap)} overlapping stimuli")


In [10]:
#reorder the dataframe for readability
rep_pattern = r"sub-.*_reps.*"
repetition_columns = [col for col in merged_df.columns if re.search(rep_pattern, col)]

alias_pattern = r"alias_.*"
alias_columns = [col for col in merged_df.columns if re.search(alias_pattern, col)]

testtrain_pattern = r"test_train_.*"
testtrain_columns = [col for col in merged_df.columns if re.search(testtrain_pattern, col)]
desired_ordering = ['filename', 'source'] + alias_columns + testtrain_columns + repetition_columns

merged_df = merged_df[desired_ordering]

In [None]:
print("saving stimulus info file...")
merged_df.to_csv(os.path.join(dataset_root, "MOSAIC", "stimuli", "datasets_stiminfo", "compiled_dataset_stiminfo.tsv"), sep='\t', index=False)