In [1]:
from os import listdir, PathLike
from typing import List, Tuple, Dict
import numpy as np
import h5py
import pandas as pd
from pynwb import NWBFile, TimeSeries
from pynwb.behavior import BehavioralEvents
from pynwb.file import Subject
from pynwb.ecephys import ElectrodeGroup
from os.path import join
import nixio
import regex as re
from usz_neuro_conversion.common import (
    SessionContext,
    NixContext,
    get_metadata_row,
    read_nix,
    get_date,
    write_nwb,
    standardize_sex,
    find_nix_files,
    get_micro_dir,
    get_matlab_matrix, get_matlab_matrix_scalars_ragged
)

In [2]:
def convert_nix_to_nwb(subject: int, session: int) -> SessionContext:
    ctx = create_context(subject, session)
    write_subject(ctx)
    add_electrode_columns(ctx)
    ieeg_electrode_group = write_ieeg_electrodes(ctx)
    write_ieeg_measurements(ctx)
    write_behavior(ctx)
    write_events(ctx)
    write_trial_data(ctx)
    write_waveforms(ctx, ieeg_electrode_group)
    write_lfp(ctx)
    return ctx

In [3]:
def create_context(subject: int, session: int) -> SessionContext:
    nix_context = NixContext(subject, session, project="Human_MTL_units_visual_WM")
    nix = read_nix(nix_context)
    general = nix.sections["General"]
    nwb = NWBFile(
        session_description="Running experiment as described in the the experiment description",
        identifier=f"Human_MTL_units_visual_WM_subject{subject:02}_session{session:02}",
        session_start_time=get_date(nix_context),
        lab=general.props["Recording location"].values[0],
        institution="Universitätsspital Zürich, 8091 Zurich, Switzerland",  # Broken UTF-8 in file
        related_publications=_get_related_publications(nix),
        experimenter="Boran, Ece",
        experiment_description=_get_experiment(nix),
        keywords=[
            "Visual",
            "Spatial",
            "Neural decoding",
            "Hippocampus",
            "Entorhinal cortex",
        ],
    )
    return nix_context.to_session_context(nix, nwb)

In [4]:
def _get_experiment(nix: nixio.File) -> str:
    task = nix.sections["Task"].props
    task_name = task["Task name"].values[0]
    # Broken UTF-8 in file
    task_desc = "The task is a change detection task designed to examine the visual working memory of subjects. In each trial, arrays of colored squares were presented and had to be memorized. The number of squares determined the set size: 1, 2, 4 or 6. There was a total 192 trials per session. Each trial started with a warning signal (0.4 s) that was a red fixation dot. The fixation dot was then changed to black (0.4 – 0.5 s, jittered). A memory array (encoding period, 0.8 s) was followed by a delay (retention interval, 0.9 s). After the delay, a test array was shown (2 s) followed by a jittered inter-trial interval of 1.3 to 2.3 s. The participants indicated by button press (“Same” or “Different”, forced choice) whether the test array differed from the memory array. If the arrays differed, only one square changed in colour, but all squares remained on the same location. The fixation dot was visible on the screen during the whole trial period. Eight different colours were used for the memory and test array (yellow, red, green, blue, magenta, cyan, grey, black). Before starting the sessions, participants conducted trial runs in a practice session to learn the task. In this session we verified if participants were colour-blind and could discriminate all colours. Practice sessions were repeated until the participant understood the task and was able to follow the pace of the trials."
    task_url = "https://www.neurobs.com/ex_files/expt_view?id=285&tree_item_url=klaver12.exp&item_id=klaver12.exp"  # Found online
    return (
        f"Task Name: {task_name}\nTask Description: {task_desc}\nTask URL: {task_url}"
    )

In [5]:
def _get_related_publications(_nix: nixio.File) -> List[str]:
    dois = ["https://doi.org/10.1016/j.neuroimage.2022.119123"]  # Found online
    return [doi.strip() for doi in dois]

In [6]:
def write_subject(ctx: SessionContext):
    metadata = get_metadata_row(ctx.to_nix_context())
    age = metadata["Age"]
    sex = metadata["Sex"]
    ctx.nwb.subject = Subject(
        subject_id=f"{ctx.subject:02}",
        age=f"P{int(age)}Y",
        description=_get_subject_description(ctx),
        species="Homo sapiens",
        sex=standardize_sex(sex),
    )

In [7]:
def _get_subject_description(ctx: SessionContext) -> str:
    metadata = get_metadata_row(ctx.to_nix_context())
    subject = ctx.nix.sections["Subject"].props
    handedness = metadata["Handedness"]
    pathology = metadata["Pathology"]
    implanted_electrodes = subject["Implanted electrodes"].values[0]
    electrodes_in_soz = metadata["Electrodes in seizure onset zone (SOZ)"]
    return f"Handedness: {handedness}\nPathology: {pathology}\nImplanted electrodes: {implanted_electrodes}\nElectrodes in seizure onset zone (SOZ): {electrodes_in_soz}"

In [8]:
def add_electrode_columns(ctx: SessionContext):
    ctx.nwb.add_electrode_column(
        name="label",
        description="Channel label referenced by other data arrays",
    )
    ctx.nwb.add_electrode_column(
        name="is_inside_soz",
        description="Indicates whether the electrode is inside the seizure onset zone (SOZ)",
    )

In [9]:
def _get_session_data(ctx: SessionContext) -> nixio.Block:
    return ctx.nix.blocks[f"Data_Subject_{ctx.subject:02}_Session_{ctx.session:02}"]

In [10]:
def write_ieeg_electrodes(ctx: SessionContext) -> ElectrodeGroup:
    nwb = ctx.nwb

    device = nwb.create_device(
        name="ATLAS Neurophysiology System",
        manufacturer="Neuralynx, Inc.",
        description="iEEG recording system",
    )

    # create an electrode group for this group
    electrode_group = nwb.create_electrode_group(
        name="ieeg",
        description=f"iEEG electrodes",
        device=device,
        location="Intracranial",
    )

    electrodes = _get_ieeg_electrodes(ctx)
    electrodes.apply(
        lambda row: _add_row_to_ieeg_electrodes(nwb, electrode_group, row), axis=1
    )
    return electrode_group

In [11]:
VISUAL_TASK_ELECTRODES = pd.read_csv(
    "../out/metadata/visual_task_electrodes.tsv", sep="\t"
)
VISUAL_TASK_ELECTRODES["anatomical_location"] = VISUAL_TASK_ELECTRODES[
    "anatomical_location"
].fillna("unspecific")
VISUAL_TASK_ELECTRODES = VISUAL_TASK_ELECTRODES.astype(
    {"label": "string", "anatomical_location": "string", "inside_soz": "bool"}
)


def _get_ieeg_electrodes(ctx: SessionContext) -> pd.DataFrame:
    # only keep the current subject
    df = VISUAL_TASK_ELECTRODES.loc[
        VISUAL_TASK_ELECTRODES["participant"] == ctx.subject
        ]
    return df.reset_index()

In [12]:
def _get_ieeg_electrode_labels(ctx: SessionContext) -> List[str]:
    electrodes = _get_ieeg_electrodes(ctx)
    return electrodes["label"].tolist()

In [13]:
def _add_row_to_ieeg_electrodes(
        nwb: NWBFile, electrode_group: ElectrodeGroup, row: pd.Series
):
    # Got MNI map: +X is right, +Y is anterior, +Z is superior according to <https://kathleenhupfeld.com/mni-template-coordinate-systems/>
    # But need NWB: +X is posterior, +Y is inferior, +Z is right according to <https://pynwb.readthedocs.io/en/stable/pynwb.file.html#pynwb.file.NWBFile.add_electrode>
    nwb.add_electrode(
        group=electrode_group,
        label=row["label"],
        location=row["anatomical_location"],
        reference="Common intracranial reference",
        is_inside_soz=row["inside_soz"],
        x=-row["y"] if not np.isnan(row["y"]) else None,
        y=-row["z"] if not np.isnan(row["z"]) else None,
        z=row["x"] if not np.isnan(row["x"]) else None,
        filtering="Passband, 1 to 8000 Hz",
    )

In [14]:
def write_ieeg_measurements(ctx: SessionContext):
    nwb = ctx.nwb
    ieeg_electrode_indices = list(range(_get_ieeg_electrode_count(ctx)))
    nwb.create_electrode_table_region(
        region=ieeg_electrode_indices,  # reference row indices 0 to N-1
        description="ieeg electrodes",
    )

In [15]:
_IEEG_RE = re.compile(r"iEEG_Data_Trial_(\d+)")

In [16]:
def _get_ieeg_electrode_count(ctx: SessionContext) -> int:
    return len(_get_ieeg_electrode_labels(ctx))

In [17]:
def write_events(ctx: SessionContext):
    nwb = ctx.nwb
    session = _get_session_data(ctx)
    tags = session.groups["Trial events single tags spike times"].tags
    tags_by_trial = [(_EVENT_RE.findall(tag.name)[0], tag.position) for tag in tags]
    events = [
        (int(trial_number) - 1, name, position[0])
        for (name, trial_number), position in tags_by_trial
        if name != "Response"
    ]
    events.sort(key=lambda x: x[0])
    offset = _get_time_offset(ctx)
    total_trial_duration = _get_total_trial_duration(ctx)
    events = [
        (name, time - offset + trial_number * total_trial_duration)
        for trial_number, name, time in events
    ]
    events.append(("END", len(events) * total_trial_duration))

    for (name, start), (_, end) in zip(events, events[1:]):
        nwb.add_epoch(
            start_time=start,
            stop_time=end,
            tags=name,
            timeseries=_get_main_time_series(ctx),
        )
        assert start < end

In [18]:
def _get_main_time_series(ctx: SessionContext) -> List[TimeSeries]:
    nwb = ctx.nwb
    return [nwb.processing["behavior"].get("BehavioralEvents.response").get_timeseries("response")]

In [19]:
# Event_Retention_Trial_183_Spike_Times
_EVENT_RE = re.compile(r"Event_([a-zA-Z]+)_.*Trial_(\d+)_Spike_Times")

In [20]:
def write_trial_data(ctx: SessionContext):
    nwb = ctx.nwb
    nwb.add_trial_column(
        name="set_size",
        description="Number of colored squares in array presented during encoding period. Either 1, 2, 4, or 6.",
    )
    nwb.add_trial_column(
        name="solution",
        description='The correct answer to the question "Was the array presented during the test the same as the one presented during the encoding period?"',
    )

    nix = ctx.nix
    trials = nix.sections["Session"].sections["Trial properties"].sections
    duration = _get_total_trial_duration(ctx)
    for trial in trials:
        trial = trial.props
        trial_number = int(trial["Trial number"].values[0]) - 1
        start_time = trial_number * duration
        stop_time = start_time + duration
        nwb.add_trial(
            id=trial_number,
            start_time=start_time,
            stop_time=stop_time,
            set_size=int(trial["Set size"].values[0]),
            solution=trial["Match"].values[0] == 1,
            timeseries=_get_main_time_series(ctx),
        )

In [21]:
def write_behavior(ctx: SessionContext):
    nwb = ctx.nwb
    behavior_module = nwb.create_processing_module(
        name="behavior", description="Data for all trials in this session."
    )
    nix = ctx.nix
    trials = nix.sections["Session"].sections["Trial properties"].sections
    offset = _get_time_offset(ctx)
    duration = _get_total_trial_duration(ctx)
    data = []
    timestamps = []
    for trial in trials:
        trial_number = int(trial["Trial number"]) - 1
        trial = trial.props
        data.append(trial["Response"].values[0] == 1)
        time = trial["Response time"].values[0] - offset + trial_number * duration
        timestamps.append(time)
        # Could do this because everything after the response is fake time inbetween trials because NWB requires all trials to be on one long stretch
        # But not really necessary since we don't have any measurements then anyway
        # stop_time = (trial_number + 1.0) * duration
        # start_time = min(time, stop_time)
        # nwb.add_invalid_time_interval(
        #    start_time=start_time,
        #    stop_time=stop_time,
        # )

    time_series = TimeSeries(
        name="response",
        data=data,
        timestamps=timestamps,
        description='The participant\'s answer to the question "Was the array presented during the test the same as the one presented during the encoding period?"',
        unit="n/a",  # Might as well use https://github.com/rly/ndx-events, but it's not built-in...
        continuity="instantaneous",
    )

    behavioral_events = BehavioralEvents(
        name=f"BehavioralEvents.response", time_series=time_series
    )

    behavior_module.add(behavioral_events)

In [22]:
def write_waveforms(ctx: SessionContext, ieeg_electrode_group: ElectrodeGroup):
    nwb = ctx.nwb
    session = _get_session_data(ctx)
    waveforms = session.groups["Spike waveforms"].data_arrays
    spike_times = session.groups["Spike times"].data_arrays
    if len(waveforms) == 0:
        assert len(spike_times) == 0
        return

    waveforms = [
        (_WAVEFORM_RE.findall(waveform.name)[0], waveform) for waveform in waveforms
    ]
    waveforms = [
        (int(unit), electrode, channel, values)
        for (unit, electrode, channel), values in waveforms
    ]
    waveforms.sort(key=lambda x: x[0])

    spike_times = [
        (_SPIKE_TIMES_RE.findall(spike_time.name)[0], spike_time[:])
        for spike_time in spike_times
    ]
    unit_to_trial_to_spike_times = {}
    for (unit, electrode, channel, trial), values in spike_times:
        unit_to_trial_to_spike_times.setdefault(int(unit), {})[trial] = (
            electrode,
            channel,
            values,
        )

    nwb.add_unit_column(
        name="offset",
        description="The offset in seconds of the first waveform voltage relative to the spike event",
    )
    waveform_sampling_interval = session.groups["Spike waveforms"].data_arrays[0].dimensions[1].sampling_interval
    nwb.units.waveform_rate = 1.0 / waveform_sampling_interval
    waveform_offset = session.groups["Spike waveforms"].data_arrays[0].dimensions[1].offset

    for unit, electrode, channel, waveform_voltages in waveforms:
        trial_to_spike_times = unit_to_trial_to_spike_times[unit]

        spike_times_for_trials = []
        for trial, (electrode_, channel_, spike_times) in trial_to_spike_times.items():
            assert electrode == electrode_
            assert channel == channel_
            spike_times_for_trials.append((trial, spike_times))
        spike_times_for_trials.sort(key=lambda x: x[0])
        spike_times_for_trials = [
            spike_times for _, spike_times in spike_times_for_trials
        ]
        spike_times_for_trials = _untrialize_irregular_timestamps(
            spike_times_for_trials, ctx
        )

        electrode_label = f"{electrode}{channel}"
        electrode_index = _get_electrode_index(ctx, electrode_label)

        means = [micro_volt * 1e-6 for micro_volt in waveform_voltages[:][0]]
        sds = [micro_volt * 1e-6 for micro_volt in waveform_voltages[:][1]]

        obs_intervals = _get_obs_intervals(ctx)
        nwb.add_unit(
            id=int(unit),
            electrode_group=ieeg_electrode_group,
            electrodes=[electrode_index],
            waveform_mean=means,
            waveform_sd=sds,
            spike_times=spike_times_for_trials,
            obs_intervals=obs_intervals,
            offset=waveform_offset
        )

In [23]:
def _get_obs_intervals(ctx: SessionContext) -> List[Tuple[float, float]]:
    trials = ctx.nix.sections["Session"].sections["Trial properties"].sections
    observation_duration = _get_data_collection_duration(ctx)
    total_duration = _get_total_trial_duration(ctx)
    return [(i * total_duration, i * total_duration + observation_duration) for i in range(len(trials))]

In [24]:
# Spike_Waveform_Unit_1_uAHL_2
_WAVEFORM_RE = re.compile(r"Spike_Waveform_Unit_(\d+)_u([a-zA-Z]+)_(\d+)")
# Spike_Times_Unit_36_uPHR_1_Trial_16
_SPIKE_TIMES_RE = re.compile(r"Spike_Times_Unit_(\d+)_u([a-zA-Z]+)_(\d+)_Trial_(\d+)")

In [25]:
def _untrialize_irregular_timestamps(
        timestamps: List[List[float]], ctx: SessionContext
) -> List[float]:
    offset = _get_time_offset(ctx)
    duration = _get_total_trial_duration(ctx)
    untrialized = []
    for trial, times in enumerate(timestamps):
        times = [time - offset + trial * duration for time in times]
        untrialized.extend(times)
    return untrialized

In [26]:
def _get_electrode_index(ctx: SessionContext, electrode: str) -> int:
    nwb = ctx.nwb
    index = next(
        index
        for index, electrodes in enumerate(nwb.electrodes["label"][:])
        if electrode in electrodes
    )
    return nwb.electrodes["id"][index]

In [27]:
def _get_data_collection_duration(ctx: SessionContext) -> float:
    return ctx.nix.sections["Session"].props["Trial duration"].values[0]

In [28]:
def _get_uncollected_duration(_ctx: SessionContext) -> float:
    # Max time between trials according to https://www.sciencedirect.com/science/article/pii/S1053811922002518?via%3Dihub
    return 2.3

In [29]:
def _get_total_trial_duration(ctx: SessionContext) -> float:
    data_collection_duration = _get_data_collection_duration(ctx)
    participant_response_time = _get_uncollected_duration(ctx)
    return data_collection_duration + participant_response_time

In [30]:
def _get_time_offset(_ctx: SessionContext) -> float:
    return -1.7  # taken from looking at event data

In [31]:
def write_lfp(ctx: SessionContext):
    micro_files = _find_micro_data_files(ctx)
    files = micro_files[ctx.subject][ctx.session]
    for part in files.values():
        print("Loading part", part)
        file = h5py.File(part, 'r')
        labels = get_matlab_matrix_scalars_ragged(file, "label")
        labels = [''.join([chr(character) for character in word]) for word in labels]
        print(len(labels))
        times = get_matlab_matrix(file, "time")
        print(times.shape)
        print(times[:100, 0])
        trials = get_matlab_matrix(file, "trial")
        print(trials.shape)
        print(trials[:100, 0])

In [32]:
def _find_micro_data_files(ctx: SessionContext) -> Dict[int, Dict[int, Dict[int, PathLike]]]:
    dir = get_micro_dir(ctx)
    micro_files = {}
    for file in listdir(dir):
        match = MATLAB_RE.match(file)
        if match:
            subject, session, part = map(int, match.groups())
            if subject not in CORRECTED_PATIENT:
                continue
            subject = CORRECTED_PATIENT[subject]
            if subject not in micro_files:
                micro_files[subject] = {}
            if session not in micro_files[subject]:
                micro_files[subject][session] = {}
            micro_files[subject][session][part] = join(dir, file)
    return micro_files


In [33]:
CORRECTED_PATIENT = {
    13: 1,
    19: 2,
    22: 3,
    23: 4,
    28: 5,
    29: 6,
    30: 7,
    34: 8,
    35: 9,
    37: 10,
    38: 11,
    40: 12,
    41: 13,
}

In [34]:
# Micro_Intervals_Patient_19_Session_01_Part_01_Interval_0_NaN_s
MATLAB_RE = re.compile(r"Micro_Intervals_Patient_(\d+)_Session_(\d+)_Part_(\d+)_Interval_0_NaN_s.mat")

Main

In [35]:
if __name__ == "__main__":
    ctx = create_context(1, 1)
    write_lfp(ctx)
    #context = convert_nix_to_nwb(1, 1)
    #write_nwb(context)

Loading part C:\Users\conta\git\janhohenheim\usz-neuro-conversion\in\to_convert\Human_MTL_units_visual_WM\micro_data\Micro_Intervals_Patient_13_Session_01_Part_01_Interval_0_NaN_s.mat
64
(7427008, 1)
[3.125000e-05 2.812500e-04 5.312500e-04 7.812500e-04 1.031250e-03
 1.281250e-03 1.531250e-03 1.781250e-03 2.031250e-03 2.281250e-03
 2.531250e-03 2.781250e-03 3.031250e-03 3.281250e-03 3.531250e-03
 3.781250e-03 4.031250e-03 4.281250e-03 4.531250e-03 4.781250e-03
 5.031250e-03 5.281250e-03 5.531250e-03 5.781250e-03 6.031250e-03
 6.281250e-03 6.531250e-03 6.781250e-03 7.031250e-03 7.281250e-03
 7.531250e-03 7.781250e-03 8.031250e-03 8.281250e-03 8.531250e-03
 8.781250e-03 9.031250e-03 9.281250e-03 9.531250e-03 9.781250e-03
 1.003125e-02 1.028125e-02 1.053125e-02 1.078125e-02 1.103125e-02
 1.128125e-02 1.153125e-02 1.178125e-02 1.203125e-02 1.228125e-02
 1.253125e-02 1.278125e-02 1.303125e-02 1.328125e-02 1.353125e-02
 1.378125e-02 1.403125e-02 1.428125e-02 1.453125e-02 1.478125e-02
 1.50312

In [None]:
if __name__ == "__main__":
    project = "Human_MTL_units_visual_WM"
    for subject, sessions in find_nix_files(project).items():
        for session, _ in sessions.items():
            print(f"Converting subject {subject} session {session}")
            try:
                context = convert_nix_to_nwb(subject, session)
                write_nwb(context)
                print("Done")
            except Exception as e:
                print(f"Failed to convert {subject} {session}")
                print(e)
    print("Everything done!")