Skip to content

Commit

Permalink
fix multi-batch inference
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitakaraevv committed Jan 4, 2024
1 parent 3716e36 commit f084a93
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions cotracker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
@torch.no_grad()
def forward(
self,
video, # (1, T, 3, H, W)
video, # (B, T, 3, H, W)
# input prompt types:
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
# *backward_tracking=True* will compute tracks in both directions.
# - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates.
# - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
queries: torch.Tensor = None,
Expand Down Expand Up @@ -120,13 +120,14 @@ def _compute_sparse_tracks(
queries = torch.cat(
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
dim=2,
)
).repeat(B, 1, 1)

if add_support_grid:
grid_pts = get_points_on_a_grid(
self.support_grid_size, self.interp_shape, device=video.device
)
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
grid_pts = grid_pts.repeat(B, 1, 1)
queries = torch.cat([queries, grid_pts], dim=1)

tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
Expand Down Expand Up @@ -173,7 +174,7 @@ def _compute_backward_tracks(self, video, queries, tracks, visibilities):
inv_visibilities = inv_visibilities.flip(1)
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]

mask = (arange < queries[None, :, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)

tracks[mask] = inv_tracks[mask]
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
Expand Down

0 comments on commit f084a93

Please sign in to comment.