# Take inferred syllables from ARHMM and assign mouse ID and experiment to each session

In [1]:
import os
import h5py
import numpy as np
import pandas as pd
from pathlib import Path
from copy import deepcopy
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import process_map
from datetime import datetime
from toolz.curried import pluck
from aging.organization.paths import FOLDERS
from aging.behavior.scalars import compute_scalars
from toolz import concat, compose, valmap, first, groupby, keyfilter, keymap

In [2]:
n_cpus = int(os.environ.get('SLURM_CPUS_PER_TASK', 1))

In [3]:
version = 6
data_folder = Path(f'/n/groups/datta/win/longtogeny/data/ontogeny/version_{version:02d}')
syllable_path = data_folder / 'all_data_pca/syllables.h5'

In [4]:
uuid_map = {}
for file in tqdm(concat(f.glob('**/results_00.h5') for f in FOLDERS)):
    try:
        with h5py.File(file, 'r') as h5f:
            uuid = h5f['metadata/uuid'][()].decode()
            uuid_map[uuid] = file
    except OSError:
        continue

with h5py.File(syllable_path, 'r') as h5f:
    h5f_uuids = list(h5f)
    uuid_map = keyfilter(lambda u: u in h5f_uuids, uuid_map)

0it [00:00, ?it/s]

In [5]:
with h5py.File(first(uuid_map.values()), 'r') as h5f:
    print(list(h5f['metadata/acquisition']))

['ColorDataType', 'ColorResolution', 'DepthDataType', 'DepthResolution', 'IsLittleEndian', 'NidaqChannels', 'NidaqSamplingRate', 'SessionName', 'StartTime', 'SubjectName']


In [15]:
def get_experiment(path: Path):
    str_path = str(path)
    if "min" in str_path and 'longtogeny' in str_path:
        exp = f"longtogeny_v2_{path.parents[2].name.lower()}"
    elif "dlight" in str_path:
        return "dlight"
    elif "longtogeny" in str(path):
        sex = path.parents[3].name.lower()
        if sex not in ("males", "females"):
            sex = path.parents[2].name.lower()
            if sex not in ("males", "females"):
                raise ValueError("bleh")
        exp = f"longtogeny_{sex}"
    elif "ontogeny" in str(path).lower() and "community" not in str(path):
        exp = path.parents[3].name.lower()
        if exp == "raw_data":
            exp = path.parents[2].name.lower()
    elif "wheel" in str(path).lower():
        exp = "wheel"
    else:
        exp = path.parents[2].name
        print(exp)
    return exp


def insert_nans(timestamps, data, fps=30):
    df_timestamps = np.diff(np.insert(timestamps, 0, timestamps[0] - 1.0 / fps))
    missing_frames = np.floor(df_timestamps / (1.0 / fps))

    fill_idx = np.where(missing_frames > 1)[0]
    data_idx = np.arange(len(timestamps)).astype('float64')

    filled_data = deepcopy(data)
    filled_timestamps = deepcopy(timestamps)

    if filled_data.ndim == 1:
        isvec = True
        filled_data = filled_data[:, None]
    else:
        isvec = False
    nframes, nfeatures = filled_data.shape

    for idx in fill_idx[::-1]:
        if idx < len(missing_frames):
            ninserts = int(missing_frames[idx] - 1)
            data_idx = np.insert(data_idx, idx, [np.nan] * ninserts)
            insert_timestamps = timestamps[idx - 1] + \
                np.cumsum(np.ones(ninserts,) * 1.0 / fps)
            filled_data = np.insert(filled_data, idx,
                                    np.ones((ninserts, nfeatures)) * np.nan, axis=0)
            filled_timestamps = np.insert(
                filled_timestamps, idx, insert_timestamps)

    if isvec:
        filled_data = np.squeeze(filled_data)

    return filled_data, data_idx, filled_timestamps

In [16]:
# group uuids by experiment
exp_groups = groupby(lambda k: get_experiment(k[1]), uuid_map.items())
exp_groups = valmap(compose(list, pluck(0)), exp_groups)

In [17]:
list(exp_groups)

['ontogeny_females',
 'ontogeny_males',
 'longtogeny_males',
 'longtogeny_v2_females',
 'longtogeny_v2_males',
 'wheel',
 'dlight']

In [18]:
valmap(len, exp_groups)

{'ontogeny_females': 224,
 'ontogeny_males': 216,
 'longtogeny_males': 977,
 'longtogeny_v2_females': 757,
 'longtogeny_v2_males': 188,
 'wheel': 826,
 'dlight': 171}

In [15]:
def extract_scalars(path: Path, recon_key, rescaled_key):
    try:
        with h5py.File(path, "r") as f:
            session_name = f["metadata/acquisition/SessionName"][()].decode()
            subject_name = f["metadata/acquisition/SubjectName"][()].decode()
            keep_scalars = list(filter(lambda k: "mm" in k, f["scalars"])) + [
                "angle",
                "velocity_theta",
            ]

            ts = f["timestamps"][()] / 1000
            scalars = dict((k, f["scalars"][k][()]) for k in keep_scalars)
            filled_scalars = valmap(lambda v: insert_nans(ts, v)[0], scalars)
            filled_ts = insert_nans(ts, ts)[2]

            frames = f[recon_key][()]
            centroid = np.array(
                [f["scalars/centroid_x_px"][()], f["scalars/centroid_y_px"][()]]
            ).T
            true_depth = f["metadata/extraction/true_depth"][()]
            recon_scalars = compute_scalars(frames, centroid, true_depth)
            recon_scalars = valmap(lambda v: insert_nans(ts, v)[0], recon_scalars)
            # also add rescaled scalars
            frames = f[rescaled_key][()]
            rescaled_scalars = compute_scalars(frames, is_recon=False)
            rescaled_scalars = keymap(lambda k: f"rescaled_{k}", rescaled_scalars)
            rescaled_scalars = valmap(lambda v: insert_nans(ts, v)[0], rescaled_scalars)
        return dict(
            true_depth=true_depth,
            session_name=session_name,
            subject_name=subject_name,
            timestamps=filled_ts - filled_ts[0],
            **filled_scalars,
            **recon_scalars,
            **rescaled_scalars,
        )
    except (OSError, KeyError) as e:
        print("Error with", str(path))
        print(e)
        return None

In [None]:
recon_frames_key = "win_size_norm_frames_v4"
rescaled_frames_key = "rescaled_frames"
df_version = 0
failed_sessions = []
with h5py.File(syllable_path, "r") as h5f:
    for experiment, uuids in exp_groups.items():
        # remove this line to do everything
        df = []
        for i, (uuid, path) in enumerate(map(lambda u: (u, uuid_map[u]), tqdm(uuids))):
            extraction_data = extract_scalars(path, recon_frames_key, rescaled_frames_key)
            if extraction_data is None:
                extraction_data = dict(session_name='', subject_name='')
            if "longtogeny" in experiment:
                age = np.nan
            elif "ontogeny" in experiment:
                age = path.parents[2].name.split("_")[0]
            else:
                age = np.nan
            date = datetime.strptime(
                path.parents[1].name.split("_")[-1], "%Y%m%d%H%M%S"
            )
            try:
                _df = pd.DataFrame(
                    dict(
                        experiment=experiment,
                        file=str(path),
                        syllables=h5f[uuid][()],
                        date=date,
                        uuid=uuid,
                        age=age,
                        **extraction_data,
                    )
                )
                _df = _df.astype(
                    dict(
                        syllables="int16[pyarrow]",
                        file="string[pyarrow]",
                        experiment="string[pyarrow]",
                        uuid="string[pyarrow]",
                        session_name="string[pyarrow]",
                        subject_name="string[pyarrow]",
                        timestamps="float32[pyarrow]",
                        true_depth="float32[pyarrow]",
                        # **{
                        #     k: "float32[pyarrow]"
                        #     for k in list(filled_scalars) + list(recon_scalars) + list(rescaled_scalars)
                        # },
                    )
                )
                df.append(_df)
                if i % 35 == 0:
                    pd.concat(df, ignore_index=True).to_parquet(
                        data_folder
                        / f"{experiment}_syllable_df_v{df_version:02d}.parquet"
                    )
            except ValueError as e:
                print("failure", uuid, e)
                failed_sessions.append((uuid, path))
        df = pd.concat(df, ignore_index=True)
        df.to_parquet(
            data_folder / f"{experiment}_syllable_df_v{df_version:02d}.parquet"
        )
        print(experiment, "length", len(df))

  0%|          | 0/224 [00:00<?, ?it/s]

  values = array(values, copy=False, ndmin=arr.ndim, dtype=arr.dtype)
  values = array(values, copy=False, ndmin=arr.ndim, dtype=arr.dtype)
  values = array(values, copy=False, ndmin=arr.ndim, dtype=arr.dtype)
  values = array(values, copy=False, ndmin=arr.ndim, dtype=arr.dtype)


In [None]:
df.info()

In [None]:
len(failed_sessions)