-
Notifications
You must be signed in to change notification settings - Fork 301
Closed
Description
| distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook) |
Hey Phil, Thanks for another great implementation 😄 . Regarding the distance calculation in LFQ (linked) I think this only holds if you're comparing self.codebook to the quantization of x (i.e. both have constant norm).
import torch
xs = torch.randn(10,3)
ys = torch.randn(10,3)
#xs,ys = map(lambda x: x/torch.norm(x,dim=-1,keepdim=True), (xs,ys))
print(torch.argmin(torch.cdist(xs,ys),dim=-1))
print(torch.argmin(-torch.einsum("i d, c d -> ... i c", xs, ys),dim=-1))
# output:
#tensor([8, 4, 7, 2, 3, 0, 4, 7, 1, 4])
#tensor([3, 9, 3, 3, 3, 2, 5, 9, 9, 5])import torch
xs = torch.randn(10,3)
ys = torch.randn(10,3)
xs,ys = map(lambda x: x/torch.norm(x,dim=-1,keepdim=True), (xs,ys))
print(torch.argmin(torch.cdist(xs,ys),dim=-1))
print(torch.argmin(-torch.einsum("i d, c d -> ... i c", xs, ys),dim=-1))
# output:
#tensor([4, 7, 7, 7, 4, 7, 7, 8, 8, 8])
#tensor([4, 7, 7, 7, 4, 7, 7, 8, 8, 8])lucidrains
Metadata
Metadata
Assignees
Labels
No labels