In [10]:
import numpy as np
import h5py
from scipy.io import loadmat
import pandas as pd
from tqdm.notebook import tqdm
import os
import glob

In [144]:
ALL_TEST_GRIDS = {}
def read_h5py_string(dataset):
    refs = dataset[()][0]  # Unpack the array of object references
    strings = []
    for ref in refs:
        obj = dataset.file[ref]
        string = obj[()].tobytes().decode("utf-16")  # Decode the byte string
        strings.append(string)
    return strings




In [None]:
def load_test_data_perparticipant(participant_num):
    session_files = sorted(glob.glob(f"/Users/mishaal/personalproj/clarion_replay/raw/Behav/s{participant_num}/T*.mat"))
    num_sessions = len(session_files)
    meg_data = h5py.File(f"/Users/mishaal/personalproj/clarion_replay/raw/data/s{participant_num}/Data_inference.mat")
    classifier_data = h5py.File(f"/Users/mishaal/personalproj/clarion_replay/raw/data/s{participant_num}/Class_data.mat")
    
    meg_signal_data = np.transpose(meg_data["data"], (2, 1, 0))
    meg_correct_dup = np.array(meg_data["correct_trials_all"]).T
    bricks_conn_trial = np.array(meg_data["bricks_conn_trial"]).T
    bricks_rel_trial = np.array(meg_data["bricks_rel_trial"]).T
    stim_labels = read_h5py_string(meg_data["stimlabel"])# each unique presentation of a grid is given a label

    assert num_sessions * 48 == bricks_conn_trial.shape[0], f"mismatch in trial numbers for participant {participant_num} {num_sessions * 48} {bricks_conn_trial.shape[0]}"
    #load the binomial classifiers 
    betas = np.array(classifier_data["betas_loc"]).T
    intercepts = np.array(classifier_data["intercepts_loc"]).T

    os.makedirs(f"/Users/mishaal/personalproj/clarion_replay/processed/test_data/s{participant_num}", exist_ok=True)
    os.makedirs(f"/Users/mishaal/personalproj/clarion_replay/processed/train_data/s{participant_num}", exist_ok=True)

    p_df = {"PID":[int(participant_num)]*bricks_conn_trial.shape[0], 
            "Session": [], "Trial": [], "Grid_Name": [],
            "left_element": [], "ontop_element": [], "right_element": [], "below_element": [],
            "besideness": [], "middle": [], "ontopness": [], 
              "Q_Brick_Middle": [], "Q_Brick_Left": [], "Q_Relation": [], "True Relation": [], "Correct": [], "RT":[]}

    absolute_trial_index = 0
    for idx, filename in enumerate(session_files):
        all_data = loadmat(filename)
        behav_data = all_data["res"][0, 0]["behav"][0,0]
        stimulus_grids = behav_data["SOLUTIONS_BUILT"]
        correctness = behav_data["correct"]
        rts = behav_data["rt"]
        q_stimuli = behav_data["stim_catch"]
        query_relation = behav_data["question_catch"] # a brick is presented in the middle and another on the top left corner. the relation of the top left brick to the middle brick is asked. This is the identity of the brick in the top left corner
        true_relation = behav_data["relation_catch"] # the relation in question

        # save experiment data
        n_trials = correctness.shape[1]
        p_df["Session"].extend([idx+1]*n_trials)
        p_df["Trial"].extend(list(range(1, n_trials+1)))
        # turns out labeling not unique?
        # stim_label = stim_labels[absolute_trial_index: absolute_trial_index + n_trials]
        # p_df["Grid_Name"].extend(stim_label)

        p_df["left_element"].extend([bricks_rel_trial[i, 0] for i in range(absolute_trial_index, absolute_trial_index + n_trials)])
        p_df["ontop_element"].extend([bricks_rel_trial[i, 1] for i in range(absolute_trial_index, absolute_trial_index + n_trials)])
        p_df["right_element"].extend([bricks_rel_trial[i, 2] for i in range(absolute_trial_index, absolute_trial_index + n_trials)])
        p_df["below_element"].extend([bricks_rel_trial[i, 3] for i in range(absolute_trial_index, absolute_trial_index + n_trials)])

        p_df["besideness"].extend([bricks_conn_trial[i, 0] for i in range(absolute_trial_index, absolute_trial_index + n_trials)])
        p_df["middle"].extend([bricks_conn_trial[i, 1] for i in range(absolute_trial_index, absolute_trial_index + n_trials)])
        p_df["ontopness"].extend([bricks_conn_trial[i, 2] for i in range(absolute_trial_index, absolute_trial_index + n_trials)])

        p_df["Q_Brick_Left"].extend(q_stimuli[:, 0].flatten())
        p_df["Q_Brick_Middle"].extend(q_stimuli[:, 1].flatten())
        p_df["Q_Relation"].extend(query_relation.flatten())
        p_df["True Relation"].extend(true_relation.flatten())

        p_df["Correct"].extend(correctness.flatten())
        p_df["RT"].extend(rts.flatten())

        for j in range(n_trials): # add the grid
            t = 0
            for grid in ALL_TEST_GRIDS:
                if np.all(ALL_TEST_GRIDS[grid] == stimulus_grids[:, :, j]):
                    t = 1
                    p_df["Grid_Name"].append(grid)
                    break
            if t == 0: 
                ALL_TEST_GRIDS[f"GRID{len(ALL_TEST_GRIDS)}"] = stimulus_grids[:, :, j]
                p_df["Grid_Name"].append(f"GRID{len(ALL_TEST_GRIDS)-1}")

        absolute_trial_index += n_trials

    #Save MEG data
    np.save(f"/Users/mishaal/personalproj/clarion_replay/processed/test_data/s{participant_num}/meg_data.npy", meg_signal_data)
    #save classifier data
    np.save(f"/Users/mishaal/personalproj/clarion_replay/processed/test_data/s{participant_num}/classifier_coeffs.npy", betas)    
    np.save(f"/Users/mishaal/personalproj/clarion_replay/processed/test_data/s{participant_num}/classifier_intercepts.npy", intercepts)
    assert np.all(meg_correct_dup.flatten() == np.array(p_df["Correct"])), f"correctness mismatch for participant {participant_num}"  

    p_df = pd.DataFrame(p_df)
    p_df.to_csv(f"/Users/mishaal/personalproj/clarion_replay/processed/test_data/s{participant_num}/test_data.csv")

    return p_df





In [None]:
participants = glob.glob("/Users/mishaal/personalproj/clarion_replay/raw/Behav/s*")
participants = sorted([(p.split("/")[-1][1:]) for p in participants])
dfs = []
for p in tqdm(participants):
    p_df = load_test_data_perparticipant(p)
    dfs.append(p_df)

# concatenate dataframes
pd.concat(dfs).to_csv("/Users/mishaal/personalproj/clarion_replay/processed/test_data/all_test_data.csv")