## sessions
- each session is a different dataset
- folders inside each session: responses, poses and trials
- trials is not a modality, it's just for me to keep track of information when evaluating the model

### poses
- `data.mem`: $(x, y, z)$ coordinates of each joint for the full session. shape: `(n_frames, 3*n_joints)`
- `meta.yml`: dictionary with metadata, see `metadata` below
- `meta/com.npy`: $(x, y, z)$ coordinate of the center of mass (com) for the full session. shape: `(n_frames, 3)`
- `meta/joints.npy`: name of the joints tracked
- `meta/skeleton.npy`: adjacency list for all the joints

### responses
- `data.mem`: timing of each spike of each neuron for the full session. shape: `(n_frames, n_neurons)`
- `meta.yml`: dictionary with metadata, see `metadata` below
- `meta/areas.npy`: array with the brain areas
- `meta/bad_channels.npy`: array with the indices of the bad channels

### metadata (meta.yml files)
- `dtype`: necessary to load `.mem` files
- `is_mem_mapped`: if it's `.mem` or `.npy`
- `modality`: sequence or trial
- `n_signals`: number of joints, neurons etc
- `n_timestamps`: number of frames 
- `phase_shift_per_signal`: useless for me, always false
- `sampling_rate`: 100 for poses, 1000 for responses, target and trials
- `start_time`: 0 in my case, i think

### trials
- trial_start: when each trial starts
- trial_end: when each trial ends

### to keep in mind
- test trials
    - bex_20230226: 1
    - jon_20230203: 1
    - luk_20230126: 0
- validation trials
    - bex_20230226: 2
    - jon_20230126: 0
    - luk_20230202: 1

In [None]:
import csv
import os
import pickle

from pathlib import Path

os.chdir(Path().resolve().parent)

import numpy as np
import yaml

### utils

In [2]:
com_kps2idx = [0, 1, 8, 13, 18, 23, 28]

keypoints = [
    "neck",
    "spine",
    "head",
    "L_ear",
    "R_ear",
    "L_eye",
    "R_eye",
    "nose",
    "L_shoulder",
    "L_elbow",
    "L_wrist",
    "L_upperArm",
    "L_lowerArm",
    "R_shoulder",
    "R_elbow",
    "R_wrist",
    "R_upperArm",
    "R_lowerArm",
    "L_hip",
    "L_knee",
    "L_ankle",
    "L_upperLeg",
    "L_lowerLeg",
    "R_hip",
    "R_knee",
    "R_ankle",
    "R_upperLeg",
    "R_lowerLeg",
    "S_tail",
    "M_tail",
    "E_tail",
]

skeleton = [
    [0, 1],
    [0, 2],
    [0, 3],
    [0, 4],
    [0, 7],
    [0, 8],
    [0, 13],
    [2, 3],
    [2, 4],
    [2, 5],
    [2, 6],
    [7, 3],
    [7, 4],
    [7, 5],
    [7, 6],
    [11, 8],
    [11, 9],
    [12, 9],
    [12, 10],
    [16, 13],
    [16, 14],
    [17, 14],
    [17, 15],
    [21, 18],
    [21, 19],
    [22, 19],
    [22, 20],
    [26, 23],
    [26, 24],
    [27, 24],
    [27, 25],
    [28, 1],
    [28, 18],
    [28, 23],
    [29, 28],
    [29, 30],
]

### process trials

In [None]:
def process_trials(filename):
    TRIALS_DIR = os.path.join(filename, "trials")
    os.makedirs(TRIALS_DIR, exist_ok=True)

    with open(os.path.join(filename, "meta.yml"), "r") as f:
        meta = yaml.safe_load(f)

    test_trials = {
        "bex_20230226": [1],
        "jon_20230203": [1],
        "luk_20230126": [0],
    }
    valid_trials = {
        "bex_20230226": [2],
        "jon_20230126": [0],
        "luk_20230202": [1],
    }
    union_trials = {
        key: valid_trials.get(key, []) + test_trials.get(key, [])
        for key in set(valid_trials) | set(test_trials)
    }
    train_trials = {
        filename: (
            list(range(len(meta["trials"])))
            if filename not in union_trials
            else [
                i
                for i in range(len(meta["trials"]))
                if filename in union_trials and i not in union_trials[filename]
            ]
        )
    }
    tiers = {
        "train": train_trials[filename],
        "validation": valid_trials.get(filename, []),
        "test": test_trials.get(filename, []),
    }
    trials_start = [10 * trial[0] for trial in meta["trials"]]
    trials_end = [10 * trial[1] for trial in meta["trials"]]

    for tier in tiers:
        for idx in tiers[tier]:
            with open(os.path.join(TRIALS_DIR, f"{idx}.yml"), "w") as f:
                data = {
                    "first_frame_idx": trials_start[idx],
                    "num_frames": (trials_end[idx] - trials_start[idx]),
                    "tier": tier,
                    "trial_idx": idx,
                }
                yaml.dump(data, f)

### process responses

In [None]:
def process_responses(filename):
    RESP_DIR = os.path.join(filename, "responses")
    os.makedirs(RESP_DIR, exist_ok=True)

    MRESP_DIR = os.path.join(RESP_DIR, "meta")
    os.makedirs(MRESP_DIR, exist_ok=True)

    with open(os.path.join(filename, "meta.yml"), "r") as f:
        meta = yaml.safe_load(f)

    SPIKES_DIR = os.path.join(filename, "spikes")
    spikes = np.zeros((10 * meta["total_n_frames"], meta["n_channels"]))
    files = [
        os.path.join(SPIKES_DIR, f"spiketimes_ch{i + 1}.txt")
        for i in range(meta["n_channels"])
    ]

    for i, fn in enumerate(files):
        with open(fn, "r") as f:
            fcsv = csv.reader(f, delimiter=",")
            spiketimes = np.array(next(fcsv)).astype(float)
            spiketimes = spiketimes[
                (0 < spiketimes) & (spiketimes < 10 * meta["total_n_frames"])
            ].astype(int)
            spikes[spiketimes, i] = 1

    mmap = np.memmap(
        os.path.join(RESP_DIR, "data.mem"),
        dtype="float64",
        mode="w+",
        shape=spikes.shape,
    )
    mmap[:] = spikes
    mmap.flush()

    with open(os.path.join(RESP_DIR, "meta.yml"), "w") as f:
        meta_resp = {
            "dtype": "float64",
            "end_time": len(spikes),
            "is_mem_mapped": True,
            "modality": "sequence",
            "n_signals": spikes.shape[-1],
            "n_timestamps": len(spikes),
            "phase_shift_per_signal": False,
            "sampling_rate": 1000,
            "start_time": 0,
        }
        yaml.dump(meta_resp, f)

    with open(os.path.join(MRESP_DIR, "areas.npy"), "wb") as f:
        areas = np.empty(meta["n_channels"], dtype="<U3")
        for area, ranges in meta["areas"].items():
            for start, end in ranges:
                areas[start:end] = area
        np.save(f, areas)

    with open(os.path.join(MRESP_DIR, "bad_channels.npy"), "wb") as f:
        # bad channels are 1-indexed
        np.save(f, np.array(meta["bad_channels"]) - 1)

### process poses

In [None]:
def process_poses(filename):
    POSES_DIR = os.path.join(filename, "poses")
    os.makedirs(POSES_DIR, exist_ok=True)

    MPOSES_DIR = os.path.join(POSES_DIR, "meta")
    os.makedirs(MPOSES_DIR, exist_ok=True)

    with open(os.path.join(filename, "meta.yml"), "r") as f:
        meta = yaml.safe_load(f)

    KEYPOINTS_DIR = os.path.join(filename, "keypoints")
    files = [
        os.path.join(KEYPOINTS_DIR, f)
        for f in ("x_coords.txt", "y_coords.txt", "z_coords.txt")
    ]
    coords = []

    for fn in files:
        with open(fn, "r") as f:
            fcsv = csv.reader(f, delimiter=",")
            coords.append([np.array(row).astype(float) for row in fcsv])

    coords = np.vstack(coords)
    coords = np.reshape(coords, (-1, 3 * meta["n_keypoints"]), "F")
    mmap = np.memmap(
        os.path.join(POSES_DIR, "data.mem"),
        dtype="float64",
        mode="w+",
        shape=coords.shape,
    )
    mmap[:] = coords
    mmap.flush()

    with open(os.path.join(POSES_DIR, "meta.yml"), "w") as f:
        meta_resp = {
            "dtype": "float64",
            "end_time": len(coords),
            "is_mem_mapped": True,
            "modality": "sequence",
            "n_signals": coords.shape[-1],
            "n_timestamps": len(coords),
            "phase_shift_per_signal": False,
            "sampling_rate": 100,
            "start_time": 0,
        }
        yaml.dump(meta_resp, f)

    with open(os.path.join(MPOSES_DIR, "com.npy"), "wb") as f:
        coords = np.reshape(coords, (-1, meta["n_keypoints"], 3))
        com = np.mean(coords[:, com_kps2idx, :], axis=1)
        np.save(f, com)

    with open(os.path.join(MPOSES_DIR, "joints.npy"), "wb") as f:
        np.save(f, np.array(keypoints))

    with open(os.path.join(MPOSES_DIR, "skeleton.npy"), "wb") as f:
        np.save(f, np.array(skeleton))

    assert len(com) == len(coords), f"{len(com)}, {len(coords)}"

### normalization

the coordinates of the poses for this dataset are not normalized and we have some huge values that lead to numerical issues. to fix this problem, i load all the poses used for training from all the sessions and then extract the min and max values of each session. once i have these values, i compute the median for the max, the median for the min and save this as meta information to normalize the poses when i'm loading the dataset.

In [None]:
def compute_normalization_factor(sessions):
    all_max, all_min = [], []

    for session in sorted(sessions):
        POSES_DIR = os.path.join(session, "poses")
        TRIALS_DIR = os.path.join(session, "trials")

        poses_meta = yaml.safe_load(open(os.path.join(POSES_DIR, "meta.yml")))
        coords = np.memmap(
            os.path.join(POSES_DIR, "data.mem"),
            dtype=poses_meta["dtype"],
            mode="r",
            shape=(poses_meta["n_timestamps"], poses_meta["n_signals"]),
        )

        for trial in sorted(os.listdir(TRIALS_DIR)):
            trial_info = yaml.safe_load(open(os.path.join(TRIALS_DIR, trial)))
            if trial_info["tier"] != "train":
                continue
            start = int(trial_info["first_frame_idx"] // 10)
            end = start + int(trial_info["num_frames"] // 10)
            all_max.append(np.max(coords[start:end]))
            all_min.append(np.min(coords[start:end]))

    median_max = np.median(all_max)
    median_min = np.median(all_min)

    for session in sessions:
        MPOSES_DIR = os.path.join(session, "poses", "meta")
        with open(os.path.join(MPOSES_DIR, "max.npy"), "wb") as f:
            np.save(f, np.array(median_max))
        with open(os.path.join(MPOSES_DIR, "min.npy"), "wb") as f:
            np.save(f, np.array(median_min))

### process everything

In [3]:
sessions = [
    "bex_20230221",
    "bex_20230222",
    "bex_20230223",
    "bex_20230224",
    "bex_20230225",
    "bex_20230226",
    "jon_20230125",
    "jon_20230126",
    "jon_20230127",
    "jon_20230130",
    "jon_20230131",
    "jon_20230202",
    "jon_20230203",
    "luk_20230126",  # Zurna asked to ignore this session
    "luk_20230127",
    "luk_20230130",
    "luk_20230131",
    "luk_20230202",
    "luk_20230203",
]

In [18]:
for session in sessions:
    os.makedirs(session, exist_ok=True)
    # process_trials(session)
    process_responses(session)
    # process_poses(session)

In [11]:
compute_normalization_factor(sessions)

### sanity checks

In [5]:
for session in sessions:
    TRIALS_DIR = os.path.join(session, "trials")
    print(len(os.listdir(TRIALS_DIR)))

2
1
2
2
2
3
1
2
1
1
2
2
2
2
2
1
2
2
1


In [None]:
for session in sessions:
    MRESP_DIR = os.path.join(session, "responses", "meta")
    for filename in sorted(os.listdir(MRESP_DIR)):
        with open(os.path.join(MRESP_DIR, filename), "rb") as f:
            print(np.load(f).shape)

(10,)
(128,)
(8,)
(128,)
(6,)
(128,)
(6,)
(128,)
(6,)
(128,)
(8,)
(128,)
(12,)
(128,)
(10,)
(128,)
(9,)
(128,)
(17,)
(128,)
(13,)
(128,)
(0,)
(128,)
(6,)
(128,)
(18,)
(128,)
(10,)
(128,)
(20,)
(128,)
(17,)
(128,)
(17,)
(128,)
(15,)
(128,)


In [None]:
for session in sessions:
    MPOSES_DIR = os.path.join(session, "poses", "meta")
    for filename in sorted(os.listdir(MPOSES_DIR)):
        with open(os.path.join(MPOSES_DIR, filename), "rb") as f:
            print(np.load(f).shape)

()
(135853, 3)
(36, 2)
()
(31,)
()
(112657, 3)
(36, 2)
()
(31,)
()
(175686, 3)
(36, 2)
()
(31,)
()
(226345, 3)
(36, 2)
()
(31,)
()
(171907, 3)
(36, 2)
()
(31,)
()
(196826, 3)
(36, 2)
()
(31,)
()
(203300, 3)
(36, 2)
()
(31,)
()
(222287, 3)
(36, 2)
()
(31,)
()
(117084, 3)
(36, 2)
()
(31,)
()
(108728, 3)
(36, 2)
()
(31,)
()
(180612, 3)
(36, 2)
()
(31,)
()
(165895, 3)
(36, 2)
()
(31,)
()
(242329, 3)
(36, 2)
()
(31,)
()
(166605, 3)
(36, 2)
()
(31,)
()
(195368, 3)
(36, 2)
()
(31,)
()
(146859, 3)
(36, 2)
()
(31,)
()
(170496, 3)
(36, 2)
()
(31,)
()
(151799, 3)
(36, 2)
()
(31,)
()
(109482, 3)
(36, 2)
()
(31,)


In [None]:
for session in sessions:
    POSES_DIR = os.path.join(session, "poses")
    RESP_DIR = os.path.join(session, "responses")
    poses_meta = yaml.safe_load(open(os.path.join(POSES_DIR, "meta.yml")))
    coords = np.memmap(
        os.path.join(POSES_DIR, "data.mem"),
        dtype=poses_meta["dtype"],
        mode="r",
        shape=(poses_meta["n_timestamps"], poses_meta["n_signals"]),
    )
    resp_meta = yaml.safe_load(open(os.path.join(RESP_DIR, "meta.yml")))
    spikes = np.memmap(
        os.path.join(RESP_DIR, "data.mem"),
        dtype=resp_meta["dtype"],
        mode="r",
        shape=(
            resp_meta["n_timestamps"],
            resp_meta["n_signals"],
        ),
    )
    print(session, coords.shape, spikes.shape)

bex_20230221 (135853, 93) (1358530, 128)
bex_20230222 (112657, 93) (1126570, 128)
bex_20230223 (175686, 93) (1756860, 128)
bex_20230224 (226345, 93) (2263450, 128)
bex_20230225 (171907, 93) (1719070, 128)
bex_20230226 (196826, 93) (1968260, 128)
jon_20230125 (203300, 93) (2033000, 128)
jon_20230126 (222287, 93) (2222870, 128)
jon_20230127 (117084, 93) (1170840, 128)
jon_20230130 (108728, 93) (1087280, 128)
jon_20230131 (180612, 93) (1806120, 128)
jon_20230202 (165895, 93) (1658950, 128)
jon_20230203 (242329, 93) (2423290, 128)
luk_20230126 (166605, 93) (1666050, 128)
luk_20230127 (195368, 93) (1953680, 128)
luk_20230130 (146859, 93) (1468590, 128)
luk_20230131 (170496, 93) (1704960, 128)
luk_20230202 (151799, 93) (1517990, 128)
luk_20230203 (109482, 93) (1094820, 128)


### convert to open with forge

In [None]:
for session in sessions:
    POSES_DIR = os.path.join(session, "poses")
    MPOSES_DIR = os.path.join(session, "poses", "meta")
    TRIALS_DIR = os.path.join(session, "trials")

    poses_meta = yaml.safe_load(open(os.path.join(POSES_DIR, "meta.yml")))
    coords = np.memmap(
        os.path.join(POSES_DIR, "data.mem"),
        dtype=poses_meta["dtype"],
        mode="r",
        shape=(poses_meta["n_timestamps"], poses_meta["n_signals"]),
    )

    with open(os.path.join(MPOSES_DIR, "skeleton.npy"), "rb") as f:
        skeleton = np.load(f)

    for i, trial in enumerate(sorted(os.listdir(TRIALS_DIR))):
        with open(os.path.join(TRIALS_DIR, trial), "r") as f:
            trial_info = yaml.safe_load(f)
            trial_start = int(trial_info["first_frame_idx"] // 10)
            trial_end = int(
                (trial_info["first_frame_idx"] + trial_info["num_frames"]) // 10
            )

        with open(os.path.join("forge", f"{session}_trial_{i}.pkl"), "wb") as f:
            poses = coords[trial_start:trial_end]
            poses = poses.reshape(-1, 31, 3).copy()
            poses[..., [1, 2]] = poses[..., [2, 1]]
            data = {
                "sequence": poses,
                "skeleton": skeleton,
                "frametime": 1000 // poses_meta["sampling_rate"],
            }
            pickle.dump(data, f)