diff --git a/README.md b/README.md index 86ab419..14eb5c1 100644 --- a/README.md +++ b/README.md @@ -17,15 +17,15 @@ import torch from geometric_vector_perceptron import GVP model = GVP( - dim_coors_in = 1024, + dim_vectors_in = 1024, dim_feats_in = 512, - dim_coors_out = 256, + dim_vectors_out = 256, dim_feats_out = 512 ) -feats, coors = (torch.randn(1, 512), torch.randn(1, 1024, 3)) +feats, vectors = (torch.randn(1, 512), torch.randn(1, 1024, 3)) -feats_out, coors_out = model(feats, coors) # (1, 256), (1, 512, 3) +feats_out, vectors_out = model(feats, vectors) # (1, 256), (1, 512, 3) ``` ## Citations diff --git a/geometric_vector_perceptron/geometric_vector_perceptron.py b/geometric_vector_perceptron/geometric_vector_perceptron.py index 98fe20d..1456f1f 100644 --- a/geometric_vector_perceptron/geometric_vector_perceptron.py +++ b/geometric_vector_perceptron/geometric_vector_perceptron.py @@ -6,37 +6,37 @@ class GVP(nn.Module): def __init__( self, *, - dim_coors_in, + dim_vectors_in, dim_feats_in, dim_feats_out, - dim_coors_out, + dim_vectors_out, feats_activation = nn.Sigmoid(), - coors_activation = nn.Sigmoid() + vectors_activation = nn.Sigmoid() ): super().__init__() - self.dim_coors_in = dim_coors_in + self.dim_vectors_in = dim_vectors_in self.dim_feats_in = dim_feats_in - self.dim_coors_out = dim_coors_out - dim_h = max(dim_coors_in, dim_coors_out) + self.dim_vectors_out = dim_vectors_out + dim_h = max(dim_vectors_in, dim_vectors_out) - self.Wh = nn.Parameter(torch.randn(dim_coors_in, dim_h)) - self.Wu = nn.Parameter(torch.randn(dim_h, dim_coors_out)) + self.Wh = nn.Parameter(torch.randn(dim_vectors_in, dim_h)) + self.Wu = nn.Parameter(torch.randn(dim_h, dim_vectors_out)) - self.coors_activation = coors_activation + self.vectors_activation = vectors_activation self.to_feats_out = nn.Sequential( nn.Linear(dim_h + dim_feats_in, dim_feats_out), feats_activation ) - def forward(self, feats, coors): - b, n, _, v, c = *feats.shape, *coors.shape + def forward(self, feats, vectors): + b, n, _, v, c = *feats.shape, *vectors.shape - assert c == 3 and v == self.dim_coors_in, 'coordinates have wrong dimensions' + assert c == 3 and v == self.dim_vectors_in, 'coordinates have wrong dimensions' assert n == self.dim_feats_in, 'scalar features have wrong dimensions' - Vh = einsum('b v c, v h -> b h c', coors, self.Wh) + Vh = einsum('b v c, v h -> b h c', vectors, self.Wh) Vu = einsum('b h c, h u -> b u c', Vh, self.Wu) sh = torch.norm(Vh, p = 2, dim = -1) @@ -45,6 +45,6 @@ def forward(self, feats, coors): s = torch.cat((feats, sh), dim = 1) feats_out = self.to_feats_out(s) - coors_out = self.coors_activation(vu) * Vu + vectors_out = self.vectors_activation(vu) * Vu - return feats_out, coors_out + return feats_out, vectors_out diff --git a/setup.py b/setup.py index 3bd0628..bcd5d16 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'geometric-vector-perceptron', packages = find_packages(), - version = '0.0.1', + version = '0.0.2', license='MIT', description = 'Geometric Vector Perceptron - Pytorch', author = 'Phil Wang', diff --git a/tests.py b/tests.py index 4486ebc..2af1a0c 100644 --- a/tests.py +++ b/tests.py @@ -7,21 +7,26 @@ def random_rotation(): q, r = torch.qr(torch.randn(3, 3)) return q +def diff_matrix(vectors): + b, _, d = vectors.shape + diff = vectors[..., None, :] - vectors[:, None, ...] + return diff.reshape(b, -1, d) + def test_equivariance(): R = random_rotation() model = GVP( - dim_coors_in = 1024, + dim_vectors_in = 1024, dim_feats_in = 512, - dim_coors_out = 256, + dim_vectors_out = 256, dim_feats_out = 512 ) feats = torch.randn(1, 512) - coors = torch.randn(1, 1024, 3) + vectors = torch.randn(1, 32, 3) - feats_out, coors_out = model(feats, coors) - feats_out_r, coors_out_r = model(feats, coors @ R) + feats_out, vectors_out = model(feats, diff_matrix(vectors)) + feats_out_r, vectors_out_r = model(feats, diff_matrix(vectors @ R)) - err = ((coors_out @ R) - coors_out_r).max() + err = ((vectors_out @ R) - vectors_out_r).max() assert err < TOL, 'equivariance must be respected'