-
-
Notifications
You must be signed in to change notification settings - Fork 5k
Closed
Description
pytorch-image-models/timm/layers/attention.py
Lines 75 to 76 in 019550e
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv.unbind(0) |
I'm using Ambarella transfering toolchain, which only support 4D tensors. This part, which uses a 5D tensor, causes an error every time I call the self-attention function.
I'm just proposing a method here that uses only 4D tensors for replacement.
H, D = self.num_heads, self.head_dim
qkv = self.qkv(x) # (B, N, 3*H*D)
q = qkv[:, :, :H*D].reshape(B, N, H, D).transpose(1, 2) # (B, H, N, D)
k = qkv[:, :, H*D:2*H*D].reshape(B, N, H, D).transpose(1, 2) # (B, H, N, D)
v = qkv[:, :, 2*H*D:].reshape(B, N, H, D).transpose(1, 2) # (B, H, N, D)
I know the 5D tensor version is more readable, but the 4D tensor version solved my problem.
Metadata
Metadata
Assignees
Labels
No labels