In [1]:
%matplotlib notebook
import haiku as hk
import jax
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
import tree

from PIL import Image
from tapnet import tapir_model
from tapnet.utils import transforms
from tapnet.utils import viz_utils

## Load Checkpoint

In [2]:
checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint_panning.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']

# Build Model

In [3]:
def build_model(frames, query_points):
  """Compute point tracks and occlusions given frames and query points."""
  model = tapir_model.TAPIR(bilinear_interp_with_depthwise_conv=False, pyramid_level=0)
  outputs = model(
      video=frames,
      is_training=False,
      query_points=query_points,
      query_chunk_size=64,
  )
  return outputs

model = hk.transform_with_state(build_model)
model_apply = jax.jit(model.apply)

# Utility Functions

In [4]:
def preprocess_frames(frames):
  """Preprocess frames to model inputs.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8

  Returns:
    frames: [num_frames, height, width, 3], [-1, 1], np.float32
  """
  frames = frames.astype(np.float32)
  frames = frames / 255 * 2 - 1
  return frames


def postprocess_occlusions(occlusions, expected_dist):
  """Postprocess occlusions to boolean visible flag.

  Args:
    occlusions: [num_points, num_frames], [-inf, inf], np.float32
    expected_dist: [num_points, num_frames], [-inf, inf], np.float32

  Returns:
    visibles: [num_points, num_frames], bool
  """
  visibles = (1 - jax.nn.sigmoid(occlusions)) * (1 - jax.nn.sigmoid(expected_dist)) > 0.5
  return visibles

def inference(frames, query_points):
  """Inference on one video.

  Args:
    frames: [num_frames, height, width, 3], [0, 255], np.uint8
    query_points: [num_points, 3], [0, num_frames/height/width], [t, y, x]

  Returns:
    tracks: [num_points, 3], [-1, 1], [t, y, x]
    visibles: [num_points, num_frames], bool
  """
  # Preprocess video to match model inputs format
  frames = preprocess_frames(frames)
  num_frames, height, width = frames.shape[0:3]
  query_points = query_points.astype(np.float32)
  frames, query_points = frames[None], query_points[None]  # Add batch dimension

  # Model inference
  rng = jax.random.PRNGKey(42)
  outputs, _ = model_apply(params, state, rng, frames, query_points)
  outputs = tree.map_structure(lambda x: np.array(x[0]), outputs)
  tracks, occlusions, expected_dist = outputs['tracks'], outputs['occlusion'], outputs['expected_dist']

  # Binarize occlusions
  visibles = postprocess_occlusions(occlusions, expected_dist)
  return tracks, visibles


def sample_random_points(frame_max_idx, height, width, num_points):
  """Sample random points with (time, height, width) order."""
  y = np.random.randint(0, height, (num_points, 1))
  x = np.random.randint(0, width, (num_points, 1))
  t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))
  points = np.concatenate((t, y, x), axis=-1).astype(np.int32)  # [num_points, 3]
  return points

# Load Videos

In [5]:
video = media.read_video('Preprocessed_Videos/AI_GeMo_late_F-/35_F-_2_c.mp4')
height, width = video.shape[1:3]
media.show_video(video, fps=30)

0
This browser does not support the video tag.


In [6]:
print(height, width) 

848 480


In [7]:
# Select Any Points at Any Frame

In [8]:
select_frame = 0  #@param {type:"slider", min:0, max:49, step:1}

# Generate a colormap with 20 points, no need to change unless select more than 20 points
colormap = viz_utils.get_colors(20)

fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(video[select_frame])
ax.axis('off')
ax.set_title('You can select more than 1 points. After select enough points, run the next cell.')

select_points = []

 # Event handler for mouse clicks
def on_click(event):
  if event.button == 1 and event.inaxes == ax:  # Left mouse button clicked
    x, y = int(np.round(event.xdata)), int(np.round(event.ydata))

    select_points.append(np.array([x, y]))

    color = colormap[len(select_points) - 1]
    color = tuple(np.array(color) / 255.0)
    ax.plot(x, y, 'o', color=color, markersize=5)
    plt.draw()

fig.canvas.mpl_connect('button_press_event', on_click)
plt.show(block=False)

<IPython.core.display.Javascript object>

# Predict Point Tracks for the Selected Points

In [9]:
print(select_points)
# TODO: write selected points in global coordinates in file

[array([265, 162]), array([ 91, 244]), array([338, 112]), array([329, 303]), array([137, 629]), array([274, 579]), array([252, 691]), array([252, 759]), array([324, 748]), array([335, 810])]


In [10]:
def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

In [11]:
#-- crop and resize image to 256x256 based on selected points
resize_height = 256  # @param {type: "integer"}
resize_width = 256  # @param {type: "integer"}

# get video metadata
fps_value = video.metadata.fps
bps_value = video.metadata.bps

# get extreme selected points
x_coords = [pt[0] for pt in select_points]
y_coords = [pt[1] for pt in select_points]
y_max, y_min = max(y_coords), min(y_coords) # get top- and bottom-most selected points
x_max, x_min = max(x_coords), min(x_coords) # get right- and left-most selected points

# get spread of selected points
x_spread = x_max-x_min
y_spread = y_max-y_min

#- crop around selected points
margin = 0.15 # 15% margin
x_margin = round(margin*x_spread)
y_margin = round(margin*y_spread)
largest_margin = max(x_margin, y_margin)

min_x_crop, max_x_crop = max(0, x_min-largest_margin-1), min(width, x_max+largest_margin+1)
min_y_crop, max_y_crop  = max(0, y_min-largest_margin-1), min(height, y_max+largest_margin+1)

cropped_width = max_x_crop-min_x_crop
cropped_height = max_y_crop-min_y_crop

frames_original = video.__array__()
frames_cropped = frames_original[:,min_y_crop:max_y_crop,min_x_crop:max_x_crop,:]

#- expand to square (background padding) and resize to desired dimensions
frames_crop_squared = []
frames_crop_resized = []
for frame in frames_cropped:
    im = Image.fromarray(frame, mode="RGB")
    squared_frame = expand2square(im, (255, 255, 255))
    resized_frame = squared_frame.resize((resize_height, resize_width))
    frames_crop_squared.append(np.asarray(squared_frame))
    frames_crop_resized.append(np.asarray(resized_frame))

media.write_video('cropped_vid.mp4', frames_cropped, fps=fps_value, bps=bps_value)
media.write_video('cropped_resized_vid.mp4', frames_crop_resized, fps=fps_value, bps=bps_value)

In [12]:
def convert_select_points_to_query_points(frame, points):
  """Convert select points to query points.

  Args:
    points: [num_points, 2], in [x, y]
  Returns:
    query_points: [num_points, 3], in [t, y, x]
  """
  points = np.stack(points)
  query_points = np.zeros(shape=(points.shape[0], 3), dtype=np.float32)
  query_points[:, 0] = frame
  query_points[:, 1] = points[:, 1]
  query_points[:, 2] = points[:, 0]
  return query_points


frames = media.resize_video(video, (resize_height, resize_width))
query_points = convert_select_points_to_query_points(select_frame, select_points) #-DC: still in global reference: from xy to tyx



# draw query_points on original image
fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(np.array(video[select_frame]))
color_list = [tuple(np.array(colormap[ii - 1])/ 255.0) for ii in range(len(select_points))]
ax.scatter(query_points[:, 2], query_points[:, 1], marker="o", color=color_list, s=25)
plt.show(block=False)

<IPython.core.display.Javascript object>

In [13]:
# shift selected points according to cropped image
query_points_crop = np.array([[0.0, cc[1]-min_y_crop, cc[2]-min_x_crop] for cc in query_points])

# draw query_points on cropped image
fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(np.array(frames_cropped[select_frame]))
ax.scatter(query_points_crop[:, 2], query_points_crop[:, 1], marker="o", color=color_list, s=25)
plt.show(block=False)

<IPython.core.display.Javascript object>

In [14]:
# shift according to squared
if cropped_width == cropped_height:
    query_points_sq = query_points_crop
elif cropped_width > cropped_height:
    y_shift = (cropped_width - cropped_height) // 2
    query_points_sq = np.array([[0.0, cc[1]+y_shift, cc[2]] for cc in query_points_crop])
else:
    x_shift = (cropped_height - cropped_width) // 2
    query_points_sq = np.array([[0.0, cc[1], cc[2]+x_shift] for cc in query_points_crop])

# draw query_points on squared image
fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(np.array(frames_crop_squared[select_frame]))
ax.scatter(query_points_sq[:, 2], query_points_sq[:, 1], marker="o", color=color_list, s=25)
plt.show(block=False)

<IPython.core.display.Javascript object>

In [15]:
query_points_fin = transforms.convert_grid_coordinates(query_points_sq, (1, max(cropped_height, cropped_width), max(cropped_height, cropped_width)), (1, resize_height, resize_width), coordinate_format='tyx')

# draw query_points on cropped and resized image
fig, ax = plt.subplots(figsize=(10, 5))
plt.imshow(np.array(frames_crop_resized[select_frame]))
plt.scatter(query_points_fin[:, 2], query_points_fin[:, 1], marker="o", color=color_list, s=25)
plt.show(block=False)

<IPython.core.display.Javascript object>

# Track points

In [16]:
tracks, visibles = inference(np.stack(frames_crop_resized, axis=0), query_points_fin)

# Visualize sparse point tracks
tracks = transforms.convert_grid_coordinates(tracks, (resize_width, resize_height), (max(cropped_height, cropped_width), max(cropped_height, cropped_width)))
print("done 1")
video_viz = viz_utils.paint_point_track(np.stack(frames_crop_squared, axis=0), tracks, visibles, colormap)
print("done 2")
media.write_video('tracked_points_video.mp4', video_viz, fps=fps_value, bps=bps_value)
print("done 3")
# media.show_video(video_viz, fps=30)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


done 1
done 2
done 3
