Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot identify high-norm tokens #373

Open
riccardorenzulli opened this issue Feb 9, 2024 · 6 comments
Open

Cannot identify high-norm tokens #373

riccardorenzulli opened this issue Feb 9, 2024 · 6 comments

Comments

@riccardorenzulli
Copy link

Hello,

I'm having trouble identifying high-norm tokens, as mentioned in the "Vision Transformers Needs Registers" paper. I've seen that it is also mentioned at #293.

I used the L2 norm and the code from #306.
To get the embedding vectors of the last layer, I use
x_layers = model.get_intermediate_layers(img, [len(model.blocks)-1]).

I tried with ViT-G/14 on the full ImageNet validation set with and without registers; however, as you can see see in the images below, the norms of the model without registers are not higher than 150, as written in the paper.

image
image

Did anyone succeed in reproducing the results of the main paper and identifying these high-norm tokens?

@AndreaBrg
Copy link

I'm having the same issue; @patricklabatut, any idea?

@heyoeyo
Copy link

heyoeyo commented Feb 17, 2024

For what it's worth, I've seen this 'high norm' pattern occur with the dinov2-based image encoder used in the depth-anything model. It happens on the vit-l, vit-b and to some extent even on the vit-s model. A similar pattern appears using the 'ViT-L/14 distilled' backbone (from the dinov2 listing), but it's only visible on internal blocks.

Here are the norms of the different output blocks for vit-l (the depth-anything version) running on a picture of a turtle:

vit-l block norms

blocknorms_depth_anything_vitl14_turtle_504

Here are the reported min/max norms for the last few blocks:

Block Min Max
14 4.68 6.44
15 5.87 9.12
16 6.98 11.04
17 9.17 175.08
18 12.43 320.31
19 16.54 509.3
20 22.94 517.46
21 33.25 532.22
22 50.88 569.05
23 99.46 389.04

Here are some more examples:

vit-l block norms at half resolution

blocknorms_depth_anything_vitl14_turtle_252

vit-b block norms

blocknorms_depth_anything_vitb14_turtle_504

vit-s block norms

blocknorms_depth_anything_vits14_turtle_504

Original ViT-L/14 distilled block norms

blocknorms_dinov2_vitl_orig_turtle_504

And the last few block min/max norms for comparison:

Block Min Max
14 4.59 7.11
15 5.83 9.58
16 7.13 10.4
17 9.51 188.77
18 12.61 342.56
19 16.85 538.95
20 23.87 549.13
21 30.23 563.75
22 39.58 618.84
23 66.03 121.01
beit-large-512 block norms

blocknorms_dpt_beit_large_512_turtle_512

input image (downscaled for display)

turtle

Some notes:

  • The final output block always 'recovers' slightly, so the output range isn't as exaggerated. Though worth noting, the depth-anything models use the last 4 blocks as outputs, so it's odd the final block has unique behavior
  • When using the original 'ViT-L/14 distilled' weights, the pattern disappears on the final block, but is visible on internal blocks
  • I've tried vit-b on a few images and it always places the high norm tokens in the top-left...
  • For vit-l, the placement of the high norm tokens is generally different for different images, and will change (for the same image) when adjusting the input resolution (see the two vit-l images)
  • vit-s doesn't exhibit the pattern, but internally (i.e. before the final block) it does seem to be trending towards having hot-spots
  • By comparison, beit-large-512 (and smaller variants) don't seem to show the pattern at all. Edit: On closer inspection, the last few beit blocks all have min and max norm > 200, but no major outliers

Obviously it's not a conclusive result, I've only tried this on a few images, but it does seem similar to the effect described in the 'register' paper.

@heyoeyo
Copy link

heyoeyo commented Feb 19, 2024

As a quick follow-up, I've tried this with the original dinov2 model & weights and got the same results. The original weights always have smaller norms on their final output (compared to the depth-anything weights), but vit-b & vit-l both show high norms internally. Results from vit-g have high norms even on the final output.

Here is an animation of the vit-g block norms (first-to-last) showing qualitatively similar results to the paper:
vitg_blocknorm_anim

The 'with registers' versions of the models don't completely get rid of high norms in the later layers, but they do get rid of outliers.

For anyone wanting to try this, here's some code that uses the dinov2 repo/models and prints out the min & max norms for each block. Just make sure to set an image path and model name at the top of the script (use any of the pretrained backbone names from the repo listing):

Code for printing block norms
import cv2
import torch
import numpy as np
from dinov2.layers.block import Block

# Setup
model_name = "dinov2_vitl14" # compare with "dinov2_vitl14_reg"
image_path = "path/to/image.jpg"
device, dtype = "cuda", torch.float32
img_size_wh = (518, 518)
img_mean, img_std = [0.485,0.456,0.406], [0.229, 0.224, 0.225]

# Load & prepare image
orig_img_bgr = cv2.imread(image_path)
img_rgb = cv2.cvtColor(orig_img_bgr, cv2.COLOR_BGR2RGB)
img_rgb = cv2.resize(img_rgb, dsize=img_size_wh)
img_rgb = (np.float32(img_rgb / 255.0) - img_mean) / img_std
img_rgb = np.transpose(img_rgb, (2, 0, 1))
img_tensor = torch.from_numpy(img_rgb).unsqueeze(0).to(device, dtype)

# Load model
model = torch.hub.load("facebookresearch/dinov2", model=model_name)
model.to(device, dtype)
model.eval()

# Capture transformer block outputs
captures = []
hook_func = lambda m, inp, out: captures.append(out)
for m in model.modules():
    if isinstance(m, Block):
        m.register_forward_hook(hook_func)
with torch.no_grad(): model(img_tensor)

# Figure out how many global tokens we'll need to remove
# (assuming we only get norms of image-patch tokens)
has_cls_token = model.cls_token is not None
num_global_tokens = model.num_register_tokens + int(has_cls_token)

# Print out norm info
print(f"Block norms (min & max) for {model_name}")
for idx, output in enumerate(captures):
    patch_tokens = output[:, num_global_tokens:, :] # Remove cls & reg tokens
    norms = patch_tokens.norm(dim=2).cpu().float().numpy()
    min_str = str(round(norms.min())).rjust(3)
    max_str = str(round(norms.max())).rjust(3)
    print(f"B{idx}:".rjust(4), f"[{min_str}, {max_str}]")

And here's some code that can be added to the end of the code above for generating the visualizations (it pops up a window, so you need to be running the code locally).

Code for visualizations
# Figure out patch sizing, for converting back to image-like shape
input_hw = img_tensor.shape[2:]
patch_size_hw = model.patch_embed.patch_size
patch_grid_hw = [x // p for x, p in zip(input_hw, patch_size_hw)]

# For displaying as an image
for idx, output in enumerate(captures):
    
    # Get tokens into image-like shape 
    patch_tokens = output[:, num_global_tokens:, :]
    imglike_tokens = torch.transpose(patch_tokens, 1, 2)
    imglike_tokens = torch.unflatten(imglike_tokens, 2, patch_grid_hw).squeeze().float()
    imglike_norms = imglike_tokens.norm(dim=0)
         
    # Make image easier to view
    min_norm, max_norm = imglike_norms.min(), imglike_norms.max()
    norm_disp = ((imglike_norms - min_norm) / (max_norm - min_norm)).cpu().float().numpy()
    norm_disp = cv2.resize(norm_disp, dsize=None, fx=8, fy=8, interpolation=cv2.INTER_NEAREST_EXACT)
    cmap_disp = cv2.applyColorMap(np.uint8(255*norm_disp), cv2.COLORMAP_VIRIDIS)
    
    cv2.imshow("Block norms", cmap_disp)
    cv2.waitKey(250)

cv2.destroyAllWindows()

@AndreaBrg
Copy link

@heyoeyo Thanks for the thorough explanations. I'll take a look.

@riccardorenzulli
Copy link
Author

Thank you very much @heyoeyo for your help and insights. We discovered that the problem in our code was the default value set to True for the norm argument in x_layers = model.get_intermediate_layers(img, [len(model.blocks)-1]). By adding norm=False and collecting the embeddings for all layers, we get the same results as yours.

As you pointed out, surprisingly, the norms in a model without registers in the last layers are not that high, while for the model with registers, the norms become high but without outliers. I was surprised about this, especially given Figures 7 and 15 of the paper.

@heyoeyo
Copy link

heyoeyo commented Feb 21, 2024

I was surprised about this, especially given Figures 7 and 15 of the paper.

Agreed! The output layer of the vitl-reg model has norms in the 150-400 range for the few images I've tried, as opposed to the <50 range reported by the paper.

I also find figure 3 vs 7 & 15 to be confusing, as fig. 3 suggests a non-register high-norm range of ~200-600 (consistent with what I've seen), whereas fig 7 & 15 show a 100-200 range for the high-norm tokens. Though I may be misinterpreting the plots.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants