In [1]:
import torch
import torchvision
import tqdm
from tapnet import evaluation_datasets



2025-11-22 23:10:08.808623: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763881808.817618    2740 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763881808.820672    2740 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1763881808.829107    2740 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763881808.829114    2740 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763881808.829115    2740 computation_placer.cc:177] computation placer alr

In [2]:
davis_dataset = evaluation_datasets.create_davis_dataset(
    davis_points_path='tapvid_davis/tapvid_davis.pkl',
    query_mode='first',
    full_resolution=False,
    resolution=(256, 256),
)

cached_dataset = []
for j, batch in enumerate(davis_dataset):
  cached_dataset.append(batch)
  print(
      'video id',
      j,
  )

video id 0
video id 1
video id 2
video id 3
video id 4
video id 5
video id 6
video id 7
video id 8
video id 9
video id 10
video id 11
video id 12
video id 13
video id 14
video id 15
video id 16
video id 17
video id 18
video id 19
video id 20
video id 21
video id 22
video id 23
video id 24
video id 25
video id 26
video id 27
video id 28
video id 29


In [3]:
import numpy as np
from tapnet.tapnext.tapnext_torch import TAPNext
from tapnet.tapnext.tapnext_torch_utils import restore_model_from_jax_checkpoint, tracker_certainty
import torch.nn.functional as F

In [4]:
def run_eval_per_frame(
    model,
    batch,
    get_trackwise_metrics=True,
    radius=8,
    threshold=0.5,
    use_certainty=False,
):
  with torch.no_grad():
    pred_tracks, track_logits, visible_logits, tracking_state = model(
        video=batch['video'][:, :1], query_points=batch['query_points']
    )
    pred_visible = visible_logits > 0
    pred_tracks, pred_visible = [pred_tracks.cpu()], [pred_visible.cpu()]
    pred_track_logits, pred_visible_logits = [track_logits.cpu()], [
        visible_logits.cpu()
    ]
    for frame in tqdm.tqdm(range(1, batch['video'].shape[1])):
      # ***************************************************
      # HERE WE RUN POINT TRACKING IN PURELY ONLINE FASHION
      # ***************************************************
      (
          curr_tracks,
          curr_track_logits,
          curr_visible_logits,
          tracking_state,
      ) = model(
          video=batch['video'][:, frame : frame + 1],
          state=tracking_state,
      )
      curr_visible = curr_visible_logits > 0
      # ***************************************************
      pred_tracks.append(curr_tracks.cpu())
      pred_visible.append(curr_visible.cpu())
      pred_track_logits.append(curr_track_logits.cpu())
      pred_visible_logits.append(curr_visible_logits.cpu())
    tracks = torch.cat(pred_tracks, dim=1).transpose(1, 2)
    pred_visible = torch.cat(pred_visible, dim=1).transpose(1, 2)
    track_logits = torch.cat(pred_track_logits, dim=1).transpose(1, 2)
    visible_logits = torch.cat(pred_visible_logits, dim=1).transpose(1, 2)

    pred_certainty = tracker_certainty(tracks, track_logits, radius)
    pred_visible_and_certain = (
        F.sigmoid(visible_logits) * pred_certainty
    ) > threshold

    if use_certainty:
      occluded = ~(pred_visible_and_certain.squeeze(-1))
    else:
      occluded = ~(pred_visible.squeeze(-1))

  scalars = evaluation_datasets.compute_tapvid_metrics(
      batch['query_points'].cpu().numpy(),
      batch['occluded'].cpu().numpy(),
      batch['target_points'].cpu().numpy(),
      occluded.numpy() + 0.0,
      tracks.numpy()[..., ::-1],
      query_mode='first',
      get_trackwise_metrics=get_trackwise_metrics,
  )
  return (
      tracks.numpy()[..., ::-1],
      occluded.numpy(),
      {k: v.sum(0) for k, v in scalars.items()},
  )


# @title Function for raw data to the input format {form-width: "25%"}
def deterministic_eval(cached_dataset, strided=False):
  if not strided:
    for sample in cached_dataset:
      batch = sample['davis'].copy()
      # batch['video'] = (batch['video'] + 1) / 2
      batch['visible'] = np.logical_not(batch['occluded'])[..., None]
      batch['padding'] = np.ones(
          batch['query_points'].shape[:2], dtype=np.bool_
      )
      batch['loss_mask'] = np.ones(
          batch['target_points'].shape[:3] + (1,), dtype=np.float32
      )
      batch['appearance'] = np.ones(
          batch['target_points'].shape[:3] + (1,), dtype=np.float32
      )

      yield batch
  else:
    for sample in cached_dataset:
      batch = sample['davis'].copy()
      # batch['video'] = (batch['video'] + 1) / 2
      batch['visible'] = np.logical_not(batch['occluded'])[..., None]
      batch['padding'] = np.ones(
          batch['query_points'].shape[:2], dtype=np.bool_
      )
      batch['loss_mask'] = np.ones(
          batch['target_points'].shape[:3] + (1,), dtype=np.float32
      )
      batch['appearance'] = np.ones(
          batch['target_points'].shape[:3] + (1,), dtype=np.float32
      )
      backward_batch = {k: v.copy() for k, v in batch.items()}
      for key in ['visible', 'appearance', 'loss_mask', 'target_points']:
        backward_batch[key] = np.flip(backward_batch[key], axis=2)
      backward_batch['video'] = np.flip(backward_batch['video'], axis=1)
      backward_queries = (
          backward_batch['video'].shape[1]
          - backward_batch['query_points'][..., 0]
          - 1
      )
      backward_batch['query_points'][..., 0] = backward_queries
      yield batch, backward_batch

In [5]:
model = TAPNext(image_size=(256, 256))
ckpt_path = 'bootstapnext_ckpt.npz'
model = restore_model_from_jax_checkpoint(model, ckpt_path)
model.cuda()

TAPNext(
  (lin_proj): Conv2d(3, 768, kernel_size=(8, 8), stride=(8, 8))
  (blocks): ModuleList(
    (0-11): 12 x TRecViTBlock(
      (ssm_block): ResidualBlock(
        (temporal_pre_norm): RMSNorm()
        (recurrent_block): RecurrentBlock(
          (linear_y): Linear(in_features=768, out_features=768, bias=True)
          (linear_x): Linear(in_features=768, out_features=768, bias=True)
          (linear_out): Linear(in_features=768, out_features=768, bias=True)
          (conv_1d): CausalConv1D()
          (rg_lru): RGLRU(
            (input_gate): BlockDiagonalLinear()
            (a_gate): BlockDiagonalLinear()
          )
        )
        (channel_pre_norm): RMSNorm()
        (mlp_block): MLPBlock(
          (ffw_up): Einsum()
          (ffw_down): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (vit_block): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (

In [6]:
standard_eval_scalars_list = []
preds = []
for batch in deterministic_eval(cached_dataset):
  batch = {k: torch.from_numpy(v).cuda().float() for k, v in batch.items()}
  with torch.amp.autocast('cuda', dtype=torch.float16, enabled=True):
    tracks, occluded, scores = run_eval_per_frame(
        model, batch, get_trackwise_metrics=False, use_certainty=False
    )
  standard_eval_scalars_list.append(scores)
  preds.append((tracks, occluded))


print('')
print(
    'AJ',
    np.mean([
        standard_eval_scalars_list[k]['average_jaccard']
        for k in range(len(standard_eval_scalars_list))
    ]),
)
print(
    'OA',
    np.mean([
        standard_eval_scalars_list[k]['occlusion_accuracy']
        for k in range(len(standard_eval_scalars_list))
    ]),
)
print(
    'PTS',
    np.mean([
        standard_eval_scalars_list[k]['average_pts_within_thresh']
        for k in range(len(standard_eval_scalars_list))
    ]),
)

100%|██████████████████████████████████████████████████████████████| 89/89 [00:03<00:00, 23.71it/s]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████████████████████████████████████████████████████████| 74/74 [00:03<00:00, 23.51it/s]
100%|██████████████████████████████████████████████████████████████| 39/39 [00:01<00:00, 23.44it/s]
100%|██████████████████████████████████████████████████████████████| 83/83 [00:03<00:00, 23.47it/s]
100%|██████████████████████████████████████████████████████████████| 51/51 [00:02<00:00, 23.56it/s]
100%|██████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 23.55it/s]
100%|██████████████████████████████████████████████████████████████| 33/33 [00:01<00:00, 23.55it/s]
100%|██████████████████████████████████████████████████████████████| 98/98 [00:04<00:00, 23.55it/s]
100%|██████████████████████████████████████████████████████████████| 65/65 [00:02<00:00, 23.48it/s]
100%|████████████████████████


AJ 0.6654442090194489
OA 0.9216233752944766
PTS 0.7948194112043738





In [7]:
davis_dataset_strided = evaluation_datasets.create_davis_dataset(
    davis_points_path='tapvid_davis/tapvid_davis.pkl',
    query_mode='strided',
    full_resolution=False,
    resolution=(256, 256),
)

cached_dataset_strided = []
for j, batch in enumerate(davis_dataset_strided):
  cached_dataset_strided.append(batch)
  print('video id', j)

video id 0
video id 1
video id 2
video id 3
video id 4
video id 5
video id 6
video id 7
video id 8
video id 9
video id 10
video id 11
video id 12
video id 13
video id 14
video id 15
video id 16
video id 17
video id 18
video id 19
video id 20
video id 21
video id 22
video id 23
video id 24
video id 25
video id 26
video id 27
video id 28
video id 29


In [8]:
import jax

eval_results = list()
for vid, (fbatch, bbatch) in enumerate(deterministic_eval(cached_dataset_strided, strided=True)):
  fbatch = {k: torch.from_numpy(v).cuda().float() for k, v in fbatch.items()}
  bbatch = {k: torch.from_numpy(v.copy()).cuda().float() for k, v in bbatch.items()}
  with torch.amp.autocast('cuda', dtype=torch.float16, enabled=True):
    ftracks, foccluded, _ = run_eval_per_frame(model, fbatch, get_trackwise_metrics=False, use_certainty=False)
    btracks, boccluded, _ = run_eval_per_frame(model, bbatch, get_trackwise_metrics=False, use_certainty=False)
  btracks, boccluded = np.flip(btracks, axis=2), np.flip(boccluded, axis=2)
  # tracks = [1, q, t, 2]
  for q in range(fbatch['query_points'].shape[1]):
    t = int((fbatch['query_points'][0, q, 0]).item())
    ftracks[0, q, :t] = btracks[0, q, :t]
    foccluded[0, q, :t] = boccluded[0, q, :t]
  tracks, occluded = ftracks, foccluded
  scalars = evaluation_datasets.compute_tapvid_metrics(
      cached_dataset_strided[vid]['davis']['query_points'],
      cached_dataset_strided[vid]['davis']['occluded'],
      cached_dataset_strided[vid]['davis']['target_points'],
      occluded + 0.,
      tracks,
      query_mode='strided',
      get_trackwise_metrics=False,
  )
  eval_results.append(jax.tree.map(lambda x: np.array(np.sum(x, axis=0)), scalars))

print('')
print(
    'AJ',
    np.mean([
        eval_results[k]['average_jaccard']
        for k in range(len(eval_results))
    ]),
)
print(
    'OA',
    np.mean([
        eval_results[k]['occlusion_accuracy']
        for k in range(len(eval_results))
    ]),
)
print(
    'PTS',
    np.mean([
        eval_results[k]['average_pts_within_thresh']
        for k in range(len(eval_results))
    ]),
)

100%|██████████████████████████████████████████████████████████████| 89/89 [00:03<00:00, 22.78it/s]
100%|██████████████████████████████████████████████████████████████| 89/89 [00:03<00:00, 22.75it/s]
100%|██████████████████████████████████████████████████████████████| 74/74 [00:03<00:00, 22.01it/s]
100%|██████████████████████████████████████████████████████████████| 74/74 [00:03<00:00, 22.02it/s]
100%|██████████████████████████████████████████████████████████████| 39/39 [00:01<00:00, 22.87it/s]
100%|██████████████████████████████████████████████████████████████| 39/39 [00:01<00:00, 22.84it/s]
100%|██████████████████████████████████████████████████████████████| 83/83 [00:04<00:00, 18.86it/s]
100%|██████████████████████████████████████████████████████████████| 83/83 [00:04<00:00, 18.87it/s]
100%|██████████████████████████████████████████████████████████████| 51/51 [00:02<00:00, 20.37it/s]
100%|██████████████████████████████████████████████████████████████| 51/51 [00:02<00:00, 20.39it/s]



AJ 0.7074496686823425
OA 0.9184522766619835
PTS 0.8315253897988535





In [12]:
davis_dataset = evaluation_datasets.create_davis_dataset(
    davis_points_path='tapvid_davis/tapvid_davis.pkl',
    query_mode='first',
    full_resolution=False,
    resolution=(256, 256),
)
cached_dataset = []
for j, batch in enumerate(davis_dataset):
  cached_dataset.append(batch)
  print(
      'video id',
      j,
  )

video id 0
video id 1
video id 2
video id 3
video id 4
video id 5
video id 6
video id 7
video id 8
video id 9
video id 10
video id 11
video id 12
video id 13
video id 14
video id 15
video id 16
video id 17
video id 18
video id 19
video id 20
video id 21
video id 22
video id 23
video id 24
video id 25
video id 26
video id 27
video id 28
video id 29


In [15]:
len(cached_dataset)
cached_dataset[0]['davis'].keys()

dict_keys(['video', 'query_points', 'target_points', 'occluded'])

In [18]:
standard_eval_scalars_list = []
preds = []


for sample in cached_dataset[:1]:
  batch = sample['davis'].copy()
  # batch['video'] = (batch['video'] + 1) / 2
  batch['visible'] = np.logical_not(batch['occluded'])[..., None]
  batch['padding'] = np.ones(
      batch['query_points'].shape[:2], dtype=np.bool_
  )
  batch['loss_mask'] = np.ones(
      batch['target_points'].shape[:3] + (1,), dtype=np.float32
  )
  batch['appearance'] = np.ones(
      batch['target_points'].shape[:3] + (1,), dtype=np.float32
  )

  batch = {k: torch.from_numpy(v).cuda().float() for k, v in batch.items()}
  with torch.amp.autocast('cuda', dtype=torch.float16, enabled=True):
    tracks, occluded, scores = run_eval_per_frame(
        model, batch, get_trackwise_metrics=False, use_certainty=False
    )
    print(tracks.shape, occluded.shape, scores)

100%|██████████████████████████████████████████████████████████████| 89/89 [00:03<00:00, 23.68it/s]

(1, 5, 90, 2) (1, 5, 90) {'occlusion_accuracy': np.float64(0.9617977528089887), 'pts_within_1': np.float64(0.2966292134831461), 'jaccard_1': np.float64(0.1765498652291105), 'pts_within_2': np.float64(0.6337078651685393), 'jaccard_2': np.float64(0.4696969696969697), 'pts_within_4': np.float64(0.9168539325842696), 'jaccard_4': np.float64(0.814968814968815), 'pts_within_8': np.float64(0.9977528089887641), 'jaccard_8': np.float64(0.9573991031390134), 'pts_within_16': np.float64(1.0), 'jaccard_16': np.float64(0.9617977528089887), 'average_jaccard': np.float64(0.6760825011685795), 'average_pts_within_thresh': np.float64(0.7689887640449438)}



