In [1]:
from __future__ import annotations

import dataclasses
import einops
import os
import io
import numpy as np
from tqdm import tqdm
from tapnet.trajan import track_autoencoder
from tapnet.tapvid import evaluation_datasets
from tapnet.utils import transforms
from tapnet.utils import viz_utils
from tapnet.utils import model_utils

### load parameters

In [2]:
def npload(fname):
  loaded = np.load(fname, allow_pickle=False)
  if isinstance(loaded, np.ndarray):
    return loaded
  else:
    return dict(loaded)


def recover_tree(flat_dict):
  tree = (
      {}
  )  # Initialize an empty dictionary to store the resulting tree structure
  for (
      k,
      v,
  ) in (
      flat_dict.items()
  ):  # Iterate over each key-value pair in the flat dictionary
    parts = k.split(
        '/'
    )  # Split the key into parts using "/" as a delimiter to build the tree structure
    node = tree  # Start at the root of the tree
    for part in parts[
        :-1
    ]:  # Loop through each part of the key, except the last one
      if (
          part not in node
      ):  # If the current part doesn't exist as a key in the node, create an empty dictionary for it
        node[part] = {}
      node = node[part]  # Move down the tree to the next level
    node[parts[-1]] = v  # Set the value at the final part of the key
  return tree  # Return the reconstructed tree

trajan_checkpoint_path = "/restricted/projectnb/cs599dg/mwakeham/tapnet/checkpoints/track_autoencoder_ckpt.npz"
params = recover_tree(npload(trajan_checkpoint_path))

### process data functions

In [3]:
# @title Preprocessor for Tracks

@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class ProcessTracksForTrackAutoencoder:
  """Samples tracks and fills out support_tracks, query_points etc.

  TrackAutoencoder format which will be output from this transform:
   video: float["*B T H W 3"]
   support_tracks: float["*B QS T 2"]
   support_tracks_visible: float["*B QS T 1"]
   query_points: float["*B Q 3"]
  """

  # note that we do not use query points in the encoding, so it is expected
  # that num_support_tracks >> num_target_tracks

  num_support_tracks: int
  num_target_tracks: int

  # If true, assume that everything after the boundary_frame is padding,
  # so don't sample any query points after the boundary_frame, and only sample
  # target tracks that have at least one visible frame before the boundary.
  before_boundary: bool = True
  episode_length: int = 150

  # Keys.
  video_key: str = "video"
  tracks: str = "tracks"  # [time, num_points, 2]
  visible_key: str = "visible"  # [time, num_points, 1]

  def random_map(self, features):

    # set tracks xy and compute visibility
    tracks_xy = features[self.tracks][..., :2]
    tracks_xy = np.asarray(tracks_xy, np.float32)
    boundary_frame = features["video"].shape[0]

    # visibles already post-processed by compute_point_tracks.py
    visibles = np.asarray(features[self.visible_key], np.float32)

    # pad to 'episode_length' frames
    if self.before_boundary:
      # if input video is longer than episode_length, crop to episode_length
      if self.episode_length - visibles.shape[0] < 0:
        visibles = visibles[: self.episode_length]
        tracks_xy = tracks_xy[: self.episode_length]

      visibles = np.pad(
          visibles,
          [[0, self.episode_length - visibles.shape[0]], [0, 0]],
          constant_values=0.0,
      )
      tracks_xy = np.pad(
          tracks_xy,
          [[0, self.episode_length - tracks_xy.shape[0]], [0, 0], [0, 0]],
          constant_values=0.0,
      )

    # Samples indices for support tracks and target tracks.
    num_input_tracks = tracks_xy.shape[1]
    idx = np.arange(num_input_tracks)
    np.random.shuffle(idx)

    assert (
        num_input_tracks >= self.num_support_tracks + self.num_target_tracks
    ), (
        (
            f"num_input_tracks {num_input_tracks} must be greater than"
            f" num_support_tracks {self.num_support_tracks} + num_target_tracks"
            f" {self.num_target_tracks}"
        ),
    )

    idx_support = idx[-self.num_support_tracks :]
    idx_target = idx[: self.num_target_tracks]

    # Gathers support tracks from `features`.  Features are of shape
    # [time, num_points, 2]
    support_tracks = tracks_xy[..., idx_support, :]
    support_tracks_visible = visibles[..., idx_support]

    # Gathers target tracks from `features`.
    target_tracks = tracks_xy[..., idx_target, :]
    target_tracks_visible = visibles[..., idx_target]

    # transpose to [num_points, time, ...]
    support_tracks = np.transpose(support_tracks, [1, 0, 2])
    support_tracks_visible = np.expand_dims(
        np.transpose(support_tracks_visible, [1, 0]), axis=-1
    )

    # [time, point_id, x/y] -> [point_id, time, x/y]
    target_tracks = np.transpose(target_tracks, [1, 0, 2])
    target_tracks_visible = np.transpose(target_tracks_visible, [1, 0])

    # Sample query points as random visible points
    num_target_tracks = target_tracks_visible.shape[0]
    num_frames = target_tracks_visible.shape[1]
    random_frame = np.zeros(num_target_tracks, dtype=np.int64)

    for i in range(num_target_tracks):
      visible_indices = np.where(target_tracks_visible[i] > 0)[0]
      if len(visible_indices) > 0:
        # Choose a random frame index from the visible ones
        random_frame[i] = np.random.choice(visible_indices)
      else:
        # If no frame is visible for a track, default to frame 0
        # (or handle as appropriate for your use case)
        random_frame[i] = 0

    # Create one-hot encoding based on the randomly selected frame for each track
    idx = np.eye(num_frames, dtype=np.float32)[
        random_frame
    ]  # [num_target_tracks, num_frames]

    # Use the one-hot index to select the coordinates at the chosen frame
    target_queries_xy = np.sum(
        target_tracks * idx[..., np.newaxis], axis=1
    )  # [num_target_tracks, 2]

    # Stack frame index and coordinates: [t, x, y]
    target_queries = np.stack(
        [
            random_frame.astype(np.float32),
            target_queries_xy[..., 0],
            target_queries_xy[..., 1],
        ],
        axis=-1,
    )  # [num_target_tracks, 3]

    # Add channel dimension to target_tracks_visible
    target_tracks_visible = np.expand_dims(target_tracks_visible, axis=-1)

    # Updates `features` to contain these *new* features and add batch dim.
    features_new = {
        "support_tracks": support_tracks[None, :],
        "support_tracks_visible": support_tracks_visible[None, :],
        "query_points": target_queries[None, :],
        "target_points": target_tracks[None, :],
        "boundary_frame": np.array([boundary_frame]),
        "target_tracks_visible": target_tracks_visible[None, :],
        "target_tracks_indices": idx_target[None, :]
    }
    features.update(features_new)
    return features

### load model and preprocessor

In [4]:
import jax

preprocessor = ProcessTracksForTrackAutoencoder(
    num_support_tracks=64,
    num_target_tracks=32,
    video_key="video", 
    before_boundary=True,
)

trajan_checkpoint_path = "/restricted/projectnb/cs599dg/mwakeham/tapnet/checkpoints/track_autoencoder_ckpt.npz"
params = recover_tree(npload(trajan_checkpoint_path))
model = track_autoencoder.TrackAutoEncoder(decoder_scan_chunk_size=32)

@jax.jit
def forward(params, inputs):
    outputs, latents = model.apply({'params': params}, inputs)
    return outputs, latents

### load existing tracks and run trajan

In [23]:
scene = "tennis"

In [None]:
def normalize_tracks(tracks, width=640, height=360):
    tracks_normalized = tracks.copy()
    tracks_normalized[..., 0] = tracks[..., 0] / width
    tracks_normalized[..., 1] = tracks[..., 1] / height
    return tracks_normalized

data_dir = f"/restricted/projectnb/cs599dg/mwakeham/trajectory_examples/{scene}"

example_dirs = [d for d in os.listdir(data_dir) if d.startswith("example")]
example_dirs.sort()

all_outputs = {}
batch_data = {}
all_latents = {}

for example_dir in tqdm(example_dirs):
    full_path = os.path.join(data_dir, example_dir)
    
    # Process each view in this example
    while True:
        tracks_path = os.path.join(full_path, f"tracks_view{view_idx}.npy")
        visible_path = os.path.join(full_path, f"visible_view{view_idx}.npy")
        
        if not os.path.exists(tracks_path):
            break
            
        tracks = np.load(tracks_path)
        visible = np.load(visible_path)
        
        tracks = normalize_tracks(tracks, width=640, height=360)
        
        if len(visible.shape) == 3 and visible.shape[2] == 1:
            visible = visible.squeeze(-1)
        
        dummy_video = np.zeros((150, 360, 640, 3), dtype=np.float32)
        
        batch = {
            "video": dummy_video,
            "tracks": tracks.astype(np.float32),
            "visible": visible.astype(np.float32),
        }
        
        batch = preprocessor.random_map(batch)
        batch.pop("tracks", None)
        outputs, latents = forward(params, batch)
        
        if example_dir not in all_outputs:
            all_outputs[example_dir] = {}
            batch_data[example_dir] = {}
            all_latents[example_dir] = {}
            
        all_outputs[example_dir][f'view{view_idx}'] = outputs
        batch_data[example_dir][f'view{view_idx}'] = batch
        all_latents[example_dir][f'view{view_idx}'] = latents
        
        view_idx += 1

print(f"Processed {len(all_outputs)} examples")

### calculate metrics

In [None]:
all_metrics = {}
jaccard_scores = []
occlusion_scores = []

for example_name, views_data in tqdm(all_outputs.items()):
    all_metrics[example_name] = {}
    
    for view_name, outputs in views_data.items():
        batch = batch_data[example_name][view_name]
        
        height, width = 360, 640
        
        reconstructed_tracks = transforms.convert_grid_coordinates(
            outputs.tracks[0], (1, 1), (width, height)
        )

        target_tracks_vis = transforms.convert_grid_coordinates(
            batch['target_points'][0], (1, 1), (width, height)
        )

        reconstructed_visibles = model_utils.postprocess_occlusions(
            outputs.visible_logits, outputs.certain_logits
        )

        video_length = batch['boundary_frame'][0]

        query_points = np.zeros((
            reconstructed_visibles.shape[0],
            batch['target_tracks_visible'].shape[1],
            1,
        ))

        # Compute TapVid metrics
        metrics = evaluation_datasets.compute_tapvid_metrics(
            query_points=query_points,
            gt_occluded=1 - batch['target_tracks_visible'][..., :video_length, 0],
            gt_tracks=target_tracks_vis[None, ..., :video_length, :],
            pred_occluded=reconstructed_visibles[..., :video_length, 0],
            pred_tracks=reconstructed_tracks[..., :video_length, :],
            query_mode='strided',
            get_trackwise_metrics=False,
        )

        jaccard = np.mean([metrics[f'jaccard_{d}'] for d in [1, 2, 4, 8, 16]])
        occlusion_acc = metrics['occlusion_accuracy'].mean()
        
        jaccard_scores.append(jaccard)
        occlusion_scores.append(occlusion_acc)
        
        all_metrics[example_name][view_name] = {
            'average_jaccard': jaccard,
            'occlusion_accuracy': occlusion_acc,
            'full_metrics': metrics
        }

final_avg_jaccard = np.mean(jaccard_scores)
final_avg_occlusion = np.mean(occlusion_scores)

print(f"Final Average Jaccard: {final_avg_jaccard:.4f}")
print(f"Final Average Occlusion Accuracy: {final_avg_occlusion:.4f}")
print(f"Processed {len(jaccard_scores)} total views")

### save all results

In [None]:
results_dir = f"/restricted/projectnb/cs599dg/mwakeham/trajectory_examples_trajan_results/{scene}"
os.makedirs(results_dir, exist_ok=True)

print(f"Saving metrics to: {results_dir}")

for example_name, views_data in tqdm(all_outputs.items()):
    example_dir = os.path.join(results_dir, example_name)
    os.makedirs(example_dir, exist_ok=True)
    
    for view_name, outputs in views_data.items():
        output_path = os.path.join(example_dir, f"trajan_metrics_view{view_name.replace('view', '')}.npz")
        
        metrics = all_metrics[example_name][view_name]
        
        np.savez(
            output_path,
            average_jaccard=metrics['average_jaccard'],
            occlusion_accuracy=metrics['occlusion_accuracy'],
            full_metrics=metrics['full_metrics']
        )

        latents_path = os.path.join(example_dir, f"trajan_latents_view{view_name.replace('view', '')}.npz")
        latents = all_latents[example_name][view_name]
        
        np.savez(
            latents_path,
            encoded_latents=latents
        )

summary = {
    'final_avg_jaccard': final_avg_jaccard,
    'final_avg_occlusion': final_avg_occlusion,
    'total_views_processed': len(jaccard_scores),
    'all_jaccard_scores': jaccard_scores,
    'all_occlusion_scores': occlusion_scores,
    'scene': scene
}

np.save(os.path.join(results_dir, "overall_summary.npy"), summary)