# Example NAPS usage

In this notebook we'll install NAPS, pull the example data from the [GitHub repository](https://github.com/kocherlab/naps), and run `naps-track` against it.

This repo is particularly useful in combination with SLEAP's example notebook on remote training and inference which can be found [here](https://colab.research.google.com/github/talmolab/sleap/blob/main/docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb).

## Install NAPS

In [None]:
%%capture
!pip install -q naps-track

In [None]:
# # If you have a model and want to do the inference on Colab, this can be done quite directly! Just upload your model and run inference as below.
# # You can also take advantage of the GPU accessibility of Colab to train as well. Look to the SLEAP tutorials for more info.
# sleap-track example.mp4 -o "example.slp" -m models/bu --verbosity json --batch_size 4 --verbosity json --tracking.tracker simple --tracking.similarity iou --tracker.track_window 5 --tracking.post_connect_single_breaks 1

## Download sample training data into Colab
Let's download a sample dataset from the the NAPS repository.


In [None]:
%%capture
!wget https://github.com/kocherlab/naps/raw/main/tests/data/example.slp
!wget https://github.com/kocherlab/naps/raw/main/tests/data/example.analysis.h5
!wget https://github.com/kocherlab/naps/raw/main/tests/data/example.mp4

In [None]:
!ls -lht

## NAPS tracking
Now let's track the files using `naps-track`.

In [None]:
!naps-track --slp-path example.slp --h5-path example.analysis.h5 --video-path example.mp4 --tag-node 0 \
 --start-frame 0 --end-frame 1203 --aruco-marker-set DICT_4X4_100 \
 --output-path example_output.analysis.h5 --aruco-error-correction-rate 0 \
 --aruco-adaptive-thresh-constant 12 --aruco-adaptive-thresh-win-size-max 30 \
 --aruco-adaptive-thresh-win-size-step 3 --aruco-perspective-rm-ignored-margin 0.13 \
 --aruco-adaptive-thresh-win-size-min 3 --half-rolling-window-size 11 

## Download

Now we can just download the output! This pulls the video, the output file, and the original project. 

Once you have these, you can create a new SLEAP project and import example_output.analysis.h5 and point to the video to see the resulting tracks. If you are curious how these compare with the original tracks, you can open the original project.

In [None]:
# # Zip the video and output
# !zip -0 -r naps_output.zip example.mp4 example.slp example_output.analysis.h5

# # Download
# from google.colab import files
# files.download("/content/naps_output.zip")

If you happen to not be using Chrome, you may get an error here. If that happens, you should be able to download the files using the "Files" tab on the left side panel.

## After NAPS

### SLEAP GUI

Now that we've got the files, we can either open the raw data or import it back into SLEAP. To view the tracks, open SLEAP (`sleap-label`) and create a new project. After creating a new project, you can go File > Import > SLEAP Analysis HDF5 and select the output, here example_output.analysis.h5. When you select the file, you will be prompted to select the video associated with the analysis file. You can simply select example.mp4 and then tracks will display in SLEAP.

### Directly reading the output H5

Now let's try reading in the .h5 directly and plotting a couple of basic features.

In [None]:
"""Read in the h5 and display basic info
"""

import h5py
import numpy as np

filename = "example_output.analysis.h5"
video_filename = "example.mp4"
output_filename = "output.mp4"

with h5py.File(filename, "r") as f:
    dset_names = list(f.keys())
    locations = f["tracks"][:].T
    node_names = [n.decode() for n in f["node_names"][:]]

print("===filename===")
print(filename)
print()

print("===HDF5 datasets===")
print(dset_names)
print()

print("===locations data shape===")
print(locations.shape)
print()

print("===nodes===")
for i, name in enumerate(node_names):
    print(f"{i}: {name}")
print()

### Utility functions for cleaning up tracks, plotting, and showing the video

In [None]:
"""Resource functions
"""

import pandas as pd
import scipy.ndimage
from tqdm import tqdm
from scipy.signal import savgol_filter
import matplotlib.colors as colors
import logging
import skvideo.io
import cv2
import matplotlib.pyplot as plt
import palettable
from IPython.display import HTML
from base64 import b64encode

def flatten_features(x, axis=0):

    if axis != 0:
        # Move time axis to the first dim
        x = np.moveaxis(x, axis, 0)

    # Flatten to 2D.
    initial_shape = x.shape
    x = x.reshape(len(x), -1)

    return x, initial_shape


def unflatten_features(x, initial_shape, axis=0):
    # Reshape.
    x = x.reshape(initial_shape)

    if axis != 0:
        # Move time axis back
        x = np.moveaxis(x, 0, axis)

    return x


def smooth_median(x, window=5, axis=0, inplace=False):
    if axis != 0 or x.ndim > 1:
        if not inplace:
            x = x.copy()

        # Reshape to (time, D)
        x, initial_shape = flatten_features(x, axis=axis)

        # Apply function to each slice
        for i in range(x.shape[1]):
            x[:, i] = smooth_median(x[:, i], window, axis=0)

        # Restore to original shape
        x = unflatten_features(x, initial_shape, axis=axis)
        return x

    y = scipy.signal.medfilt(x.copy(), window)
    y = y.reshape(x.shape)
    mask = np.isnan(y) & (~np.isnan(x))
    y[mask] = x[mask]
    return y


def fill_missing(x, kind="nearest", axis=0, **kwargs):
    """Fill missing values in a timeseries.
    Args:
        x: Timeseries of shape (time, ...) or with time axis specified by axis.
        kind: Type of interpolation to use. Defaults to "nearest".
        axis: Time axis (default: 0).
    Returns:
        Timeseries of the same shape as the input with NaNs filled in.
    Notes:
        This uses pandas.DataFrame.interpolate and accepts the same kwargs.
    """
    if x.ndim > 2:
        # Reshape to (time, D)
        x, initial_shape = flatten_features(x, axis=axis)

        # Interpolate.
        x = fill_missing(x, kind=kind, axis=0, **kwargs)

        # Restore to original shape
        x = unflatten_features(x, initial_shape, axis=axis)

        return x

    return pd.DataFrame(x).interpolate(method=kind, axis=axis, **kwargs).to_numpy()



def plot_trx(
    tracks,
    video_path=None,
    shift=0,
    frame_start=0,
    frame_end=100,
    trail_length=10,
    output_path="output.mp4",
    color_map=None,
    id_map=None,
    scale_factor=1,
    annotate=False,
):
    ffmpeg_writer = skvideo.io.FFmpegWriter(
        f"{output_path}", inputdict={'-r':"20"}, outputdict={"-vcodec": "libx264"}
    )
    if video_path != None:
        cap = cv2.VideoCapture(video_path)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_start - 1)
    data = tracks[(frame_start+shift):(frame_end+shift), :, :, :]
    dpi = 300
    for frame_idx in tqdm(range(data.shape[0]), position=0, leave=True):
        fig, ax = plt.subplots(figsize=(3660 / dpi, 3660 / dpi), dpi=dpi)
        plt.gca().invert_yaxis()
        # plt.xlim((0, 3660))
        # plt.ylim((-3660, 0))
        data_subset = data[max((frame_idx - trail_length), 0) : frame_idx, :, :, :]
        for fly_idx in range(data_subset.shape[3]):
            if annotate and data_subset.shape[0] > 0:
                # if  ~(data_subset[-1, 0, 1, fly_idx] < 3660/2):
                plt.annotate(
                    fly_idx,
                    (data_subset[-1, 0, 0, fly_idx], data_subset[-1, 0, 1, fly_idx]),
                    size=18,
                    ha="left",
                    va="bottom",
                    color="#CB9E23",
                )
            for node_idx in range(data_subset.shape[1]):
                for idx in range(2, data_subset.shape[0]):
                    # if  data_subset[idx, node_idx, 1, fly_idx] < 3660/2:
                    #     continue
                    # Note that you need to use single steps or the data has "steps"
                    if color_map == None:
                        plt.plot(
                            data_subset[(idx - 2) : idx, node_idx, 0, fly_idx],
                            data_subset[(idx - 2) : idx, node_idx, 1, fly_idx],
                            linewidth= 4.5 * idx / data_subset.shape[0],
                            color=palettable.tableau.Tableau_20.mpl_colors[node_idx],
                        )
                    else:
                        color = color_map[id_map[fly_idx]]
                        (l,) = ax.plot(
                            data_subset[(idx - 2) : idx, node_idx, 0, fly_idx]
                            * scale_factor,
                            data_subset[(idx - 2) : idx, node_idx, 1, fly_idx]
                            * scale_factor,
                            linewidth=3 * idx / data_subset.shape[0],
                            color=color,
                        )
                        l.set_solid_capstyle("round")
        if video_path != None:
            if cap.isOpened():
                res, frame = cap.read()
                frame = frame[:, :, 0]
                # frame[:,:,:] = 255
                plt.imshow(frame, cmap="gray")
        ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
        ax.set_axis_off()
        fig.add_axes(ax)
        fig.set_size_inches(3660 / dpi, 3660 / dpi, True)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.axis("off")
        fig.patch.set_visible(False)
        fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
        fig.canvas.draw()
        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        image_from_plot = image_from_plot.reshape(
            fig.canvas.get_width_height()[::-1] + (3,)
        )
        ffmpeg_writer.writeFrame(image_from_plot)
        plt.close()
    ffmpeg_writer.close()

def instance_node_velocities(fly_node_locations, start_frame, end_frame):
    frame_count = len(range(start_frame, end_frame))
    if len(fly_node_locations.shape) == 4:
        fly_node_velocities = np.zeros(
            (frame_count, fly_node_locations.shape[1], fly_node_locations.shape[3])
        )
        for fly_idx in range(fly_node_locations.shape[3]):
            for n in tqdm(range(0, fly_node_locations.shape[1])):
                fly_node_velocities[:, n, fly_idx] = diff(
                    fly_node_locations[start_frame:end_frame, n, :, fly_idx]
                )
    else:
        fly_node_velocities = np.zeros((frame_count, fly_node_locations.shape[1]))
        for n in tqdm(range(0, fly_node_locations.shape[1] - 1)):
            fly_node_velocities[:, n] = diff(
                fly_node_locations[start_frame:end_frame, n, :]
            )

    return fly_node_velocities


def diff(node_loc, diff_func = np.gradient, **kwargs):
    """
    node_loc is a [frames, 2] arrayF

    win defines the window to smooth over

    poly defines the order of the polynomial
    to fit with

    """
    node_loc_vel = np.zeros_like(node_loc)
    for c in range(node_loc.shape[-1]):
        node_loc_vel[:, c] = diff_func(node_loc[:, c], **kwargs)

    node_vel = np.linalg.norm(node_loc_vel,axis=1)

    return node_vel

def show_video(video_path, video_width = 1000):
  video_file = open(video_path, "r+b").read()

  video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
  return HTML(f"""<video width={video_width} controls><source src="{video_url}"></video>""")

In [None]:
%%capture

px_mm = 15.5

# Missingness filter
atleast_one_node_defined = np.any(~np.isnan(locations[:, :, 0, :]), axis=1)
no_nodes_defined =  ~atleast_one_node_defined

missing_ct = np.sum(no_nodes_defined, axis=0)
missing_freq = missing_ct / no_nodes_defined.shape[0]
locations_filtered = locations[:, : , :, missing_freq < 0.8]


vel = instance_node_velocities(locations_filtered,0,locations_filtered.shape[0])
mask_2d = ~(vel[:,node_names.index('Thorax'),:] < px_mm*5 )[:,np.newaxis,np.newaxis,:]
mask_4d = np.broadcast_to(mask_2d, locations_filtered.shape)
locations_filtered[mask_4d] = np.nan

In [None]:
# Write a video!
plot_trx(locations_filtered, video_path = video_filename, output_path = output_filename, frame_start = 0, frame_end=100, trail_length=5, annotate=True, shift = 5)

In [None]:
# Now lets do a little Jupyter magic to display the video in browser! 
show_video(output_filename)