Skip to content

Commit

Permalink
Little cleanup in _tensornet.py
Browse files Browse the repository at this point in the history
Signed-off-by: Tsz Wai Ko <47970742+kenko911@users.noreply.github.com>
  • Loading branch information
kenko911 committed May 16, 2024
1 parent b132d4c commit 9b1e406
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions src/matgl/models/_tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,10 @@ def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, **kwa
g.edata["bond_vec"] = bond_vec.to(g.device)
g.edata["bond_dist"] = bond_dist.to(g.device)

# This asserts convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor]

# Expand distances with radial basis functions
edge_attr = self.bond_expansion(g.edata["bond_dist"])
g.edata["edge_attr"] = edge_attr
# Embedding from edge-wise tensors to node-wise tensors
# Embedding layer
X, edge_feat, state_feat = self.tensor_embedding(g, state_attr)
# Interaction layers
for layer in self.layers:
Expand Down

0 comments on commit 9b1e406

Please sign in to comment.