In [23]:
from os import listdir, PathLike
from os.path import join
from typing import List, Tuple, Dict
from dataclasses import dataclass
import h5py
import numpy as np
import pandas as pd
from hdmf.backends.hdf5 import H5DataIO
from pynwb import NWBFile, TimeSeries
from pynwb.file import Subject
from pynwb.ecephys import ElectricalSeries, ElectrodeGroup, LFP
from pynwb.behavior import BehavioralEvents
import nixio
import regex as re
from usz_neuro_conversion.common import (
    SessionContext,
    NixContext,
    get_metadata_row,
    read_nix,
    get_date,
    write_nwb,
    standardize_sex,
    find_nix_files, get_matlab_matrix_scalars_ragged, get_micro_dir
)
from joblib import Parallel, delayed
import multiprocessing

In [62]:
def read_matlab(ctx: SessionContext):
    global micros
    if len(micros) > 0:
        micros = {}
    micro_files = _find_micro_data_files(ctx)
    files = micro_files[ctx.subject]

    def read_electrode(electrode, file):
        with h5py.File(file, 'r') as file:
            trials = get_matlab_matrix(file, "trial")
            times = get_matlab_matrix(file, "time")
            return {
                electrode: Micro(
                    trials=trials,
                    times=times
                )
            }

    num_cores = multiprocessing.cpu_count()
    # Source: https://stackoverflow.com/a/50926231
    micros = Parallel(n_jobs=num_cores)(delayed(read_electrode)(electrode, file) for electrode, file in files.items())
    # Source: https://stackoverflow.com/a/43219379
    micros = {k: v for d in micros for k, v in d.items()}
    assert len(micros) > 0

In [52]:
def get_matlab_matrix(file: h5py.File, variable: str) -> np.ndarray:
    ref = (
        file.get(f"data/{variable}")
        if "data" in file.keys()
        else file.get(f"dataMicro/{variable}")
    )
    refs = [ref[0] for ref in ref]
    assert len(refs) > 0

    inner_dim = np.array(file[ref[0][0]][:]).shape
    matrices = np.zeros((len(refs), inner_dim[0], inner_dim[1]))
    for i, ref in enumerate(refs):
        matrices[i] = file[ref][:]
    return matrices

In [53]:
@dataclass(frozen=True)
class Micro:
    trials: np.ndarray
    times: np.ndarray


micros = {}

In [7]:
def _find_micro_data_files(ctx: SessionContext) -> Dict[int, Dict[str, PathLike]]:
    dir = get_micro_dir(ctx)
    micro_files = {}
    for file in listdir(dir):
        match = MATLAB_RE.match(file)
        if match:
            subject, _electrode_index, electrode = match.groups()
            subject = int(subject)
            if subject not in CORRECTED_PATIENT:
                continue
            subject = CORRECTED_PATIENT[subject]
            if subject not in micro_files:
                micro_files[subject] = {}
            micro_files[subject][electrode] = join(dir, file)
    assert len(micro_files) > 0
    return micro_files


In [10]:
CORRECTED_PATIENT = {
    28: 1,
    22: 2,
    19: 3,
    30: 4,
    33: 5,
    13: 6,
    23: 7,
    29: 8,
    16: 9,
}

In [5]:
def create_context(subject: int, session: int) -> SessionContext:
    nix_context = NixContext(
        subject, session, project="Human_MTL_units_scalp_EEG_and_iEEG_verbal_WM"
    )
    nix = read_nix(nix_context)
    general = nix.sections["General"]
    nwb = NWBFile(
        session_description="Running experiment as described in the the experiment description",
        identifier=f"Human_MTL_units_scalp_EEG_and_iEEG_verbal_WM_subject{subject:02}_session{session:02}",
        session_start_time=get_date(nix_context),
        lab=general.props["Recording location"].values[0],
        institution="Universitätsspital Zürich, 8091 Zurich, Switzerland",  # Broken UTF-8 in file
        experimenter="Boran, Ece",
        keywords=[
            "Neuroscience",
            "Electrophysiology",
            "Human",
            "Awake",
            "Local field potential",
            "Neuronal action potential",
            "Spikes",
            "Medial temporal lobe",
            "Hippocampus",
            "Entorhinal cortex",
            "Amygdala",
            "Scalp EEG",
            "Intracranial EEG",
            "Cognitive task",
            "Verbal working memory",
            "Epilepsy",
        ],
    )
    return nix_context.to_session_context(nix, nwb)

In [9]:
# Micro_Data_Patient_04_Electrode_01_uAR
MATLAB_RE = re.compile(r"Micro_Data_Patient_(\d+)_Electrode_(\d+)_u([A-Z]+).mat")

In [63]:
subject = 1
session = 1
ctx = create_context(subject, session)
read_matlab(ctx)

In [64]:
micros.keys()

dict_keys(['AHL', 'AL', 'ECL', 'PHR'])

In [68]:
print(micros["AHL"].trials.shape)
print(micros["AHL"].times.shape)

(200, 256000, 8)
(200, 256000, 1)


In [80]:
trials = np.zeros(
    (len(micros), micros["AHL"].trials.shape[0], micros["AHL"].trials.shape[1], micros["AHL"].trials.shape[2]))
for i, (electrode, micro) in enumerate(micros.items()):
    trials[i] = micro.trials
trials = trials.swapaxes(2, 3).swapaxes(1, 2)
# electrode - channel/subelectrode - trial - values per timestamp
trials.shape

(4, 8, 200, 256000)

In [81]:
trials_reshaped = np.zeros((trials.shape[0] * trials.shape[1], trials.shape[2], trials.shape[3]))
for i in range(trials.shape[0]):
    for j in range(trials.shape[1]):
        trials_reshaped[i * trials.shape[1] + j] = trials[i, j]
trials_reshaped.shape

(32, 200, 256000)

In [88]:
trials_reshaped_again = np.zeros((trials_reshaped.shape[0], trials_reshaped.shape[1] * trials_reshaped.shape[2]))
for i in range(trials_reshaped.shape[1]):
    for j in range(trials_reshaped.shape[2]):
        trials_reshaped_again[:, i * trials_reshaped.shape[1] + j] = trials_reshaped[:, i, j]
trials_reshaped_again = trials_reshaped_again.transpose()
trials_reshaped_again.shape

(51200000, 32)

In [90]:
times = np.zeros(micros["AHL"].times.shape[0] * micros["AHL"].times.shape[1])
for i in range(micros["AHL"].times.shape[0]):
    for j in range(micros["AHL"].times.shape[1]):
        times[i * micros["AHL"].times.shape[1] + j] = micros["AHL"].times[i, j, 0]
times.shape

(51200000,)

# New Approach

In [96]:
def get_matlab_trial_info(file_name: str) -> pd.DataFrame:
    csv_name = file_name.replace(".mat", ".csv")
    with open(csv_name, "r") as file:
        return pd.read_csv(file, sep=",")


def get_trial_indices(micro_info: pd.DataFrame, session: int) -> List[int]:
    # Source: https://stackoverflow.com/a/17215844
    x = micro_info.loc[:, "Session"] == session
    return x[x].index.values


@dataclass(frozen=True)
class MicroData:
    matrix: np.ndarray
    measurements_per_trial: int
    channels_per_electrode: int
    electrodes: int
    total_trials: int
    trial_info: pd.DataFrame


def prepare_micro_data(ctx: SessionContext) -> MicroData:
    micro_files = _find_micro_data_files(ctx)
    files = micro_files[ctx.subject]

    num_electrodes = len(files)
    reference_file = list(files.values())[0]  # arbitrary

    with h5py.File(reference_file, 'r') as file:
        ref = file.get("dataMicro/trial")
        refs = [ref[0] for ref in ref]

        inner_dim = np.array(file[refs[0]][:]).shape
        total_trials = len(refs)
        measurements = inner_dim[0]
        channels = inner_dim[1]

        trial_info = get_matlab_trial_info(reference_file)
        trials_in_current_session_indices = get_trial_indices(trial_info, session)

        total_sources = num_electrodes * channels
        total_measurements = measurements * len(trials_in_current_session_indices)
        matrix = np.zeros((total_measurements, total_sources))

        return MicroData(
            matrix=matrix,
            measurements_per_trial=measurements,
            channels_per_electrode=channels,
            electrodes=num_electrodes,
            total_trials=total_trials,
            trial_info=trial_info
        )


def read_lfp_trials(ctx: SessionContext, micro_data: MicroData):
    micro_files = _find_micro_data_files(ctx)
    files = micro_files[ctx.subject]
    for electrode_index, file in enumerate(files.values()):
        with h5py.File(file, 'r') as electrode_file:
            refs = [ref[0] for ref in electrode_file.get(f"dataMicro/trial")]
            indices_in_session = get_trial_indices(micro_data.trial_info, session)
            trial_refs = [refs[i] for i in indices_in_session]
            for trial_index, trial_ref in enumerate(trial_refs):
                measurement_index = trial_index * micro_data.measurements_per_trial
                channel_index = electrode_index * micro_data.channels_per_electrode
                data = electrode_file[trial_ref]
                micro_data.matrix[measurement_index:measurement_index + micro_data.measurements_per_trial,
                channel_index:channel_index + micro_data.channels_per_electrode] = data

In [167]:
subject = 6
session = 1
ctx = create_context(subject, session)

micro_data = prepare_micro_data(ctx)
print("matrix.shape", micro_data.matrix.shape)
print("measurements_per_trial", micro_data.measurements_per_trial)
print("channels_per_electrode", micro_data.channels_per_electrode)
print("electrodes", micro_data.electrodes)
print("total_trials", micro_data.total_trials)
print("trial_info.shape", micro_data.trial_info.shape)

matrix.shape (12800000, 64)
measurements_per_trial 256000
channels_per_electrode 8
electrodes 8
total_trials 349
trial_info.shape (349, 15)


In [169]:
read_lfp_trials(ctx, micro_data)

349
349
349
349
349
349
349
349


In [170]:
print(micro_data.matrix[9, 4])

-5.2339244549504125


In [171]:
print(micro_data.matrix[0, 0])


3.3570360935542003
