In [1]:
import numpy as np
import h5py
from scipy.io import loadmat
import pandas as pd
from tqdm.notebook import tqdm
import os
import glob

In [62]:
ALL_TEST_GRIDS = {}
ALL_TRAIN_GRIDS = {}   
def read_h5py_string(dataset):
    refs = dataset[()][0]  # Unpack the array of object references
    strings = []
    for ref in refs:
        obj = dataset.file[ref]
        string = obj[()].tobytes().decode("utf-16")  # Decode the byte string
        strings.append(string)
    return strings

In [56]:
def mk_ontopness(REQUIRED_FORM):
    """
    Check if there is ontopness in pattern.
    """
    ontopness = False
    count_ontop = 0
    ontop = 0
    below = 0
    
    if REQUIRED_FORM.size > 0:  # Check if not empty
        for i in range(1, REQUIRED_FORM.shape[0]):
            row_current = REQUIRED_FORM.shape[0] - i
            row_above = REQUIRED_FORM.shape[0] - (i + 1)
            
            if np.any(REQUIRED_FORM[row_current, :] - REQUIRED_FORM[row_above, :] != 0):
                index = np.where((REQUIRED_FORM[row_current, :] - REQUIRED_FORM[row_above, :]) != 0)[0]
                
                if np.any((REQUIRED_FORM[row_current, index] * REQUIRED_FORM[row_above, index]) != 0):
                    ontopness = True
                    count_ontop = count_ontop + 1
                    
                    elements_form = np.unique(REQUIRED_FORM)
                    elements_form = elements_form[elements_form != 0]
                    
                    row1 = np.where(REQUIRED_FORM == elements_form[0])[0]
                    row2 = np.where(REQUIRED_FORM == elements_form[1])[0]
                    
                    if np.min(row1) < np.min(row2):
                        ontop = elements_form[0]
                        below = elements_form[1]
                    elif np.min(row1) > np.min(row2):
                        ontop = elements_form[1]
                        below = elements_form[0]
    
    return ontopness, count_ontop, ontop, below

def mk_besideness(REQUIRED_FORM):
    """
    Check if there is besideness in pattern.
    """
    besideness = False
    count_beside = 0
    left = 0
    right = 0
    
    if REQUIRED_FORM.size > 0:  # Check if not empty
        REQUIRED_FORM = REQUIRED_FORM.T
        
        for i in range(1, REQUIRED_FORM.shape[0]):
            row_current = REQUIRED_FORM.shape[0] - i
            row_above = REQUIRED_FORM.shape[0] - (i + 1)
            
            if np.any(REQUIRED_FORM[row_current, :] - REQUIRED_FORM[row_above, :] != 0):
                index = np.where((REQUIRED_FORM[row_current, :] - REQUIRED_FORM[row_above, :]) != 0)[0]
                
                if np.any((REQUIRED_FORM[row_current, index] * REQUIRED_FORM[row_above, index]) != 0):
                    besideness = True
                    count_beside = count_beside + 1
                    
                    elements_form = np.unique(REQUIRED_FORM)
                    elements_form = elements_form[elements_form != 0]
                    
                    row1 = np.where(REQUIRED_FORM == elements_form[0])[0]
                    row2 = np.where(REQUIRED_FORM == elements_form[1])[0]
                    
                    if np.min(row1) < np.min(row2):
                        left = elements_form[0]
                        right = elements_form[1]
                    elif np.min(row1) > np.min(row2):
                        left = elements_form[1]
                        right = elements_form[0]
    
    return besideness, count_beside, left, right

def brick_connectedness(stim_grid):
    bricks_conn_trial = [0, 0, 0] #Element connected to middle via besideness | middle Element | Element connected to middle via ontopness
    bricks_rel_trial = [0, 0, 0, 0] # left element | ontop element | right element | below element

    bricks = np.unique(stim_grid)[1:] # dont need 0 
    
    part1 = np.copy(stim_grid); part1[part1==bricks[0]] = 0;
    part2 = np.copy(stim_grid); part2[part1==bricks[1]] = 0;
    part3 = np.copy(stim_grid); part3[part1==bricks[2]] = 0;

    bricks_order = np.array([[mk_ontopness(part3)[0]+mk_ontopness(part2)[0], mk_ontopness(part1)[0]+mk_ontopness(part3)[0], mk_ontopness(part1)[0]+mk_ontopness(part2)[0]], [mk_besideness(part3)[0]+mk_besideness(part2)[0], mk_besideness(part1)[0]+mk_besideness(part3)[0], mk_besideness(part1)[0]+mk_besideness(part2)[0]]])
    bricks_order = ([np.where(~bricks_order[0, :] & bricks_order[1,:])[0],
                     np.where(bricks_order[0,:] & bricks_order[1,:])[0],
                       np.where(bricks_order[0,:] & ~bricks_order[1,:])[0]])
    try:
        bricks_conn_trial = bricks[bricks_order].T
    except:
        print("somethign happened here")

    if mk_ontopness(part1)[0]:
        _, _, bricks_rel_trial[1], bricks_rel_trial[3] = mk_ontopness(part1)
    elif mk_ontopness(part2)[0]:
        _, _, bricks_rel_trial[1], bricks_rel_trial[3] = mk_ontopness(part2)
    elif mk_ontopness(part3)[0]:
        _, _, bricks_rel_trial[1], bricks_rel_trial[3] = mk_ontopness(part3)
    
    if mk_besideness(part1)[0]:
        _, _, bricks_rel_trial[0], bricks_rel_trial[2] = mk_besideness(part1)
    elif mk_besideness(part2)[0]:
        _, _, bricks_rel_trial[0], bricks_rel_trial[2] = mk_besideness(part2)
    elif mk_besideness(part3)[0]:
        _, _, bricks_rel_trial[0], bricks_rel_trial[2] = mk_besideness(part3)

    assert 0 not in bricks_conn_trial, "0 in bricks_conn_trial"
    assert 0 not in bricks_rel_trial, "0 in bricks_rel_trial"
    
    return bricks_conn_trial.flatten(), bricks_rel_trial





In [None]:
def load_test_data_perparticipant(participant_num):
    session_files = sorted(glob.glob(f"/Users/mishaal/personalproj/clarion_replay/raw/Behav/s{participant_num}/T*.mat"))
    num_sessions = len(session_files)
    meg_data = h5py.File(f"/Users/mishaal/personalproj/clarion_replay/raw/data/s{participant_num}/Data_inference.mat")
    classifier_data = h5py.File(f"/Users/mishaal/personalproj/clarion_replay/raw/data/s{participant_num}/Class_data.mat")
    
    meg_signal_data = np.transpose(meg_data["data"], (2, 1, 0))
    meg_correct_dup = np.array(meg_data["correct_trials_all"]).T


    no_detect_grids = False

    try:
        bricks_conn_trial = np.array(meg_data["bricks_conn_trial"]).T
        bricks_rel_trial = np.array(meg_data["bricks_rel_trial"]).T
    except:
        no_detect_grids = True

    stim_labels = read_h5py_string(meg_data["stimlabel"])# each unique presentation of a grid is given a label

    assert num_sessions * 48 == len(stim_labels), f"mismatch in trial numbers for participant {participant_num} {num_sessions * 48} {bricks_conn_trial.shape[0]}"
    #load the binomial classifiers 
    betas = np.array(classifier_data["betas_loc"]).T
    intercepts = np.array(classifier_data["intercepts_loc"]).T

    os.makedirs(f"/Users/mishaal/personalproj/clarion_replay/processed/test_data/s{participant_num}", exist_ok=True)
    os.makedirs(f"/Users/mishaal/personalproj/clarion_replay/processed/train_data/s{participant_num}", exist_ok=True)

    p_df = {"PID":[int(participant_num)]*len(stim_labels), 
            "Session": [], "Trial": [], "Grid_Name": [],
            "left_element": [], "ontop_element": [], "right_element": [], "below_element": [],
            "besideness": [], "middle": [], "ontopness": [], 
              "Q_Brick_Middle": [], "Q_Brick_Left": [], "Q_Relation": [], "True Relation": [], "Correct": [], "RT":[]}

    absolute_trial_index = 0
    for idx, filename in enumerate(session_files):
        all_data = loadmat(filename)
        behav_data = all_data["res"][0, 0]["behav"][0,0]
        stimulus_grids = behav_data["SOLUTIONS_BUILT"]
        correctness = behav_data["correct"]
        rts = behav_data["rt"]
        q_stimuli = behav_data["stim_catch"]
        query_relation = behav_data["question_catch"] # a brick is presented in the middle and another on the top left corner. the relation of the top left brick to the middle brick is asked. This is the identity of the brick in the top left corner
        true_relation = behav_data["relation_catch"] # the relation in question

        # save experiment data
        n_trials = correctness.shape[1]
        p_df["Session"].extend([idx+1]*n_trials)
        p_df["Trial"].extend(list(range(1, n_trials+1)))
        # turns out labeling not unique?
        stim_label = stim_labels[absolute_trial_index: absolute_trial_index + n_trials]
        # p_df["Grid_Name"].extend(stim_label)


        if not no_detect_grids:
            p_df["left_element"].extend([bricks_rel_trial[i, 0] for i in range(absolute_trial_index, absolute_trial_index + n_trials)])
            p_df["ontop_element"].extend([bricks_rel_trial[i, 1] for i in range(absolute_trial_index, absolute_trial_index + n_trials)])
            p_df["right_element"].extend([bricks_rel_trial[i, 2] for i in range(absolute_trial_index, absolute_trial_index + n_trials)])
            p_df["below_element"].extend([bricks_rel_trial[i, 3] for i in range(absolute_trial_index, absolute_trial_index + n_trials)])

            p_df["besideness"].extend([bricks_conn_trial[i, 0] for i in range(absolute_trial_index, absolute_trial_index + n_trials)])
            p_df["middle"].extend([bricks_conn_trial[i, 1] for i in range(absolute_trial_index, absolute_trial_index + n_trials)])
            p_df["ontopness"].extend([bricks_conn_trial[i, 2] for i in range(absolute_trial_index, absolute_trial_index + n_trials)])

        p_df["Q_Brick_Left"].extend(q_stimuli[:, 0].flatten())
        p_df["Q_Brick_Middle"].extend(q_stimuli[:, 1].flatten())
        p_df["Q_Relation"].extend(query_relation.flatten())
        p_df["True Relation"].extend(true_relation.flatten())

        p_df["Correct"].extend(correctness.flatten())
        p_df["RT"].extend(rts.flatten())

        for j in range(n_trials): # add the grid
            t = 0
            for grid in ALL_TEST_GRIDS:
                if np.all(ALL_TEST_GRIDS[grid] == stimulus_grids[:, :, j]):
                    t = 1
                    p_df["Grid_Name"].append(grid)
                    break
            if t == 0: 
                ALL_TEST_GRIDS[f"GRID{len(ALL_TEST_GRIDS)}"] = stimulus_grids[:, :, j]
                p_df["Grid_Name"].append(f"GRID{len(ALL_TEST_GRIDS)-1}")

            if no_detect_grids:
                bricks_conn_trial, bricks_rel_trial = brick_connectedness(stimulus_grids[:, :, j])

                p_df["left_element"].append(bricks_rel_trial[0])
                p_df["ontop_element"].append(bricks_rel_trial[1])
                p_df["right_element"].append(bricks_rel_trial[2])
                p_df["below_element"].append(bricks_rel_trial[3])

                p_df["besideness"].append(bricks_conn_trial[0])
                p_df["middle"].append(bricks_conn_trial[1])
                p_df["ontopness"].append(bricks_conn_trial[2])
            
            # else: # sanity checking
            #     bricks_conn_trial_, bricks_rel_trial_ = brick_connectedness(stimulus_grids[:, :, j])
            #     assert np.all(bricks_conn_trial[absolute_trial_index + j] == bricks_conn_trial_), f"brick_conn_trial mismatch for participant {participant_num} session {idx+1} trial {j+1}"
            #     assert np.all(bricks_rel_trial[absolute_trial_index + j] == bricks_rel_trial_), f"brick_rel_trial mismatch for participant {participant_num} session {idx+1} trial {j+1}"

        # if indeed grids were uniquely named:
        # for j in range(n_trials):
        #     grid_name = p_df["Grid_Name"][absolute_trial_index + j]
        #     if grid_name not in ALL_TEST_GRIDS:
        #         ALL_TEST_GRIDS[grid_name] = [stimulus_grids[:, :, j], bricks_rel_trial[absolute_trial_index + j], bricks_conn_trial[absolute_trial_index + j]]
        #     else:
        #         assert np.all(ALL_TEST_GRIDS[grid_name] == stimulus_grids[:, :, j]), f"grid mismatch for participant {participant_num} session {idx+1} trial {j+1}"

        absolute_trial_index += n_trials

    #Save MEG data
    np.save(f"/Users/mishaal/personalproj/clarion_replay/processed/test_data/s{participant_num}/meg_data.npy", meg_signal_data)
    #save classifier data
    np.save(f"/Users/mishaal/personalproj/clarion_replay/processed/test_data/s{participant_num}/classifier_coeffs.npy", betas)    
    np.save(f"/Users/mishaal/personalproj/clarion_replay/processed/test_data/s{participant_num}/classifier_intercepts.npy", intercepts)
    assert np.all(meg_correct_dup.flatten() == np.array(p_df["Correct"])), f"correctness mismatch for participant {participant_num}"  

    p_df = pd.DataFrame(p_df)
    p_df.to_csv(f"/Users/mishaal/personalproj/clarion_replay/processed/test_data/s{participant_num}/test_data.csv")

    return p_df





In [59]:
participants = glob.glob("/Users/mishaal/personalproj/clarion_replay/raw/Behav/s*")
participants = sorted([(p.split("/")[-1][1:]) for p in participants])
dfs = []
for p in tqdm(participants):
    if "18" in p: continue
    p_df = load_test_data_perparticipant(p)
    dfs.append(p_df)

# concatenate dataframes
pd.concat(dfs).to_csv("/Users/mishaal/personalproj/clarion_replay/processed/test_data/all_test_data.csv")

  0%|          | 0/20 [00:00<?, ?it/s]

In [68]:
def load_train_construction_data_participant(participant_num):
    session_files = sorted(glob.glob(f"/Users/mishaal/personalproj/clarion_replay/raw/Behav/Training_MEG/s{participant_num}/T*.mat"))
    num_sessions = len(session_files)

    p_df = {"PID":[], "Session": [], "Trial": [], "Grid_Name": []}

    for idx, filename in enumerate(session_files):
        all_data = loadmat(filename)
        behav_data = all_data["res_train"][0, 0]
        stimulus_grids = behav_data["INFO_FORM"]
        correctness = behav_data["correct"]

        n_trials = correctness.shape[1]
        p_df["Session"].extend([idx+1]*n_trials)
        p_df["Trial"].extend(list(range(1, n_trials+1)))
        p_df["PID"].extend([int(participant_num)]*n_trials)

        for j in range(n_trials):
            for grid in ALL_TRAIN_GRIDS:
                if np.all(ALL_TRAIN_GRIDS[grid] == stimulus_grids[:, :, j]):
                    p_df["Grid_Name"].append(grid)
                    break
            else:
                ALL_TRAIN_GRIDS[f"GRID{len(ALL_TRAIN_GRIDS)}"] = stimulus_grids[:, :, j]
                p_df["Grid_Name"].append(f"GRID{len(ALL_TRAIN_GRIDS)-1}")

    #write the csv
    p_df = pd.DataFrame(p_df)
    p_df.to_csv(f"/Users/mishaal/personalproj/clarion_replay/processed/train_data/s{participant_num}/train_data_constr.csv")
    return p_df


def load_train_rel_data_participant(participant_num):
    session_files = sorted(glob.glob(f"/Users/mishaal/personalproj/clarion_replay/raw/Behav/Training_MEG/s{participant_num}/D*.mat"))
    num_sessions = len(session_files)

    p_df = {"PID":[], "Session": [], "Trial": [], "Grid_Name": []}
    
    for idx, filename in enumerate(session_files):
        all_data = loadmat(filename)
        behav_data = all_data["res_train"][0, 0]
        stimulus_grids = behav_data["SOLUTIONS"]

        n_trials = stimulus_grids.shape[2]
        p_df["Session"].extend([idx+1]*n_trials)
        p_df["Trial"].extend(list(range(1, n_trials+1)))
        p_df["PID"].extend([int(participant_num)]*n_trials)

        for j in range(n_trials):
            for grid in ALL_TRAIN_GRIDS:
                if np.all(ALL_TRAIN_GRIDS[grid] == stimulus_grids[:, :, j]):
                    p_df["Grid_Name"].append(grid)
                    break
            else:
                ALL_TRAIN_GRIDS[f"GRID{len(ALL_TRAIN_GRIDS)}"] = stimulus_grids[:, :, j]
                p_df["Grid_Name"].append(f"GRID{len(ALL_TRAIN_GRIDS)-1}")

    #write the csv
    p_df = pd.DataFrame(p_df)
    p_df.to_csv(f"/Users/mishaal/personalproj/clarion_replay/processed/train_data/s{participant_num}/train_data_rel.csv")
    return p_df

In [70]:
participants = glob.glob("/Users/mishaal/personalproj/clarion_replay/raw/Behav/s*")
participants = sorted([(p.split("/")[-1][1:]) for p in participants])
dfs = []
for p in tqdm(participants):
    if "18" in p: continue
    # p_df = load_train_construction_data_participant(p)
    p_df = load_train_rel_data_participant(p)
    dfs.append(p_df)

# concatenate dataframes
# pd.concat(dfs).to_csv("/Users/mishaal/personalproj/clarion_replay/processed/train_data/all_train_cons_data.csv")
pd.concat(dfs).to_csv("/Users/mishaal/personalproj/clarion_replay/processed/train_data/all_train_rel_data.csv")

  0%|          | 0/20 [00:00<?, ?it/s]