# Take inferred syllables from ARHMM and assign mouse ID and experiment to each session

In [13]:
import cv2
import h5py
import numpy as np
import pandas as pd
from copy import deepcopy
from toolz import concat, curry, compose, valmap, first, groupby
from toolz.curried import pluck
from tqdm.auto import tqdm
from pathlib import Path
from datetime import datetime

In [2]:
folders = [
    '/n/groups/datta/Dana/Ontogeny/raw_data/Ontogeny_females',
    '/n/groups/datta/Dana/Ontogeny/raw_data/Ontogeny_males',
    '/n/groups/datta/Dana/Ontogeny/raw_data/longtogeny_pre_unet/Females',
    '/n/groups/datta/Dana/Ontogeny/raw_data/longtogeny_pre_unet/Males',
    # '/n/groups/datta/min/dominance_v1',
    # '/n/groups/datta/min/community_v1',
    # '/n/groups/datta/min/wheel_062023',
    # '/n/groups/datta/min/cas_behavior_01',
]
folders = [Path(f) for f in folders]

In [3]:
uuid_map = {}
for file in tqdm(concat(f.glob('**/results_00.h5') for f in folders)):
    with h5py.File(file, 'r') as h5f:
        uuid = h5f['metadata/uuid'][()].decode()
        uuid_map[uuid] = file

0it [00:00, ?it/s]

In [4]:
data_folder = Path('/n/groups/datta/win/longtogeny/data/ontogeny/version_03')

In [5]:
syllable_path = data_folder / 'all_data_pca/syllables.h5'

In [6]:
with h5py.File(first(uuid_map.values()), 'r') as h5f:
    print(list(h5f['metadata/acquisition']))

['ColorDataType', 'ColorResolution', 'DepthDataType', 'DepthResolution', 'IsLittleEndian', 'NidaqChannels', 'NidaqSamplingRate', 'SessionName', 'StartTime', 'SubjectName']


In [7]:
# TODO: fill out data organization better here:
experiments = (
    'ontogeny_males',
    'ontogeny_females',
    'longtogeny_females',
    'longtogeny_males',
    # others...
)

In [8]:
def get_experiment(path: Path):
    if "longtogeny" in str(path):
        sex = path.parents[3].name.lower()
        if sex not in ("males", "females"):
            sex = path.parents[2].name.lower()
            if sex not in ("males", "females"):
                raise ValueError("bleh")
        exp = f"longtogeny_{sex}"
    elif "ontogeny" in str(path).lower() and "community" not in str(path):
        exp = path.parents[3].name.lower()
    else:
        exp = path.parents[2].name
    return exp


@curry
def is_experiment(experiment, uuid):
    return experiment == get_experiment(uuid_map[uuid])


def insert_nans(timestamps, data, fps=30):
    df_timestamps = np.diff(np.insert(timestamps, 0, timestamps[0] - 1.0 / fps))
    missing_frames = np.floor(df_timestamps / (1.0 / fps))

    fill_idx = np.where(missing_frames > 1)[0]
    data_idx = np.arange(len(timestamps)).astype('float64')

    filled_data = deepcopy(data)
    filled_timestamps = deepcopy(timestamps)

    if filled_data.ndim == 1:
        isvec = True
        filled_data = filled_data[:, None]
    else:
        isvec = False
    nframes, nfeatures = filled_data.shape

    for idx in fill_idx[::-1]:
        if idx < len(missing_frames):
            ninserts = int(missing_frames[idx] - 1)
            data_idx = np.insert(data_idx, idx, [np.nan] * ninserts)
            insert_timestamps = timestamps[idx - 1] + \
                np.cumsum(np.ones(ninserts,) * 1.0 / fps)
            filled_data = np.insert(filled_data, idx,
                                    np.ones((ninserts, nfeatures)) * np.nan, axis=0)
            filled_timestamps = np.insert(
                filled_timestamps, idx, insert_timestamps)

    if isvec:
        filled_data = np.squeeze(filled_data)

    return filled_data, data_idx, filled_timestamps


def im_moment_features(frame):
    frame_mask = frame > 10
    cnts, _ = cv2.findContours(frame_mask.astype('uint8'), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    tmp = np.array([cv2.contourArea(x) for x in cnts])
    if tmp.size == 0:
        return None
    mouse_cnt = cnts[tmp.argmax()]
    tmp = cv2.moments(mouse_cnt)
    num = 2*tmp['mu11']
    den = tmp['mu20']-tmp['mu02']

    common = np.sqrt(4*np.square(tmp['mu11'])+np.square(den))

    if tmp['m00'] == 0:
        features = {
            'orientation': np.nan,
            'centroid': np.nan,
            'axis_length': [np.nan, np.nan]}
    else:
        features = {
            'orientation': -.5*np.arctan2(num, den),
            'centroid': [tmp['m10']/tmp['m00'], tmp['m01']/tmp['m00']],
            'axis_length': [2*np.sqrt(2)*np.sqrt((tmp['mu20']+tmp['mu02']+common)/tmp['m00']),
                            2*np.sqrt(2)*np.sqrt((tmp['mu20']+tmp['mu02']-common)/tmp['m00'])]
        }

    return features


def pxs_to_mm(coords, resolution=(512, 424), field_of_view=(70.6, 60), true_depth=673.1):

    cx = resolution[0] // 2
    cy = resolution[1] // 2

    xhat = coords[:, 0] - cx
    yhat = coords[:, 1] - cy

    fw = resolution[0] / (2 * np.deg2rad(field_of_view[0] / 2))
    fh = resolution[1] / (2 * np.deg2rad(field_of_view[1] / 2))

    new_coords = np.zeros_like(coords)
    new_coords[:, 0] = true_depth * xhat / fw
    new_coords[:, 1] = true_depth * yhat / fh

    return new_coords


def compute_scalars(frames, centroid, true_depth):
    centroid_mm = pxs_to_mm(centroid, true_depth=true_depth)
    centroid_mm_shift = pxs_to_mm(centroid + 1, true_depth=true_depth)
    px_to_mm = np.abs(centroid_mm_shift - centroid_mm)

    width = []
    length = []
    height = []
    area = []
    for i, frame in enumerate(frames):
        # compute ellipse
        feats = im_moment_features(frame)
        if feats is None:
            width.append(np.nan)
            length.append(np.nan)
            height.append(np.nan)
        else:
            w = np.min(feats['axis_length'])
            w = w * px_to_mm[i, 1]

            l = np.max(feats['axis_length'])
            l = l * px_to_mm[i, 0]

            width.append(w)
            length.append(l)
            height.append(np.mean(frame[(frame > 10) & (frame < 110)]))
        area.append(np.sum((frame > 10) & (frame < 110)) * px_to_mm[i].mean())
    out = dict(recon_width=width, recon_length=length, recon_height=height, recon_area=area)
    return valmap(np.array, out)

In [10]:
# group uuids by experiment
exp_groups = groupby(lambda k: get_experiment(k[1]), uuid_map.items())
exp_groups = valmap(compose(list, pluck(0)), exp_groups)

In [16]:
recon_frames_key = "win_size_norm_frames_v2"
failed_sessions = []
# batch_size = 200
with h5py.File(syllable_path, "r") as h5f:
    for experiment, uuids in exp_groups.items():
        if experiment != 'longtogeny_males':
            continue
        df = []
        for i, (uuid, path) in enumerate(
            map(lambda u: (u, uuid_map[u]), tqdm(uuids))
        ):
            try:
                with h5py.File(path, "r") as f:
                    session_name = f["metadata/acquisition/SessionName"][()].decode()
                    subject_name = f["metadata/acquisition/SubjectName"][()].decode()
                    keep_scalars = list(filter(lambda k: "mm" in k, f["scalars"])) + [
                        "angle",
                        "velocity_theta",
                    ]

                    ts = f["timestamps"][()] / 1000
                    scalars = dict((k, f["scalars"][k][()]) for k in keep_scalars)
                    filled_scalars = valmap(lambda v: insert_nans(ts, v)[0], scalars)
                    filled_ts = insert_nans(ts, ts)[2]

                    frames = f[recon_frames_key][()]
                    centroid = np.array(
                        [f["scalars/centroid_x_px"][()], f["scalars/centroid_y_px"][()]]
                    ).T
                    true_depth = f["metadata/extraction/true_depth"][()]
                    recon_scalars = compute_scalars(frames, centroid, true_depth)
                    recon_scalars = valmap(
                        lambda v: insert_nans(ts, v)[0], recon_scalars
                    )
            except KeyError:
                session_name = ""
                subject_name = ""
            if "longtogeny" in experiment:
                age = None
            elif "ontogeny" in experiment:
                age = path.parents[2].name.split("_")[0]
            else:
                age = None
            date = datetime.strptime(
                path.parents[1].name.split("_")[-1], "%Y%m%d%H%M%S"
            )
            try:
                _df = pd.DataFrame(
                    dict(
                        experiment=experiment,
                        file=str(path),
                        syllables=h5f[uuid][()],
                        date=date,
                        uuid=uuid,
                        age=age,
                        session_name=session_name,
                        subject_name=subject_name,
                        timestamps=filled_ts - filled_ts[0],
                        **filled_scalars,
                        **recon_scalars,
                    )
                )
                _df = _df.astype(
                    dict(
                        syllables="int16[pyarrow]",
                        file="string[pyarrow]",
                        experiment="string[pyarrow]",
                        uuid="string[pyarrow]",
                        session_name="string[pyarrow]",
                        subject_name="string[pyarrow]",
                        timestamps="float32[pyarrow]",
                    )
                )
                df.append(_df)
                if i % 35 == 0:
                    pd.concat(df, ignore_index=True).to_parquet(
                        data_folder / f"{experiment}_syllable_df.parquet"
                    )
            except ValueError:
                failed_sessions.append((uuid, path))
        df = pd.concat(df, ignore_index=True)
        df.to_parquet(data_folder / f"{experiment}_syllable_df.parquet")
        print(experiment, "length", len(df))

  0%|          | 0/1096 [00:00<?, ?it/s]

longtogeny_males length 37649629


In [17]:
with h5py.File(syllable_path, "r") as h5f:
    print(len(h5f[uuid]))

35954


In [18]:
len(uuid_map)

2324

In [19]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 37649629 entries, 0 to 37649628
Data columns (total 23 columns):
 #   Column          Dtype         
---  ------          -----         
 0   experiment      string        
 1   file            string        
 2   syllables       int16[pyarrow]
 3   date            datetime64[ns]
 4   uuid            string        
 5   age             object        
 6   session_name    string        
 7   subject_name    string        
 8   timestamps      float[pyarrow]
 9   area_mm         float32       
 10  centroid_x_mm   float32       
 11  centroid_y_mm   float32       
 12  height_ave_mm   float32       
 13  length_mm       float32       
 14  velocity_2d_mm  float32       
 15  velocity_3d_mm  float32       
 16  width_mm        float32       
 17  angle           float32       
 18  velocity_theta  float32       
 19  recon_width     float64       
 20  recon_length    float64       
 21  recon_height    float64       
 22  recon_area      