In [1]:
import ase
import ase.neighborlist
import numpy as np
import torch
import torch_geometric.data
default_dtype = torch.float64
torch.set_default_dtype(default_dtype)

from pyggnn import EGNN

DEVICE = torch.device("cpu")
CUTOFF = 4.0

### Dummy atom data

In [2]:
# Polonium with Simple Cubic Lattice
po_lattice = torch.eye(3) * 3.340  # Cubic lattice with edges of length 3.34 AA
po_coords = torch.tensor([[0., 0., 0.,]])
po_types = ['Po']

po = ase.Atoms(symbols=po_types, positions=po_coords, cell=po_lattice, pbc=True)

In [3]:
# Silicon with Diamond Structure
si_lattice = torch.tensor([
    [0.      , 2.734364, 2.734364],
    [2.734364, 0.      , 2.734364],
    [2.734364, 2.734364, 0.      ]
])
si_coords = torch.tensor([
    [1.367182, 1.367182, 1.367182],
    [0.      , 0.      , 0.      ]
])
si_types = ['Si', 'Si']


si = ase.Atoms(symbols=si_types, positions=si_coords, cell=si_lattice, pbc=True)

### Make dataset and dataloader

In [4]:
dataset = []

dummy_energies = torch.randn(2, 1, 1)  # dummy energies for example

for crystal, energy in zip([po, si], dummy_energies):
    # edge_src and edge_dst are the indices of the central and neighboring atom, respectively
    # edge_shift indicates whether the neighbors are in different images / copies of the unit cell
    edge_src, edge_dst, edge_shift = ase.neighborlist.neighbor_list(
        "ijS",
        a=crystal, 
        cutoff=CUTOFF,
        self_interaction=True
    )

    data = torch_geometric.data.Data(
        pos=torch.tensor(crystal.get_positions()),
        lattice=torch.tensor(crystal.cell.array).unsqueeze(0),  # We add a dimension for batching
        atomic_num=torch.tensor(crystal.numbers),  # Using atomic num
        edge_index=torch.stack([torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], dim=0),
        edge_shift=torch.tensor(edge_shift, dtype=default_dtype),
        energy=energy  # dummy energy (assumed to be normalized "per atom")
    )
    dataset.append(data)

print(dataset)

[Data(edge_index=[2, 7], pos=[1, 3], lattice=[1, 3, 3], atomic_num=[1], edge_shift=[7, 3], energy=[1, 1]), Data(edge_index=[2, 34], pos=[2, 3], lattice=[1, 3, 3], atomic_num=[2], edge_shift=[34, 3], energy=[1, 1])]


In [5]:
batch_size = 2
dataloader = torch_geometric.loader.DataLoader(dataset, batch_size=batch_size)

### Define model

In [6]:
model = EGNN(
    node_dim=256,
    edge_dim=256,
    n_conv_layer=5,
    out_dim=1,
    activation="swish",
    aggr="add",
    beta=1.0,
)
model.to(DEVICE)

EGNN(
  (node_initialize): AtomicNum2Node(100, 256)
  (convs): ModuleList(
    (0): EGNNConv()
    (1): EGNNConv()
    (2): EGNNConv()
    (3): EGNNConv()
    (4): EGNNConv()
  )
  (output): Node2Property(
    (node_transform): Sequential(
      (0): Dense(in_features=256, out_features=256, bias=True)
      (1): Swish()
      (2): Dense(in_features=256, out_features=256, bias=True)
    )
    (predict): Sequential(
      (0): Dense(in_features=256, out_features=256, bias=True)
      (1): Swish()
      (2): Dense(in_features=256, out_features=1, bias=False)
    )
  )
)

In [7]:
for data in dataloader:
    out = model(data)
    # get same shape of energy
    print(out.shape==data["energy"].shape)

True


In [8]:
out

tensor([[ -2.7241],
        [-38.0073]], grad_fn=<MmBackward0>)