# Example of how to use GraphDataset, GraphDataLoader and GNNs

In [1]:
import ase
import numpy as np
import torch
torch.set_default_dtype(torch.float64)

from pyggnn.data import List2GraphDataset, Hdf2GraphDataset, GraphLoader
from pyggnn.nn import ScaleShift
from pyggnn import EGNN, SchNet, DimeNet, DimeNetPlusPlus

DEVICE = torch.device("cpu")
CUTOFF = 4.0

## Dummy atoms data

### Polonium with Simple Cubic Lattice

In [16]:
po_lattice = torch.eye(3) * 3.340  # Cubic lattice with edges of length 3.34 AA
po_coords = torch.tensor([[0., 0., 0.,]])
po_types = ['Po']

# make atoms object
po = ase.Atoms(
    symbols=po_types,
    positions=po_coords,
    cell=po_lattice,
    pbc=True,
    # Additional data to be put in `info` parameters of atoms object.
    info={
        # you can add value data
        "energy": -12.0,
        # you can also add array-like data
        "dos": np.array([0.1, 0.2, 0.3]),
    }
)

### Silicon with Diamond Structure

In [17]:
si_lattice = torch.tensor([
    [0.      , 2.734364, 2.734364],
    [2.734364, 0.      , 2.734364],
    [2.734364, 2.734364, 0.      ],
])
si_coords = torch.tensor([
    [1.367182, 1.367182, 1.367182],
    [0.      , 0.      , 0.      ],
])
si_types = ['Si', 'Si']

# make atoms object
si = ase.Atoms(
    symbols=si_types,
    positions=si_coords,
    cell=si_lattice,
    pbc=True,
    info={
        "energy": -1.,
        "dos": np.array([0.5, 0.6, 0.7]),
    }
)

### isolated CO2 molecule

In [19]:
co2_lattice = torch.tensor([
    [15.      , 0.       , 0.    ],
    [0.       , 15.      , 0.    ],
    [0.       , 0.       , 15.   ],
])
co2_coords = torch.tensor([
     [0.5     , 0.5      , 0.5    ],
     [0.421632, 0.5      , 0.5    ],
     [0.578368, 0.5      , 0.5    ],
])
co2_types = ['C', 'O', 'O']

# make atoms object
# without pbc
co2 = ase.Atoms(
    symbols=co2_types,
    positions=co2_coords,
    cell=co2_lattice,
    pbc=False,
    info={
        "energy": -10.,
        "dos": np.array([0.8, 0.9, 1.0]),
    }
)

## Make dataset and dataloader

In [21]:
dataset = List2GraphDataset(
    # passing a list of atoms objects
    [po, si, co2],
    cutoff_radi=CUTOFF,
    # passing additional property key names
    property_names=["energy", "dos"]
)

# if you have hdf5 file, you can also use Hdf52GraphDataset
dataset = Hdf2GraphDataset(
    hdf5_path="./",
    cutoff_radi=CUTOFF,
    property_names=["energy", "dos"]
)

In [25]:
batch_size = 3
# Inherits Pytorch's data loader, so the same keyword argument can be used.
dataloader = GraphLoader(dataset, batch_size=batch_size, shuffle=True)

# check each batch
for t in dataloader:
    print(t)

Data(edge_index=[2, 44], pos=[6, 3], atom_numbers=[6], lattice=[3, 3, 3], edge_shift=[44, 3], energy=[3, 1], dos=[3, 3], batch=[6])


## Define model & Training

In [26]:
model = SchNet(
    node_dim=128,
    edge_filter_dim=128,
    out_dim=1,
    n_conv_layer=4,
    n_gaussian=32,
    scaler=ScaleShift,
    cutoff_radi=CUTOFF,
)
model.to(DEVICE)

SchNet(node_dim=128, edge_filter_dim=128, n_gaussian=32, cutoff=4.0, out_dim=1, convolution_layers: SchNetConv * 4)

In [27]:
for data in dataloader:
    data.to(DEVICE)
    out = model(data)
    # get same shape of energy
    print(out.shape==data["energy"].shape)
    print(out)

True
tensor([[-1.7342],
        [-0.1156],
        [ 0.7659]], grad_fn=<ScatterAddBackward0>)


In [28]:
model = EGNN(
    node_dim=128,
    edge_dim=128,
    n_conv_layer=4,
    out_dim=1,
)
model.to(DEVICE)

EGNN(node_dim=128, edge_dim=128, cutoff=None, out_dim=1, convolution_layers: EGNNConv * 4)

In [29]:
for data in dataloader:
    data.to(DEVICE)
    out = model(data)
    # get same shape of energy
    print(out.shape==data["energy"].shape)
    print(out)

True
tensor([[ -1.9652],
        [ -1.3788],
        [-27.5988]], grad_fn=<MmBackward0>)


In [30]:
model = DimeNet(
    edge_message_dim=128,
    n_interaction=4,
    out_dim=1,
    n_radial=16,
    n_spherical=4,
    n_bilinear=2, 
    envelope_exponent=5,
    cutoff_radi=CUTOFF,
)
model.to(DEVICE)

DimeNet(edge_message_dim=128, n_radial=16, n_spherical=4, cutoff=4.0, out_dim=1, interaction_layers: DimNetInteraction * 4)

In [31]:
for data in dataloader:
    data.to(DEVICE)
    out = model(data)
    # get same shape of energy
    print(out.shape==data["energy"].shape)
    print(out)

True
tensor([[ 1.3469e+34],
        [ 1.7107e-02],
        [-3.3512e+00]], grad_fn=<ScatterAddBackward0>)


In [32]:
model = DimeNetPlusPlus(
    edge_message_dim=128,
    n_interaction=4,
    out_dim=1,
    n_radial=16,
    n_spherical=4,
    edge_down_dim=64,
    basis_embed_dim=64,
    out_up_dim=256,
    envelope_exponent=5,
    cutoff_radi=CUTOFF,
)
model.to(DEVICE)

DimeNetPlusPlus(edge_message_dim=128, n_radial=16, n_spherical=4, cutoff=4.0, out_dim=1, interaction_layers: DimNetPPInteraction * 4)

In [33]:
for data in dataloader:
    data.to(DEVICE)
    out = model(data)
    # get same shape of energy
    print(out.shape==data["energy"].shape)
    print(out)

True
tensor([[-8.8331e+35],
        [ 1.5179e-01],
        [-6.7252e-02]], grad_fn=<ScatterAddBackward0>)
