Skip to content

Commit

Permalink
Fix interpolate parameters to allow tracing (#247)
Browse files Browse the repository at this point in the history
Pass scale factor as a tuple of floats to F.interpolate() to allow tracing.
  • Loading branch information
patricklabatut committed Sep 30, 2023
1 parent e7df9fc commit 44abdbe
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions dinov2/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,16 @@ def interpolate_pos_encoding(self, x, w, h):
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1

sqrt_N = math.sqrt(N)
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy),
mode="bicubic",
)

assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
assert int(w0) == patch_pos_embed.shape[-2]
assert int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)

Expand Down

0 comments on commit 44abdbe

Please sign in to comment.