# Example of how to use GraphDataset, GraphDataLoader and GNNs

In [1]:
import ase
from ase.db import connect
import h5py
import numpy as np
import torch
torch.set_default_dtype(torch.float64)

from pyggnn.data import List2GraphDataset, Hdf2GraphDataset, Db2GraphDataset, DataKeys, GraphLoader
from pyggnn.nn import ScaleShift
from pyggnn import EGNN, SchNet, DimeNet, DimeNetPlusPlus

DEVICE = torch.device("cpu")
CUTOFF = 4.0

## Dummy atoms data

### Polonium with Simple Cubic Lattice

In [2]:
po_lattice = torch.eye(3) * 3.340  # Cubic lattice with edges of length 3.34 AA
po_coords = torch.tensor([[0., 0., 0.,]])
po_types = ['Po']

# make atoms object
po = ase.Atoms(
    symbols=po_types,
    positions=po_coords,
    cell=po_lattice,
    pbc=True,
    # Additional data to be put in `info` parameters of atoms object.
    info={
        # you can add value data
        "energy": -12.0,
        # you can also add array-like data
        "dos": np.array([0.1, 0.2, 0.3]),
    }
)

### Silicon with Diamond Structure

In [3]:
si_lattice = torch.tensor([
    [0.      , 2.734364, 2.734364],
    [2.734364, 0.      , 2.734364],
    [2.734364, 2.734364, 0.      ],
])
si_coords = torch.tensor([
    [1.367182, 1.367182, 1.367182],
    [0.      , 0.      , 0.      ],
])
si_types = ['Si', 'Si']

# make atoms object
si = ase.Atoms(
    symbols=si_types,
    positions=si_coords,
    cell=si_lattice,
    pbc=True,
    info={
        "energy": -1.,
        "dos": np.array([0.5, 0.6, 0.7]),
    }
)

### isolated CO2 molecule

In [4]:
co2_lattice = torch.tensor([
    [15.      , 0.       , 0.    ],
    [0.       , 15.      , 0.    ],
    [0.       , 0.       , 15.   ],
])
co2_coords = torch.tensor([
     [0.5     , 0.5      , 0.5    ],
     [0.421632, 0.5      , 0.5    ],
     [0.578368, 0.5      , 0.5    ],
])
co2_types = ['C', 'O', 'O']

# make atoms object
# without pbc
co2 = ase.Atoms(
    symbols=co2_types,
    positions=co2_coords,
    cell=co2_lattice,
    pbc=False,
    info={
        "energy": -10.,
        "dos": np.array([0.8, 0.9, 1.0]),
    }
)

## Make dataset and dataloader

Data sets can be created in three ways.  
- List of ase.Atoms
- Database
- HDF5 file

In [5]:
# you can make dataset from list of atoms
dataset = List2GraphDataset(
    # passing a list of atoms objects
    [po, si, co2],
    cutoff_radi=CUTOFF,
    # passing additional property key names
    property_names=["energy", "dos"]
)

dataset[0]

Data(edge_index=[2, 6], pos=[1, 3], atom_numbers=[1], lattice=[1, 3, 3], edge_shift=[6, 3], energy=[1, 1], dos=[1, 3])

In [6]:
# from atoms, make database for data persistency
db = connect('example.db')
for atoms in [po, si, co2]:
    # additional data is set `data` parameters
    db.write(atoms=atoms, data=atoms.info)

# you can make dataset from database
dataset = Db2GraphDataset(
    "example.db",
    cutoff_radi=CUTOFF,
    property_names=["energy", "dos"]
)

dataset[0]

Data(edge_index=[2, 6], pos=[1, 3], atom_numbers=[1], lattice=[1, 3, 3], edge_shift=[6, 3], energy=[1, 1], dos=[1, 3])

In [7]:
# it can also be read from Hdf5 grouped by system
with h5py.File("example.hdf5", "w") as f:
    for atoms in [po, si, co2]:
        g = f.create_group(str(atoms.symbols))
        g.create_dataset(DataKeys.Lattice, data=np.array(atoms.cell))
        g.create_dataset(DataKeys.Position, data=atoms.positions)
        g.create_dataset(DataKeys.Atom_numbers, data=atoms.numbers)
        # you can set additional information in `attrs` or `dataset`
        g.attrs["energy"] = atoms.info["energy"]
        g.create_dataset("dos", data=atoms.info["dos"])
    
dataset = Hdf2GraphDataset(
    hdf5_path="example.hdf5",
    cutoff_radi=CUTOFF,
    property_names=["energy", "dos"]
)

dataset[0]

Data(edge_index=[2, 6], pos=[3, 3], atom_numbers=[3], lattice=[1, 3, 3], edge_shift=[6, 3], energy=[1, 1], dos=[1, 3])

In [8]:
batch_size = 2
# Inherits Pytorch's data loader, so the same keyword argument can be used.
dataloader = GraphLoader(dataset, batch_size=batch_size, shuffle=True)

# check each batch
for t in dataloader:
    print(t)

Data(edge_index=[2, 38], pos=[5, 3], atom_numbers=[5], lattice=[2, 3, 3], edge_shift=[38, 3], energy=[2, 1], dos=[2, 3], batch=[5])
Data(edge_index=[2, 6], pos=[1, 3], atom_numbers=[1], lattice=[1, 3, 3], edge_shift=[6, 3], energy=[1, 1], dos=[1, 3], batch=[1])


## Define model & Training

In [9]:
model = SchNet(
    node_dim=128,
    edge_filter_dim=128,
    out_dim=1,
    n_conv_layer=4,
    n_gaussian=32,
    scaler=ScaleShift,
    cutoff_radi=CUTOFF,
)
model.to(DEVICE)

SchNet(node_dim=128, edge_filter_dim=128, n_gaussian=32, cutoff=4.0, out_dim=1, convolution_layers: SchNetConv * 4)

In [10]:
for data in dataloader:
    data.to(DEVICE)
    out = model(data)
    # get same shape of energy
    print(out.shape==data["energy"].shape)
    print(out)

True
tensor([[ 0.9356],
        [-2.7731]], grad_fn=<ScatterAddBackward0>)
True
tensor([[1.6955]], grad_fn=<ScatterAddBackward0>)


In [11]:
model = EGNN(
    node_dim=128,
    edge_dim=128,
    n_conv_layer=4,
    out_dim=1,
)
model.to(DEVICE)

EGNN(node_dim=128, edge_dim=128, cutoff=None, out_dim=1, convolution_layers: EGNNConv * 4)

In [12]:
for data in dataloader:
    data.to(DEVICE)
    out = model(data)
    # get same shape of energy
    print(out.shape==data["energy"].shape)
    print(out)

True
tensor([[ 0.1691],
        [-2.3914]], grad_fn=<MmBackward0>)
True
tensor([[-1.4147]], grad_fn=<MmBackward0>)


In [13]:
model = DimeNet(
    edge_message_dim=128,
    n_interaction=4,
    out_dim=1,
    n_radial=16,
    n_spherical=4,
    n_bilinear=2, 
    envelope_exponent=5,
    cutoff_radi=CUTOFF,
)
model.to(DEVICE)

DimeNet(edge_message_dim=128, n_radial=16, n_spherical=4, cutoff=4.0, out_dim=1, interaction_layers: DimNetInteraction * 4)

In [14]:
for data in dataloader:
    data.to(DEVICE)
    out = model(data)
    # get same shape of energy
    print(out.shape==data["energy"].shape)
    print(out)

True
tensor([[ 0.0217],
        [-0.0235]], grad_fn=<ScatterAddBackward0>)
True
tensor([[-2.2971e+33]], grad_fn=<ScatterAddBackward0>)


In [15]:
model = DimeNetPlusPlus(
    edge_message_dim=128,
    n_interaction=4,
    out_dim=1,
    n_radial=16,
    n_spherical=4,
    edge_down_dim=64,
    basis_embed_dim=64,
    out_up_dim=256,
    envelope_exponent=5,
    cutoff_radi=CUTOFF,
)
model.to(DEVICE)

DimeNetPlusPlus(edge_message_dim=128, n_radial=16, n_spherical=4, cutoff=4.0, out_dim=1, interaction_layers: DimNetPPInteraction * 4)

In [16]:
for data in dataloader:
    data.to(DEVICE)
    out = model(data)
    # get same shape of energy
    print(out.shape==data["energy"].shape)
    print(out)

True
tensor([[-0.9449],
        [-0.2211]], grad_fn=<ScatterAddBackward0>)
True
tensor([[7.3053e+35]], grad_fn=<ScatterAddBackward0>)
