In [None]:
# @title Install Code and Dependencies {form-width: "25%"}
!pip install git+https://github.com/google-deepmind/tapnet.git

/bin/sh: line 1: pip: command not found


In [None]:
# @title Download TAPVid-DAVIS Dataset {form-width: "25%"}
!wget --no-check-certificate https://storage.googleapis.com/dm-tapnet/tapvid_davis.zip
!unzip tapvid_davis.zip

In [None]:
# @title Download Model {form-width: "25%"}

%mkdir -p tapnet/checkpoints
!wget -P tapnet/checkpoints --no-check-certificate https://storage.googleapis.com/dm-tapnet/tapnext/tapnext_ckpt.npz
%ls tapnet/checkpoints

In [None]:
# @title Load TAP-Vid-DAVIS dataset on 256x256 {form-width: "25%"}
davis_dataset = evaluation_datasets.create_davis_dataset(
    davis_points_path='tapvid_davis/tapvid_davis.pkl', query_mode='first', full_resolution=False, resolution=(256, 256))

cached_dataset = []
for j, batch in enumerate(davis_dataset):
  cached_dataset.append(batch)
  print('video id', j, jax.tree_util.tree_map(lambda x: x.shape, batch))

In [None]:
compute_dtype = "float32"

ssm_vit_backbone = ssm_vit.Model(
    # try S/16, B/8, L/16
    variant="B/8",
    patch_size=(1, 8, 8),
    pool_type="queries",
    posemb="learn",
    posemb_full="sincos2d",
    rep_size=True,
    dropout=0.0,
    lru_width=768,
    remat=False,
    dtype_mm=compute_dtype,
    dtype_ssm=compute_dtype,
    query_scale=1,
    spatiotemporal_attn=False,
)


 # pytype: disable=bad-return-type
model = video_ssm_tracker.TAPNextTracker(
      backbone=ssm_vit_backbone,
  )

In [None]:
def tracker_forward(params, frame, queries=None, state=None):
  """Forward function.

  Assuming all queries are in frame 0, then the first forward
  step only includes queries and not the state, subsquent forward
  calls include the state but not queries.
  """
  if state is None:
    # assme all query points are valid
    assert queries is not None
    query_padding = jnp.ones_like(queries)[..., 0]
    result, _ = model.apply(
      variables={'params': params},
      frames=frame,
      query_points=queries,
      query_padding=query_padding,
      method=model.forward_step,
      mutable='intermediates')
  else:
    result, _ = model.apply(
    variables={'params': params},
    frames=frame,
    state=state,
    method=model.forward_step,
    mutable='intermediates')

  return (
    result.tracks,
    result.visible,
    result.state,
    result.track_logits,
    result.visible_logits,
  )


In [None]:
# @title Function for per frame evaluation {form-width: "25%"}

import tqdm

def get_window(
    coord, softmax, radius=8
):
  """Note: coord is assumed to be a raster coordinate."""
  start = jnp.maximum(jnp.array(jnp.floor(coord - radius - 0.5), jnp.int32), 0)
  softmax = jax.lax.dynamic_slice(softmax, [start], [radius * 2 + 1])
  coord = start + 0.5 + jnp.arange(radius * 2 + 1)
  return softmax, coord


def get_certainty(
    coord_yx, track_logits, radius: int = 8
):
  """Get uncertainty from coordinate logits for a single point/frame."""
  logits_y, logits_x = jnp.split(track_logits, 2, axis=-1)
  track_softmax_y = jax.nn.softmax(logits_y)
  track_softmax_x = jax.nn.softmax(logits_x)
  sm_y, coord_y = get_window(coord_yx[0], track_softmax_y)
  sm_x, coord_x = get_window(coord_yx[1], track_softmax_x)
  sm = sm_y[:, jnp.newaxis] * sm_x[jnp.newaxis, :]
  grid_x, grid_y = jnp.meshgrid(coord_x, coord_y)
  grid = jnp.stack([grid_y, grid_x], axis=-1)
  in_radius = (
      jnp.sum(jnp.square(grid - coord_yx), axis=-1) <= jnp.square(radius) + 1e-8
  )
  return jnp.sum(sm * in_radius)[jnp.newaxis]


def tracker_certainty(
    tracks, track_logits, radius: int = 8
):
  """Get certainty from coordinate logits for all points/frames.

  Args:
    tracks: Tracks in [y, x], raster coordinates.
    track_logits: Logits for each track, with y logits first following x logits,
      same number of logits as pixels in image.
    radius: Radius of the circle in which probability mass is summed.

  Returns:
    uncertainty probability between 0 and 1.
  """
  vmapped_certain_fn = functools.partial(get_certainty, radius=radius)
  for _ in range(len(tracks.shape) - 1):
    vmapped_certain_fn = jax.vmap(vmapped_certain_fn)
  uncertainty = vmapped_certain_fn(tracks, track_logits)
  return uncertainty


def run_eval_per_frame(
    modelf,
    params,
    batch,
    get_trackwise_metrics=True,
    radius=8,
    threshold=0.5,
    use_certainty=False,
):
  pred_tracks, pred_visible, tracking_state, track_logits, visible_logits = (
      modelf(
          params,
          frame=batch['video'][:, :1],
          queries=batch['query_points'],
      )
  )
  pred_tracks, pred_visible = [pred_tracks], [pred_visible]
  pred_track_logits, pred_visible_logits = [track_logits], [visible_logits]
  for frame in range(1, batch['video'].shape[1]):
    (
        curr_tracks,
        curr_visible,
        tracking_state,
        curr_track_logits,
        curr_visible_logits,
    ) = modelf(
        params,
        frame=batch['video'][:, frame : frame + 1],
        state=tracking_state,
    )
    import pdb; pdb.set_trace()
    pred_tracks.append(np.array(jax.device_get(curr_tracks)))
    pred_visible.append(np.array(jax.device_get(curr_visible)))
    pred_track_logits.append(np.array(jax.device_get(curr_track_logits)))
    pred_visible_logits.append(np.array(jax.device_get(curr_visible_logits)))
  tracks = np.concatenate(pred_tracks, axis=2)
  pred_visible = np.concatenate(pred_visible, axis=2)
  track_logits = np.concatenate(pred_track_logits, axis=2)
  visible_logits = np.concatenate(pred_visible_logits, axis=2)

  pred_certainty = tracker_certainty(tracks, track_logits, radius)

  pred_visible_and_certain = (
      jax.nn.sigmoid(visible_logits) * pred_certainty
  ) > threshold

  if use_certainty:
    occluded = np.logical_not(pred_visible_and_certain.squeeze(-1))
  else:
    occluded = np.logical_not(pred_visible.squeeze(-1))

  scalars = evaluation_datasets.compute_tapvid_metrics(
      batch['query_points'],
      batch['occluded'],
      batch['target_points'],
      occluded + 0.0,
      tracks[..., ::-1],
      query_mode='first',
      get_trackwise_metrics=get_trackwise_metrics,
  )
  return (
      tracks[..., ::-1],
      occluded,
      jax.tree.map(lambda x: np.array(np.sum(x, axis=0)), scalars),
  )

# @title Function for raw data to the input format {form-width: "25%"}
def deterministic_eval(cached_dataset, strided=False):
  if not strided:
    for sample in tqdm.tqdm(cached_dataset):
      batch = sample['davis'].copy()
      # batch['video'] = (batch['video'] + 1) / 2
      batch['visible'] = np.logical_not(batch['occluded'])[..., None]
      batch['padding'] = np.ones(
          batch['query_points'].shape[:2], dtype=np.bool_
      )
      batch['loss_mask'] = np.ones(
          batch['target_points'].shape[:3] + (1,), dtype=np.float32
      )
      batch['appearance'] = np.ones(
          batch['target_points'].shape[:3] + (1,), dtype=np.float32
      )

      yield batch
  else:
    for sample in tqdm.tqdm(cached_dataset):
      batch = sample['davis'].copy()
      # batch['video'] = (batch['video'] + 1) / 2
      batch['visible'] = np.logical_not(batch['occluded'])[..., None]
      batch['padding'] = np.ones(
          batch['query_points'].shape[:2], dtype=np.bool_
      )
      batch['loss_mask'] = np.ones(
          batch['target_points'].shape[:3] + (1,), dtype=np.float32
      )
      batch['appearance'] = np.ones(
          batch['target_points'].shape[:3] + (1,), dtype=np.float32
      )
      backward_batch = {k: v.copy() for k, v in batch.items()}
      for key in ['visible', 'appearance', 'loss_mask', 'target_points']:
        backward_batch[key] = np.flip(backward_batch[key], axis=2)
      backward_batch['video'] = np.flip(backward_batch['video'], axis=1)
      backward_queries = (
          backward_batch['video'].shape[1]
          - backward_batch['query_points'][..., 0]
          - 1
      )
      backward_batch['query_points'][..., 0] = backward_queries
      yield batch, backward_batch

In [None]:
from big_vision import utils
import collections

# Forked from the main code for simplicity, but once submitted it will be better
# just to adhoc-import this function.
def recover_tree(keys, values):
  """Recovers a tree as a nested dict from flat names and values.

  This function is useful to analyze checkpoints that are saved by our programs
  without need to access the exact source code of the experiment. In particular,
  it can be used to extract an reuse various subtrees of the scheckpoint, e.g.
  subtree of parameters.

  Args:
    keys: a list of keys, where '/' is used as separator between nodes.
    values: a list of leaf values.

  Returns:
    A nested tree-like dict.
  """
  tree = {}
  sub_trees = collections.defaultdict(list)
  for k, v in zip(keys, values):
    if "/" not in k:
      tree[k] = v
    else:
      k_left, k_right = k.split("/", 1)
      sub_trees[k_left].append((k_right, v))
  for k, kv_pairs in sub_trees.items():
    k_subtree, v_subtree = zip(*kv_pairs)
    tree[k] = recover_tree(k_subtree, v_subtree)
  return tree


ckpt_path = 'tapnet/checkpoints/tapnext_ckpt.npz'
loaded_params = utils.npload(ckpt_path)
k, v = zip(*list(loaded_params.items()))
loaded_params = recover_tree(k, v)


In [None]:
# @title Per-frame inference

standard_eval_scalars_list = []
preds = []
for batch in deterministic_eval(cached_dataset):
  tracks, occluded, scores = run_eval_per_frame(
      tracker_forward, loaded_params, batch, get_trackwise_metrics=False, use_certainty=False
  )
  standard_eval_scalars_list.append(scores)
  preds.append((tracks, occluded))


print('')
print(
    np.mean([
        standard_eval_scalars_list[k]['average_jaccard']
        for k in range(len(standard_eval_scalars_list))
    ])
)
print(
    np.mean([
        standard_eval_scalars_list[k]['occlusion_accuracy']
        for k in range(len(standard_eval_scalars_list))
    ])
)
print(
    np.mean([
        standard_eval_scalars_list[k]['average_pts_within_thresh']
        for k in range(len(standard_eval_scalars_list))
    ])
)