## sessions
- each session is a different dataset
- folders inside each session: responses, poses, reward and trials
- trials is not a modality, it's just for me to keep track of information when evaluating the model

### poses
- `data.mem`: $(x, y, z)$ coordinates of each joint for the full session. shape: `(n_frames, 3*n_joints)`
- `meta.yml`: dictionary with metadata, see `metadata` below
- `meta/com.npy`: $(x, y, z)$ coordinate of the center of mass (com) for the full session. shape: `(n_frames, 3)`
- `meta/joints.npy`: name of the joints tracked
- `meta/skeleton.npy`: adjacency list for all the joints

### responses
- `data.mem`: timing of each spike of each neuron for the full session. shape: `(n_frames, n_neurons)`
- `meta.yml`: dictionary with metadata, see `metadata` below
- `meta/areas.npy`: array with the brain areas

### target
- `data.mem`: something like a np.array with the target information for each frame. 0 before the cue is shown, 1 after the cue was showed if it indicates left and 2 if it indicates right. shape: `(n_frames, 1)`
- `meta.yml`: dictionary with metadata, see `metadata` below

### metadata (meta.yml files)
- `dtype`: necessary to load `.mem` files
- `is_mem_mapped`: if it's `.mem` or `.npy`
- `modality`: sequence or trial
- `n_signals`: number of joints, neurons etc
- `n_timestamps`: number of frames 
- `phase_shift_per_signal`: useless for me, always false
- `sampling_rate`: 100 for poses, 1000 for responses, target and trials
- `start_time`: 0 in my case, i think

### trials
- trial_start: when each trial starts
- trial_end: when each trial ends
- toc (time of commitment): when the monkey committed to a target
- trial_type: precue, gbyk, feedback
- walk_start: when the monkey starts walking
- walk_end: when the monkey stops walking
- cue_start: time of the signal that shows the reward
- cue_end: end of the signal
- choice: L, R
- reward: L, R (can be different from choice if the monkey ignores the highest reward)

##### to keep in mind
- figure out how to parse the toc
- how walk_start and walk_end are represented? relative to trial start?
- how is walk_start/mt_on and walk_end/mt_off encoded?
- for ken_20230618, some trials are off by one, which gives nan for coords. should be fixed when we have data for the full session.
- if we don't know when the cue is showed for "precue" trials, how should we define the target? or is it the case that the trial start time in this case is the time when the cue is showed?

In [13]:
import os
import pickle
import random
import shutil
import sys

from pathlib import Path

import h5py
import numpy as np
import pandas as pd
import yaml

sys.path.append(str(Path().resolve().parent))

### utils

In [14]:
def extract_strings(references, file):
    return ["".join((chr(x[0]) for x in file[ref[0]])) for ref in references]


old_keypoints = [
    "L_wrist",
    "L_elbow",
    "L_shoulder",
    "R_wrist",
    "R_elbow",
    "R_shoulder",
    "L_ankle",
    "L_knee",
    "L_hip",
    "R_ankle",
    "R_knee",
    "R_hip",
    "E_tail",
    "S_tail",
    "neck",
    "head",
    "L_ear",
    "R_ear",
    "L_eye",
    "R_eye",
    "nose",
]

old_skeleton = [
    [1, 0],
    [1, 2],
    [4, 3],
    [4, 5],
    [7, 6],
    [7, 8],
    [10, 9],
    [10, 11],
    [13, 8],
    [13, 11],
    [13, 12],
    [14, 2],
    [14, 5],
    [14, 13],
    [14, 15],
    [14, 16],
    [14, 17],
    [14, 20],
    [15, 16],
    [15, 17],
    [15, 18],
    [15, 19],
    [20, 16],
    [20, 17],
    [20, 18],
    [20, 19],
]

new_keypoints = [
    "E_tail",
    "L_ankle",
    "L_ear",
    "L_elbow",
    "L_eye",
    "L_hip",
    "L_knee",
    "L_shoulder",
    "L_wrist",
    "R_ankle",
    "R_ear",
    "R_elbow",
    "R_eye",
    "R_hip",
    "R_knee",
    "R_shoulder",
    "R_wrist",
    "S_tail",
    "head",
    "neck",
    "nose",
]
new_skeleton = [
    [3, 7],
    [3, 8],
    [6, 1],
    [6, 5],
    [11, 15],
    [11, 16],
    [14, 9],
    [14, 13],
    [17, 0],
    [17, 5],
    [17, 13],
    [18, 2],
    [18, 4],
    [18, 10],
    [18, 12],
    [19, 2],
    [19, 7],
    [19, 10],
    [19, 15],
    [19, 17],
    [19, 18],
    [19, 20],
    [20, 2],
    [20, 4],
    [20, 10],
    [20, 12],
]

### process trials

In [15]:
def process_trials(SESSION_DIR: Path) -> None:
    TRIALS_DIR = SESSION_DIR / "trials"
    TRIALS_DIR.mkdir(exist_ok=True, parents=True)

    filename = SESSION_DIR.name
    trials_info = pd.read_csv(f"../data/gbyk/{filename}.csv")
    success_idxs = trials_info[trials_info["outcome"] == "success"].index

    with h5py.File(f"../data/gbyk/{filename}.mat") as dataset:
        trials_start = dataset["spikes"]["trial_start"][:][0][success_idxs]
        trials_end = dataset["spikes"]["trial_end"][:][0][success_idxs]

    random.seed(0)
    train_idxs = list(range(len(success_idxs)))
    test_size = int(len(train_idxs) * 0.2)
    val_size = int(len(train_idxs) * 0.2)
    test_idxs = random.sample(train_idxs, test_size)
    remaining_idxs = sorted(list(set(train_idxs) - set(test_idxs)))
    val_idxs = random.sample(remaining_idxs, val_size)
    tiers = {"train": train_idxs, "validation": val_idxs, "test": test_idxs}

    trials_info.columns = trials_info.columns.map(str.lower)
    trials_info = trials_info[trials_info["outcome"] == "success"].reset_index(
        drop=True
    )

    assert len(trials_info) == len(trials_start)
    assert len(trials_info) == len(trials_end)

    trials_info["reward"] = np.where(
        trials_info["reward"] == 1.8,
        trials_info["choice"],
        trials_info["choice"].map({"L": "R", "R": "L"}),
    )
    trials_info["cue"] = np.where(
        trials_info["block"] == "precue",
        trials_info["precue_abstime"],  # Condition 1
        np.where(
            trials_info["block"] == "gbyk",
            (trials_info["cuestart_abstime"] + trials_info["cueend_abstime"])
            / 2,  # Condition 2
            trials_end,  # Condition 3 (feedback)
        ),
    )

    for tier in tiers:
        for idx in tiers[tier]:
            with open(TRIALS_DIR / f"{idx:05d}.yml", "w") as f:
                data = {
                    "choice": trials_info.loc[idx, "choice"],
                    "cue_frame_idx": int(trials_info.loc[idx, "cue"].item()),
                    "first_frame_idx": int(trials_start[idx].item()),
                    "num_frames": int(
                        (trials_end[idx] - trials_start[idx]).item()
                    ),
                    # "walk_start": trials_info.loc[idx, "walk_start"],
                    # "walk_end": trials_info.loc[idx, "walk_end"],
                    "reward": trials_info.loc[idx, "reward"],
                    "tier": tier,
                    "trial_idx": idx,
                    "type": trials_info.loc[idx, "block"],
                }
                yaml.dump(data, f)

### process spikes

In [16]:
def process_spikes(SESSION_DIR: Path) -> None:
    SPIKES_DIR = SESSION_DIR / "spikes"
    SPIKES_DIR.mkdir(exist_ok=True, parents=True)

    META_SPIKES_DIR = SPIKES_DIR / "meta"
    META_SPIKES_DIR.mkdir(exist_ok=True, parents=True)

    filename = SESSION_DIR.name
    dataset = h5py.File(f"../data/gbyk/{filename}.mat")

    ### responses
    spikes = dataset["spikes"]["session"][:]
    mmap = np.memmap(
        SPIKES_DIR / "data.mem",
        dtype="float32",
        mode="w+",
        shape=spikes.shape,
    )
    mmap[:] = spikes[:]
    mmap.flush()

    ### meta
    with open(SPIKES_DIR / "meta.yml", "w") as f:
        meta = {
            "dtype": "float32",
            "end_time": len(spikes),
            "is_mem_mapped": True,
            "modality": "sequence",
            "n_signals": spikes.shape[-1],
            "n_timestamps": len(spikes),
            "phase_shift_per_signal": False,
            "sampling_rate": 1000,
            "start_time": 0,
        }
        yaml.dump(meta, f)

    with open(META_SPIKES_DIR / "areas.npy", "wb") as f:
        areas = np.array(
            extract_strings(dataset["spikes"]["array_labels"], dataset)
        )
        array_code = dataset["spikes"]["array_code"][:].ravel() - 1  # 1-based
        np.save(f, areas[array_code.astype(int)])

    assert len(array_code) == spikes.shape[-1], (
        f"{len(array_code)}, {spikes.shape[-1]}"
    )

### process spike count

In [17]:
def process_spike_count(SESSION_DIR: Path, sampling_rate: int = 20) -> None:
    SPIKE_COUNT_DIR = SESSION_DIR / "spike_count"
    SPIKE_COUNT_DIR.mkdir(exist_ok=True, parents=True)
    META_SPIKE_COUNT_DIR = SPIKE_COUNT_DIR / "meta"
    META_SPIKE_COUNT_DIR.mkdir(exist_ok=True, parents=True)
    SPIKES_DIR = SESSION_DIR / "spikes"
    period = int(1000 / sampling_rate)

    with open(SPIKES_DIR / "meta.yml", "r") as f:
        meta = yaml.safe_load(f)
        spikes = np.memmap(
            SPIKES_DIR / "data.mem",
            dtype=meta["dtype"],
            mode="r",
            shape=(meta["n_timestamps"], meta["n_signals"]),
        )

    duration = len(spikes) - (len(spikes) % period)
    spike_count = (
        spikes[:duration].reshape(len(spikes) // period, period, -1).sum(axis=1)
    )

    mmap = np.memmap(
        SPIKE_COUNT_DIR / "data.mem",
        dtype="float32",
        mode="w+",
        shape=spike_count.shape,
    )
    mmap[:] = spike_count[:]
    mmap.flush()

    with open(SPIKE_COUNT_DIR / "meta.yml", "w") as f:
        meta = {
            "dtype": "float32",
            "end_time": len(spike_count),
            "is_mem_mapped": True,
            "modality": "sequence",
            "n_signals": spike_count.shape[-1],
            "n_timestamps": len(spike_count),
            "phase_shift_per_signal": False,
            "sampling_rate": sampling_rate,
            "start_time": 0,
        }
        yaml.dump(meta, f)

    shutil.copytree(
        SPIKES_DIR / "meta", META_SPIKE_COUNT_DIR, dirs_exist_ok=True
    )

### process poses

In [18]:
def process_poses(SESSION_DIR: Path, old_format: bool = False) -> None:
    POSES_DIR = SESSION_DIR / "poses"
    POSES_DIR.mkdir(exist_ok=True, parents=True)

    META_POSES_DIR = POSES_DIR / "meta"
    META_POSES_DIR.mkdir(exist_ok=True, parents=True)

    file_name = SESSION_DIR.name
    file_path = Path(f"../data/gbyk/{file_name}.mat")
    dataset = h5py.File(file_path)

    ### center of mass
    with open(META_POSES_DIR / "com.npy", "wb") as f:
        x_com = dataset["spikes"]["Traj"]["x"][0]
        y_com = dataset["spikes"]["Traj"]["y"][0]
        z_com = dataset["spikes"]["Traj"]["z"][0]
        com = np.stack([x_com, y_com, z_com], axis=0).T
        np.save(f, com)

    ### coords
    # first five joints are useless for the denoised datasets
    if old_format:
        x_coords = dataset["spikes"]["Body"]["x"][5:]
        y_coords = dataset["spikes"]["Body"]["y"][5:]
        z_coords = dataset["spikes"]["Body"]["z"][5:]
    else:
        x_coords = dataset["spikes"]["Body"]["x"]
        y_coords = dataset["spikes"]["Body"]["y"]
        z_coords = dataset["spikes"]["Body"]["z"]

    # after next line the shape is (n_joints, 3, n_frames)
    coords = np.stack([x_coords, y_coords, z_coords], axis=1)
    coords = np.reshape(coords, (-1, coords.shape[-1])).T
    mmap = np.memmap(
        os.path.join(POSES_DIR, "data.mem"),
        dtype="float32",
        mode="w+",
        shape=coords.shape,
    )
    mmap[:] = coords
    mmap.flush()

    ### meta
    with open(POSES_DIR / "meta.yml", "w") as f:
        meta = {
            "dtype": "float32",
            "end_time": len(coords),
            "is_mem_mapped": True,
            "modality": "sequence",
            "n_signals": coords.shape[-1],
            "n_timestamps": len(coords),
            "phase_shift_per_signal": False,
            "sampling_rate": 100,
            "start_time": 0,
        }
        yaml.dump(meta, f)

    with open(META_POSES_DIR / "joints.npy", "wb") as f:
        if old_format:
            np.save(f, np.array(old_keypoints))
        else:
            np.save(f, np.array(new_keypoints))

    with open(META_POSES_DIR / "skeleton.npy", "wb") as f:
        if old_format:
            np.save(f, np.array(old_skeleton))
        else:
            np.save(f, np.array(new_skeleton))

    assert len(com) == len(coords), f"{len(com)}, {len(coords)}"

### process target

In [19]:
def process_target(SESSION_DIR: Path) -> None:
    TARGET_DIR = SESSION_DIR / "target"
    TARGET_DIR.mkdir(exist_ok=True, parents=True)

    with open(SESSION_DIR / "spikes/meta.yml", "r") as f:
        meta_responses = yaml.safe_load(f)
        target = np.zeros(meta_responses["n_timestamps"])

    TRIALS_DIR = SESSION_DIR / "trials"
    for trial in TRIALS_DIR.iterdir():
        with open(trial, "r") as f:
            trial_info = yaml.safe_load(f)
            cue_idx = trial_info["cue_frame_idx"]
            end_idx = trial_info["first_frame_idx"] + trial_info["num_frames"]
            target[cue_idx:end_idx] = 1 if trial_info["reward"] == "R" else 2

    mmap = np.memmap(
        TARGET_DIR / "data.mem",
        dtype="float32",
        mode="w+",
        shape=target.shape,
    )
    mmap[:] = target[:]
    mmap.flush()

    ### meta
    with open(TARGET_DIR / "meta.yml", "w") as f:
        meta = {
            "dtype": "float32",
            "end_time": len(target),
            "is_mem_mapped": True,
            "modality": "sequence",
            "n_signals": target.shape[-1] if len(target.shape) > 1 else 1,
            "n_timestamps": len(target),
            "phase_shift_per_signal": False,
            "sampling_rate": 1000,
            "start_time": 0,
        }
        yaml.dump(meta, f)

### process everything

In [20]:
sessions = [
    "bex_20230623_denoised",
    "ken_20230614_denoised",
    "ken_20230618_denoised",
]
old_format = True

In [21]:
sessions = [
    "bex_20230621_spikes_sorted_SES",
    "bex_20230624_spikes_sorted_SES",
    "bex_20230629_spikes_sorted_SES",
    "bex_20230630_spikes_sorted_SES",
    "bex_20230701_spikes_sorted_SES",
    "bex_20230708_spikes_sorted_SES",
    # "ken_20230614_spikes_sorted_SES",
    "ken_20230618_spikes_sorted_SES",
    "ken_20230622_spikes_sorted_SES",
    "ken_20230629_spikes_sorted_SES",
    "ken_20230630_spikes_sorted_SES",
    "ken_20230701_spikes_sorted_SES",
    "ken_20230703_spikes_sorted_SES",
]
old_format = False

In [22]:
for session in sessions:
    SESSION_DIR = Path(f"../data/gbyk/{session}")
    SESSION_DIR.mkdir(parents=True, exist_ok=True)
    process_trials(SESSION_DIR)
    process_spikes(SESSION_DIR)
    process_spike_count(SESSION_DIR)
    process_poses(SESSION_DIR, old_format)
    process_target(SESSION_DIR)

### sanity checks

In [None]:
for session in sessions:
    TRIALS_DIR = os.path.join(session, "trials")
    print(len(os.listdir(TRIALS_DIR)))

In [None]:
for session in sessions:
    META_SPIKE_COUNT_DIR = os.path.join(session, "responses", "meta")
    for filename in os.listdir(META_SPIKE_COUNT_DIR):
        with open(os.path.join(META_SPIKE_COUNT_DIR, filename), "rb") as f:
            print(np.load(f).shape)

In [None]:
for session in sessions:
    META_POSES_DIR = os.path.join(session, "poses", "meta")
    for filename in os.listdir(META_POSES_DIR):
        with open(os.path.join(META_POSES_DIR, filename), "rb") as f:
            print(np.load(f).shape)

In [None]:
for session in sessions:
    POSES_DIR = os.path.join(session, "poses")
    SPIKES_DIR = os.path.join(session, "responses")
    poses_meta = yaml.safe_load(open(os.path.join(POSES_DIR, "meta.yml")))
    resp_meta = yaml.safe_load(open(os.path.join(SPIKES_DIR, "meta.yml")))
    coords = np.memmap(
        os.path.join(POSES_DIR, "data.mem"),
        dtype=poses_meta["dtype"],
        mode="r",
        shape=(poses_meta["n_timestamps"], poses_meta["n_signals"]),
    )
    spikes = np.memmap(
        os.path.join(SPIKES_DIR, "data.mem"),
        dtype=resp_meta["dtype"],
        mode="r",
        shape=(
            resp_meta["n_timestamps"],
            resp_meta["n_signals"],
        ),
    )
    print(coords.shape, spikes.shape)

In [None]:
coords = coords.reshape(-1, 21, 3)
print(coords[..., 1].max(), coords[..., 1].min())
print(coords[..., 0].max(), coords[..., 0].min())

### convert to open with forge

In [None]:
for session in sessions:
    POSES_DIR = os.path.join(session, "poses")
    META_POSES_DIR = os.path.join(session, "poses", "meta")
    TRIALS_DIR = os.path.join(session, "trials")

    poses_meta = yaml.safe_load(open(os.path.join(POSES_DIR, "meta.yml")))
    coords = np.memmap(
        os.path.join(POSES_DIR, "data.mem"),
        dtype=poses_meta["dtype"],
        mode="r",
        shape=(poses_meta["n_timestamps"], poses_meta["n_signals"]),
    )

    with open(os.path.join(META_POSES_DIR, "skeleton.npy"), "rb") as f:
        skeleton = np.load(f)

    for i, trial in enumerate(sorted(os.listdir(TRIALS_DIR))):
        with open(os.path.join(TRIALS_DIR, trial), "r") as f:
            trial_info = yaml.safe_load(f)
            trial_start = int(trial_info["first_frame_idx"] // 10)
            trial_end = int(
                (trial_info["first_frame_idx"] + trial_info["num_frames"]) // 10
            )

        with open(os.path.join("forge", f"{session}_trial_{i}.pkl"), "wb") as f:
            poses = coords[trial_start:trial_end] / 4.325
            poses = poses.reshape(-1, 21, 3).copy()
            poses[..., [1, 2]] = poses[..., [2, 1]]
            poses[..., [0]] = -poses[..., [0]] + 0.3
            poses[..., [1]] = poses[..., [1]] - 0.05
            data = {
                "sequence": poses,
                "skeleton": skeleton,
                "frametime": 1000 // poses_meta["sampling_rate"],
            }
            pickle.dump(data, f)