Skip to content

Commit

Permalink
add example of how to use pair finding without model fitting (#40)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicholas Lubbers <hippynn@lanl.gov>
  • Loading branch information
lubbersnick and Nicholas Lubbers committed Oct 23, 2023
1 parent 7b93c04 commit 450a517
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions examples/periodic_pairfinding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from hippynn.graphs import GraphModule
from hippynn.graphs.nodes.inputs import SpeciesNode, PositionsNode, CellNode
from hippynn.graphs.nodes.indexers import acquire_encoding_padding
from hippynn.graphs.nodes.pairs import PeriodicPairIndexer


n_atom = 30
n_system = 30
n_dim = 3
distance_cutoff = 0.3

if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"

floatX = torch.float32

# Set up input nodes
sp = SpeciesNode("Z")
pos = PositionsNode("R")
cell = CellNode("C")

# Set up and compile calculation
enc, pidxer = acquire_encoding_padding(sp, species_set=[0, 1])
pairfinder = PeriodicPairIndexer("pair finder", (pos, enc, pidxer, cell), dist_hard_max=distance_cutoff)
computer = GraphModule([sp, pos, cell], [*pairfinder.children])
computer.to(device)

# Get some random inputs
species_tensor = torch.ones(n_system, n_atom, device=device, dtype=torch.int64)
pos_tensor = torch.rand(n_system, n_atom, 3, device=device, dtype=floatX)
cell_tensor = torch.eye(3, 3, device=device, dtype=floatX).unsqueeze(0).expand(n_system, n_dim, n_dim).clone()

# Run calculation
outputs = computer(species_tensor, pos_tensor, cell_tensor)

# Print outputs
output_as_dict = {c.name: o for c, o in zip(pairfinder.children, outputs)}
for k, v in output_as_dict.items():
print(k, v.shape, v.dtype, v.min(), v.max())

0 comments on commit 450a517

Please sign in to comment.