### Notebook for exporting predicted labels from DEG 

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
import h5py
from tqdm import tqdm

In [15]:
trial_ids[0]

'M232_20170306_v043'

In [2]:
test_data = sorted(Path('/data/caitlin/cross_validation/k2/test_vids/').glob('*.avi'))
trial_ids = [a.stem for a in test_data]

In [3]:
data_path = Path('/data/caitlin/cross_validation/k2/k2_deg/DATA/')

In [4]:
def get_prediction_probabilities(trial_id):
    filename = data_path.joinpath(trial_id,trial_id).with_name(f'{trial_id}_outputs').with_suffix('.h5')
    with h5py.File(filename, "r") as f:
        keys = list(f.keys())

        prediction_model_name = keys[1]

        probabilities = f[prediction_model_name]['P'][:]

        # want to set negative probabilities to 0
        negative_probabilities = np.sum(probabilities < 0)
        if negative_probabilities > 0:
            probabilities[probabilities < 0] = 0

        thresholds = f[prediction_model_name]['thresholds'][:]
        if thresholds.ndim == 2:
            # this should not happen
            thresholds = thresholds[-1, :]
        loaded_class_names = f[prediction_model_name]['class_names'][:]
        if type(loaded_class_names[0]) == bytes:
            loaded_class_names = [i.decode('utf-8') for i in loaded_class_names]
        
        f.close()

    return probabilities, thresholds, prediction_model_name, keys

In [5]:
def find_bout_indices(predictions_trace: np.ndarray,
                      bout_length: int,
                      positive: bool = True,
                      eps: float = 1e-6) -> np.ndarray:
    # make a filter for convolution that will be 1 at that bout center
    center = np.ones(bout_length) / bout_length
    filt = np.concatenate([[-bout_length / 2], center, [-bout_length / 2]])
    if not positive:
        predictions_trace = np.logical_not(predictions_trace.copy()).astype(int)
    out = np.convolve(predictions_trace, filt, mode='same')
    # precision issues: using == 1 here has false negatives in case where out = 0.99999999998 or something
    indices = np.where(np.abs(out - 1) < eps)[0]
    if len(indices) == 0:
        return np.array([]).astype(int)
    # if even, this corresponds to the center + 0.5 frame in the bout
    # if odd, this corresponds to the center frame of the bout
    # we want indices to contain the entire bout, not just the center frame
    if bout_length % 2:
        expanded = np.concatenate([np.array(range(i - bout_length // 2, i + bout_length // 2 + 1)) for i in indices])
    else:
        expanded = np.concatenate([np.array(range(i - bout_length // 2, i + bout_length // 2)) for i in indices])
    return expanded

In [6]:
def export_predictions(trial_id):
    # get estimated labels
    probabilities, thresholds, prediction_model_name, keys = get_prediction_probabilities(trial_id)
    bout_length = 1
    predictions = (probabilities > thresholds).astype(int)
    T, K = predictions.shape
    for k in range(K):
        predictions_trace = predictions[:, k]
        for bout_len in range(1, bout_length + 1):
            # first, remove "false negatives", like filling in gaps in true behavior bouts
            short_neg_indices = find_bout_indices(predictions_trace, bout_len, positive=False)
            predictions_trace[short_neg_indices] = 1
            # then remove "false positives", very short "1" bouts
            short_pos_indices = find_bout_indices(predictions_trace, bout_len)
            predictions_trace[short_pos_indices] = 0
        predictions[:, k] = predictions_trace
    df = pd.DataFrame(data=predictions, columns=["background", "lift", "handopen", "grab", "sup", "atmouth", "chew"])
    prediction_fname = data_path.joinpath(trial_id,trial_id).with_name(f'{trial_id}_predictions.csv')
    df.to_csv(prediction_fname)

In [7]:
for trial_id in tqdm(trial_ids):
    export_predictions(trial_id)

100%|████████████████████████████████████████████████████████████████████████████| 300/300 [00:01<00:00, 267.57it/s]
