Copyright 2020 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
# @title Download Model {form-width: "25%"}

%mkdir tapnet/checkpoints

!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/tapir_checkpoint.npy

%ls tapnet/checkpoints

mkdir: cannot create directory 'tapnet/checkpoints': File exists
--2023-07-11 02:29:31--  https://storage.googleapis.com/dm-tapnet/tapir_checkpoint.npy
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.163.128, 172.253.62.128, 142.250.31.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.163.128|:443... connected.


In [1]:
pip install --upgrade "jax[cpu]"


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [1]:
# @title Imports {form-width: "25%"}

import haiku as hk
import jax
import mediapy as media
import numpy as np
import tree

from tapnet import tapir_model
from tapnet.utils import transforms
from tapnet.utils import viz_utils

In [2]:
# @title Load Checkpoint {form-width: "25%"}

checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']

In [3]:
# @title Build Model {form-width: "25%"}

def build_model(frames, query_points):
  """Compute point tracks and occlusions given frames and query points."""
  model = tapir_model.TAPIR()
  outputs = model(
      video=frames,
      is_training=False,
      query_points=query_points,
      query_chunk_size=64,
  )
  return outputs

model = hk.transform_with_state(build_model)
model_apply = jax.jit(model.apply)

In [4]:
# @title Utility Functions {form-width: "25%"}

def preprocess_frames(frames):
  """Preprocess frames to model inputs.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8

  Returns:
    frames: [num_frames, height, width, 3], [-1, 1], np.float32
  """
  frames = frames.astype(np.float32)
  frames = frames / 255 * 2 - 1
  return frames


def postprocess_occlusions(occlusions, expected_dist):
  """Postprocess occlusions to boolean visible flag.

  Args:
    occlusions: [num_points, num_frames], [-inf, inf], np.float32
    expected_dist: [num_points, num_frames], [-inf, inf], np.float32

  Returns:
    visibles: [num_points, num_frames], bool
  """
  # visibles = occlusions < 0
  visibles = (1 - jax.nn.sigmoid(occlusions)) * (1 - jax.nn.sigmoid(expected_dist)) > 0.5
  return visibles

def inference(frames, query_points):
  """Inference on one video.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8
    query_points: [num_points, 3], [0, num_frames/height/width], [t, y, x]

  Returns:
    tracks: [num_points, 3], [-1, 1], [t, y, x]
    visibles: [num_points, num_frames], bool
  """
  # Preprocess video to match model inputs format
  frames = preprocess_frames(frames)
  num_frames, height, width = frames.shape[0:3]
  query_points = query_points.astype(np.float32)
  frames, query_points = frames[None], query_points[None]  # Add batch dimension

  # Model inference
  rng = jax.random.PRNGKey(42)
  outputs, _ = model_apply(params, state, rng, frames, query_points)
  outputs = tree.map_structure(lambda x: np.array(x[0]), outputs)
  tracks, occlusions, expected_dist = outputs['tracks'], outputs['occlusion'], outputs['expected_dist']

  # Binarize occlusions
  visibles = postprocess_occlusions(occlusions, expected_dist)
  return tracks, visibles


def sample_random_points(frame_max_idx, height, width, num_points):
  """Sample random points with (time, height, width) order."""
  y = np.random.randint(0, height, (num_points, 1))
  x = np.random.randint(0, width, (num_points, 1))
  t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))
  points = np.concatenate((t, y, x), axis=-1).astype(np.int32)  # [num_points, 3]
  return points

In [5]:
# @title Load an Exemplar Video {form-width: "25%"}

%mkdir tapnet/examplar_videos

!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4

video = media.read_video('tapnet/examplar_videos/horsejump-high.mp4')
# media.show_video(video, fps=10)

mkdir: cannot create directory 'tapnet/examplar_videos': File exists
--2023-07-11 07:31:18--  https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.163.128, 172.253.62.128, 142.251.16.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.163.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 467706 (457K) [video/mp4]
Saving to: 'tapnet/examplar_videos/horsejump-high.mp4.6'


2023-07-11 07:31:18 (4.31 MB/s) - 'tapnet/examplar_videos/horsejump-high.mp4.6' saved [467706/467706]



In [87]:
import fastplotlib as fpl
from ipywidgets import IntSlider, VBox

In [88]:
from skimage.transform import resize

In [89]:
video.shape

(50, 480, 854, 3)

In [109]:
video_resized = np.zeros((50, 160, 284, 3), dtype=np.uint8)

for i in range(video.shape[0]):
    video_resized[i] = resize(video[i], (160, 284, 3), preserve_range=True)

In [159]:
plot = fpl.Plot(size=(700, 400))

plot.add_image(video_resized[0])

def update_frame(change):
    ix = change["new"]
    plot.graphics[0].data = video_resized[ix]

slider = IntSlider(min=0, max=video_resized.shape[0] - 1, step=1, value=0)
slider.observe(update_frame, "value")

VBox([plot.show(), slider])

RFBOutputContext()

VBox(children=(VBox(children=(JupyterWgpuCanvas(css_height='400px', css_width='700px'), HBox(children=(Button(…

In [163]:
plot.camera.world.scale_y *= -1

### Draw a polygon around the horse and rider by clicking on the polygon tool!

In [164]:
from skimage.draw import polygon2mask

In [165]:
vertices = plot.selectors[0].get_vertices()

# returns boolean mask
mask = polygon2mask(video.shape[1:-1], vertices).T

# get points
pts = np.argwhere(mask).astype(np.float32)

# random sample of points
n_query = 100
ixs = np.random.choice(range(pts.shape[0]), n_query, replace=False)

In [166]:
plot.add_scatter(np.fliplr(pts[ixs]), sizes=10, colors="random")

<weakproxy at 0x7f61dc6a9f80 to ScatterGraphic at 0x7f6278201cd0>

In [167]:
query_points = np.column_stack([np.zeros(n_query), pts[ixs]]).astype(np.int32)
tracks, visibles = inference(video_resized, query_points)

In [168]:
plot_tracks = fpl.Plot(size=(700, 400))

plot_tracks.add_image(video_resized[0])

pos0 = np.vstack([tracks[i][0] for i in range(len(tracks))])

plot_tracks.add_scatter(
    pos0, 
    cmap="jet",
    sizes=5, 
    name="pts"
)

def update_frame(change):
    frame_ix = change["new"]
    plot_tracks.graphics[0].data = video_resized[frame_ix]
    
    for i in range(len(tracks)):
        plot_tracks["pts"].data[i] = tracks[i][frame_ix]

slider = IntSlider(min=0, max=video_resized.shape[0] - 1, step=1, value=0)
slider.observe(update_frame, "value")

VBox([plot_tracks.show(), slider])

RFBOutputContext()

VBox(children=(VBox(children=(JupyterWgpuCanvas(css_height='400px', css_width='700px'), HBox(children=(Button(…

In [169]:
plot_tracks.camera.world.scale_y *= -1

Apply a gaussian filer to smooth the tracks

In [170]:
from scipy.ndimage import gaussian_filter1d

In [171]:
tracks_filt = list()
for t in tracks:
    # gaussian filter xs and ys
    tracks_filt.append(
        np.column_stack(
            [
                gaussian_filter1d(t[:, 0], 1.5), # filter x vals
                gaussian_filter1d(t[:, 1], 1.5)  # filter y vals
            ]
        )
    )

In [172]:
plot_tracks.add_line_collection(
    tracks_filt, 
    cmap="jet", # same cmap, colors will match
    thickness=2,
    name="tracks"
)

# bring our points up so they're more visible
plot_tracks["pts"].position_z = 5

In [173]:
# x and y components of velocity and accerlation

x_velocity = list()
y_velocity = list()

x_accel = list()
y_accel = list()

for t in tracks_filt:
    x_vel = np.gradient(t[:, 0])
    x_velocity.append(x_vel)
    
    y_vel = np.gradient(t[:, 1])
    y_velocity.append(y_vel)
    
    x_acc = np.gradient(x_vel)
    x_accel.append(x_acc)
    
    y_acc = np.gradient(y_vel)
    y_accel.append(y_acc)

In [158]:
for track, vals in zip(plot_tracks["tracks"].graphics, y_velocity):
    # velcoity is directional, so diverging colormaps!
    
    # blue = negative, down
    # white = 0
    # red = positive, up
    
    track.cmap = "bwr_r"
    track.cmap.values = vals

In [144]:
for track, vals in zip(plot_tracks["tracks"].graphics, y_accel):
    track.cmap = "bwr_r"
    track.cmap.values = vals

In [145]:
for track, vals in zip(plot_tracks["tracks"].graphics, x_accel):
    track.cmap = "bwr_r"
    track.cmap.values = vals