Skip to content

Commit

Permalink
vectors, not coors
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 14, 2021
1 parent 7e5e774 commit d34a924
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 26 deletions.
8 changes: 4 additions & 4 deletions README.md
Expand Up @@ -17,15 +17,15 @@ import torch
from geometric_vector_perceptron import GVP

model = GVP(
dim_coors_in = 1024,
dim_vectors_in = 1024,
dim_feats_in = 512,
dim_coors_out = 256,
dim_vectors_out = 256,
dim_feats_out = 512
)

feats, coors = (torch.randn(1, 512), torch.randn(1, 1024, 3))
feats, vectors = (torch.randn(1, 512), torch.randn(1, 1024, 3))

feats_out, coors_out = model(feats, coors) # (1, 256), (1, 512, 3)
feats_out, vectors_out = model(feats, vectors) # (1, 256), (1, 512, 3)
```

## Citations
Expand Down
30 changes: 15 additions & 15 deletions geometric_vector_perceptron/geometric_vector_perceptron.py
Expand Up @@ -6,37 +6,37 @@ class GVP(nn.Module):
def __init__(
self,
*,
dim_coors_in,
dim_vectors_in,
dim_feats_in,
dim_feats_out,
dim_coors_out,
dim_vectors_out,
feats_activation = nn.Sigmoid(),
coors_activation = nn.Sigmoid()
vectors_activation = nn.Sigmoid()
):
super().__init__()
self.dim_coors_in = dim_coors_in
self.dim_vectors_in = dim_vectors_in
self.dim_feats_in = dim_feats_in

self.dim_coors_out = dim_coors_out
dim_h = max(dim_coors_in, dim_coors_out)
self.dim_vectors_out = dim_vectors_out
dim_h = max(dim_vectors_in, dim_vectors_out)

self.Wh = nn.Parameter(torch.randn(dim_coors_in, dim_h))
self.Wu = nn.Parameter(torch.randn(dim_h, dim_coors_out))
self.Wh = nn.Parameter(torch.randn(dim_vectors_in, dim_h))
self.Wu = nn.Parameter(torch.randn(dim_h, dim_vectors_out))

self.coors_activation = coors_activation
self.vectors_activation = vectors_activation

self.to_feats_out = nn.Sequential(
nn.Linear(dim_h + dim_feats_in, dim_feats_out),
feats_activation
)

def forward(self, feats, coors):
b, n, _, v, c = *feats.shape, *coors.shape
def forward(self, feats, vectors):
b, n, _, v, c = *feats.shape, *vectors.shape

assert c == 3 and v == self.dim_coors_in, 'coordinates have wrong dimensions'
assert c == 3 and v == self.dim_vectors_in, 'coordinates have wrong dimensions'
assert n == self.dim_feats_in, 'scalar features have wrong dimensions'

Vh = einsum('b v c, v h -> b h c', coors, self.Wh)
Vh = einsum('b v c, v h -> b h c', vectors, self.Wh)
Vu = einsum('b h c, h u -> b u c', Vh, self.Wu)

sh = torch.norm(Vh, p = 2, dim = -1)
Expand All @@ -45,6 +45,6 @@ def forward(self, feats, coors):
s = torch.cat((feats, sh), dim = 1)

feats_out = self.to_feats_out(s)
coors_out = self.coors_activation(vu) * Vu
vectors_out = self.vectors_activation(vu) * Vu

return feats_out, coors_out
return feats_out, vectors_out
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'geometric-vector-perceptron',
packages = find_packages(),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'Geometric Vector Perceptron - Pytorch',
author = 'Phil Wang',
Expand Down
17 changes: 11 additions & 6 deletions tests.py
Expand Up @@ -7,21 +7,26 @@ def random_rotation():
q, r = torch.qr(torch.randn(3, 3))
return q

def diff_matrix(vectors):
b, _, d = vectors.shape
diff = vectors[..., None, :] - vectors[:, None, ...]
return diff.reshape(b, -1, d)

def test_equivariance():
R = random_rotation()

model = GVP(
dim_coors_in = 1024,
dim_vectors_in = 1024,
dim_feats_in = 512,
dim_coors_out = 256,
dim_vectors_out = 256,
dim_feats_out = 512
)

feats = torch.randn(1, 512)
coors = torch.randn(1, 1024, 3)
vectors = torch.randn(1, 32, 3)

feats_out, coors_out = model(feats, coors)
feats_out_r, coors_out_r = model(feats, coors @ R)
feats_out, vectors_out = model(feats, diff_matrix(vectors))
feats_out_r, vectors_out_r = model(feats, diff_matrix(vectors @ R))

err = ((coors_out @ R) - coors_out_r).max()
err = ((vectors_out @ R) - vectors_out_r).max()
assert err < TOL, 'equivariance must be respected'

0 comments on commit d34a924

Please sign in to comment.