# Example of how to use GraphDataset, GraphDataLoader and GNNs

In [1]:
import ase
import numpy as np
import torch
torch.set_default_dtype(torch.float64)

from pyggnn.data import List2GraphDataset, Hdf2GraphDataset, GraphLoader
from pyggnn.nn import ScaleShift
from pyggnn import EGNN, SchNet, DimeNet

DEVICE = torch.device("cpu")
CUTOFF = 4.0

### Dummy atoms data

In [14]:
# Polonium with Simple Cubic Lattice
po_lattice = torch.eye(3) * 3.340  # Cubic lattice with edges of length 3.34 AA
po_coords = torch.tensor([[0., 0., 0.,]])
po_types = ['Po']

po = ase.Atoms(
    symbols=po_types,
    positions=po_coords,
    cell=po_lattice,
    pbc=True,
    # Additional data to be put in `info` parameters of atoms object.
    info={
        # you can add value data
        "energy":12.0,
        # you can also add array-like data
        "dos": np.array([0.1, 0.2, 0.3]),
    }
)

# Silicon with Diamond Structure
si_lattice = torch.tensor([
    [0.      , 2.734364, 2.734364],
    [2.734364, 0.      , 2.734364],
    [2.734364, 2.734364, 0.      ],
])
si_coords = torch.tensor([
    [1.367182, 1.367182, 1.367182],
    [0.      , 0.      , 0.      ],
])
si_types = ['Si', 'Si']


si = ase.Atoms(
    symbols=si_types,
    positions=si_coords,
    cell=si_lattice,
    pbc=True,
    info={
        "energy":1.,
        "dos": np.array([0.5, 0.6, 0.7]),
    }
)

### Make dataset and dataloader

In [3]:
dataset = List2GraphDataset(
    # passing a list of atoms objects
    [po, si],
    cutoff_radi=CUTOFF,
    # passing additional properties
    property_names=["energy", "dos"]
)

# if you have hdf5 file, you can also use Hdf52GraphDataset
dataset = Hdf2GraphDataset(
    hdf5_path="./",
    cutoff_radi=CUTOFF,
    property_names=["energy", "dos"]
)

In [4]:
batch_size = 2
# Inherits Pytorch's data loader, so the same keyword argument can be used.
dataloader = GraphLoader(dataset, batch_size=batch_size, shuffle=True)

### Define model

In [15]:
model = SchNet(
    node_dim=128,
    edge_dim=128,
    out_dim=1,
    n_conv_layer=4,
    n_gaussian=32,
    scaler=ScaleShift,
    cutoff_radi=CUTOFF,
)
model.to(DEVICE)

SchNet(node_dim=128, edge_dim=128, n_gaussian=32, n_conv_layer=4, cutoff=4.0, out_dim=1)

In [16]:
for data in dataloader:
    data.to(DEVICE)
    out = model(data)
    # get same shape of energy
    print(out.shape==data["energy"].shape)
    print(out)

True
tensor([[0.1551],
        [0.4404]], grad_fn=<ScatterAddBackward0>)


In [17]:
model = EGNN(
    node_dim=128,
    edge_dim=128,
    n_conv_layer=4,
    out_dim=1,
)
model.to(DEVICE)

EGNN(node_dim=128, edge_dim=128, n_conv_layer=4, cutoff=None, out_dim=1)

In [18]:
for data in dataloader:
    data.to(DEVICE)
    out = model(data)
    # get same shape of energy
    print(out.shape==data["energy"].shape)
    print(out)

True
tensor([[ 0.6844],
        [16.1460]], grad_fn=<MmBackward0>)


In [None]:
model = DimeNet(
    node_dim=128,
    edge_dim=128,
    n_interaction=4,
    out_dim=1,
    n_radial=16,
    n_spherical=4,
    n_bilinear=2, 
    envelope_exponent=5,
    cutoff_radi=CUTOFF,
)
model.to(DEVICE)

In [None]:
for data in dataloader:
    data.to(DEVICE)
    out = model(data)
    # get same shape of energy
    print(out.shape==data["energy"].shape)
    print(out)