Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High resolution image result with NaN features #33

Closed
TurtleSmoke opened this issue Apr 20, 2023 · 8 comments
Closed

High resolution image result with NaN features #33

TurtleSmoke opened this issue Apr 20, 2023 · 8 comments

Comments

@TurtleSmoke
Copy link

Hello,

I'm having an issue with Dinov2 while trying to use it with high-resolution images like the one available at this link. The problem is that the features returned by the model contain NaN values. This issue occurs with all four available models and is consistently present for images around the same size.

I would like to know if you have any ideas about what could be causing this problem. Here's an minimal example:

import torch
import numpy as np
import torchvision.transforms as T
from PIL import Image
import hubconf

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dino = hubconf.dinov2_vits14().to(device)  # Same issue with larger model
img = Image.open('4k.png')
pw, ph = np.array(img.size) // 14

transform = T.Compose([
    T.Resize((14 * ph, 14 * pw), interpolation=T.InterpolationMode.BICUBIC),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

tensor = transform(img)[:3].unsqueeze(0).to(device)
with torch.no_grad():
    features = dino.forward_features(tensor)['x_norm_patchtokens'][0]

print(features)  # NaN
@woctezuma
Copy link

woctezuma commented Apr 20, 2023

The linked image is 2144x1319. Maybe:

The models can accept larger images provided the image shapes are multiples of the patch size (14).
If this condition is not verified, the model will crop to the closest smaller multiple of the patch size.

Edit: Nevermind, you resize the image to a shape multiple of the patch size (14). Plus, the model would have cropped the image.

@ccharest93
Copy link

You could make a debug function that checks the current values of the embedding vector for NaN entries, then insert that function btw various layers of the model to see in which layer the NaN start to appear, it might help us help you.

@TurtleSmoke
Copy link
Author

After debugging the issue further, I found that the problematic function causing NaN values in the features is memory_efficient_attention, which is part of the xFormers library used by Dinov2. Here is the relevant code snippet from Dinov2's attention.py file:

class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
if not XFORMERS_AVAILABLE:
assert attn_bias is None, "xFormers is required for nested tensors usage"
return super().forward(x)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
if attn_bias is not None:
self_att_op = fmha.MemoryEfficientAttentionFlashAttentionOp
else:
self_att_op = None
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=self_att_op)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x

The output tensor is full of NaN during the forward pass of the first block.

@ccharest93
Copy link

#19 (comment)

Shows how to replace memory_efficient_attention with normal attention, could fix your issue.

@TurtleSmoke
Copy link
Author

Regrettably, I need this function for the purpose of saving memory since the image necessitates nearly 100 gigabytes of RAM, which surpasses my requirements.

@TurtleSmoke
Copy link
Author

I was somehow able to find an image that work without memory_efficient_attention but not with. I think xFormers might be struggling to handle attention when there are just too many patches to deal with. It's a bit confusing, to be honest.

@TurtleSmoke
Copy link
Author

It seems that I did not make enough research before creating this issue: [0.0.18] memory_efficient_attention NaNs when seqlen>32768 #719

I'll try upgrading and determine if this will resolve the problem.

@TurtleSmoke
Copy link
Author

TurtleSmoke commented Apr 21, 2023

It resolved the problem, it works with the version xformers==0.0.19.dev516

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants
@woctezuma @ccharest93 @TurtleSmoke and others