Copyright 2020 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [1]:
# @title Imports {form-width: "25%"}

import haiku as hk
import jax
import mediapy as media
import numpy as np
import tree

from tapnet import tapir_model
from tapnet.utils import transforms
from tapnet.utils import viz_utils

In [2]:
# @title Load Checkpoint {form-width: "25%"}

checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']

In [3]:
# @title Build Model {form-width: "25%"}

def build_model(frames, query_points):
  """Compute point tracks and occlusions given frames and query points."""
  model = tapir_model.TAPIR()
  outputs = model(
      video=frames,
      is_training=False,
      query_points=query_points,
      query_chunk_size=64,
  )
  return outputs

model = hk.transform_with_state(build_model)
model_apply = jax.jit(model.apply)

In [4]:
# @title Utility Functions {form-width: "25%"}

def preprocess_frames(frames):
  """Preprocess frames to model inputs.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8

  Returns:
    frames: [num_frames, height, width, 3], [-1, 1], np.float32
  """
  frames = frames.astype(np.float32)
  frames = frames / 255 * 2 - 1
  return frames


def postprocess_occlusions(occlusions, expected_dist):
  """Postprocess occlusions to boolean visible flag.

  Args:
    occlusions: [num_points, num_frames], [-inf, inf], np.float32
    expected_dist: [num_points, num_frames], [-inf, inf], np.float32

  Returns:
    visibles: [num_points, num_frames], bool
  """
  # visibles = occlusions < 0
  visibles = (1 - jax.nn.sigmoid(occlusions)) * (1 - jax.nn.sigmoid(expected_dist)) > 0.5
  return visibles

def inference(frames, query_points):
  """Inference on one video.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8
    query_points: [num_points, 3], [0, num_frames/height/width], [t, y, x]

  Returns:
    tracks: [num_points, 3], [-1, 1], [t, y, x]
    visibles: [num_points, num_frames], bool
  """
  # Preprocess video to match model inputs format
  frames = preprocess_frames(frames)
  num_frames, height, width = frames.shape[0:3]
  query_points = query_points.astype(np.float32)
  frames, query_points = frames[None], query_points[None]  # Add batch dimension

  # Model inference
  rng = jax.random.PRNGKey(42)
  outputs, _ = model_apply(params, state, rng, frames, query_points)
  outputs = tree.map_structure(lambda x: np.array(x[0]), outputs)
  tracks, occlusions, expected_dist = outputs['tracks'], outputs['occlusion'], outputs['expected_dist']

  # Binarize occlusions
  visibles = postprocess_occlusions(occlusions, expected_dist)
  return tracks, visibles


def sample_random_points(frame_max_idx, height, width, num_points):
  """Sample random points with (time, height, width) order."""
  y = np.random.randint(0, height, (num_points, 1))
  x = np.random.randint(0, width, (num_points, 1))
  t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))
  points = np.concatenate((t, y, x), axis=-1).astype(np.int32)  # [num_points, 3]
  return points

In [5]:
import fastplotlib as fpl
from ipywidgets import IntSlider, VBox, Layout
from sidecar import Sidecar

In [6]:
from skimage.transform import resize

In [8]:
import pickle


with open('/home/clewis7/Desktop/tapvid_davis/tapvid_davis.pkl', 'rb') as f:
    data = pickle.load(f)

In [9]:
selected_vids = [
    "paragliding-launch", "kite-surf", "drift-chicane",
    "dance-twirl", "dog", "dogs-jump",
    "car-roundabout", "soapbox", "breakdance"
]

# **CAITLIN SELECT VID HERE HERE**

In [10]:
vid_name = selected_vids[0]
video_resized = data[vid_name]["video"]

In [11]:
plot = fpl.Plot(size=(700, 400))

plot.add_image(video_resized[0])

def update_frame(change):
    ix = change["new"]
    plot.graphics[0].data = video_resized[ix]

slider = IntSlider(min=0, max=video_resized.shape[0] - 1, step=1, value=0)
slider.observe(update_frame, "value")
slider.layout = Layout(width="700px")

with Sidecar(title="draw"):
    display(VBox([plot.show(), slider]))

RFBOutputContext()

In [12]:
plot.camera.world.scale_y *= -1

### Draw a polygon around the horse and rider by clicking on the polygon tool!

In [13]:
from skimage.draw import polygon2mask

In [14]:
# sample random points within the polygon
vertices = plot.selectors[0].get_vertices()

# returns boolean mask
mask = polygon2mask(video_resized.shape[1:-1], vertices).T

# get points
pts = np.argwhere(mask).astype(np.float32)

# random sample of points
n_query = 100
ixs = np.random.choice(range(pts.shape[0]), n_query, replace=False)

In [15]:
plot.add_scatter(np.fliplr(pts[ixs]), sizes=10, colors="random")

  warn(f"converting {array.dtype} array to float32")


<weakproxy at 0x7fb468106d90 to ScatterGraphic at 0x7fb468137950>

In [16]:
query_points = np.column_stack([np.zeros(n_query), pts[ixs]]).astype(np.int32)
tracks, visibles = inference(video_resized, query_points)



In [17]:
plot_tracks = fpl.Plot(size=(700, 400))

plot_tracks.add_image(video_resized[0])

pos0 = np.vstack([tracks[i][0] for i in range(len(tracks))])

plot_tracks.add_scatter(
    pos0, 
    cmap="jet",
    sizes=5, 
    name="pts"
)

def update_frame(change):
    frame_ix = change["new"]
    plot_tracks.graphics[0].data = video_resized[frame_ix]
    
    for i in range(len(tracks)):
        plot_tracks["pts"].data[i] = tracks[i][frame_ix]

slider = IntSlider(min=0, max=video_resized.shape[0] - 1, step=1, value=0)
slider.observe(update_frame, "value")
slider.layout = Layout(width="700px")

with Sidecar(title="tracks"):
    display(VBox([plot_tracks.show(), slider]))

RFBOutputContext()

In [18]:
plot_tracks.camera.world.scale_y *= -1

Apply a gaussian filer to smooth the tracks

In [19]:
from scipy.ndimage import gaussian_filter1d

In [20]:
tracks_filt = list()
for t in tracks:
    # gaussian filter xs and ys
    tracks_filt.append(
        np.column_stack(
            [
                gaussian_filter1d(t[:, 0], 1.5), # filter x vals
                gaussian_filter1d(t[:, 1], 1.5)  # filter y vals
            ]
        )
    )

In [21]:
plot_tracks.add_line_collection(
    tracks_filt, 
    cmap="jet", # same cmap, colors will match
    thickness=2,
    name="tracks"
)

# bring our points up so they're more visible
plot_tracks["pts"].position_z = 5

In [22]:
tracks_array = np.dstack(tracks_filt)
tracks_array.shape

(80, 2, 100)

Shpae is `[n_frames, 2-xy, n_points]`

In [25]:
plot_tracks["tracks"][:].cmap = "winter"

In [27]:
np.save(f"./tracks/{vid_name}.npy", tracks_array)

In [33]:
ll -h ./tracks

total 64K
-rw-r--r-- 1 kushalk 63K Jul 12 03:56 paragliding-launch.npy
