From f084a93f28ad71c35f8fbdf2aeb3b2fc551a4c7a Mon Sep 17 00:00:00 2001 From: Nikita Karaev Date: Thu, 4 Jan 2024 16:53:22 +0000 Subject: [PATCH] fix multi-batch inference --- cotracker/predictor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cotracker/predictor.py b/cotracker/predictor.py index dc2cd96..65875fe 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -23,11 +23,11 @@ def __init__(self, checkpoint="./checkpoints/cotracker2.pth"): @torch.no_grad() def forward( self, - video, # (1, T, 3, H, W) + video, # (B, T, 3, H, W) # input prompt types: # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame. # *backward_tracking=True* will compute tracks in both directions. - # - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates. + # - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates. # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask. # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks. queries: torch.Tensor = None, @@ -120,13 +120,14 @@ def _compute_sparse_tracks( queries = torch.cat( [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], dim=2, - ) + ).repeat(B, 1, 1) if add_support_grid: grid_pts = get_points_on_a_grid( self.support_grid_size, self.interp_shape, device=video.device ) grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2) + grid_pts = grid_pts.repeat(B, 1, 1) queries = torch.cat([queries, grid_pts], dim=1) tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6) @@ -173,7 +174,7 @@ def _compute_backward_tracks(self, video, queries, tracks, visibilities): inv_visibilities = inv_visibilities.flip(1) arange = torch.arange(video.shape[1], device=queries.device)[None, :, None] - mask = (arange < queries[None, :, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2) + mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2) tracks[mask] = inv_tracks[mask] visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]