## Convolutions with Neighbors
The computational efficiency of convolutions comes from interacting data that is consider "local". For an image, pixels are local if they are within the size of the convolutional filter. For point-wise operations, we have a lot of flexibility for how we define pair-wise interactions used for convolutions. In this notebook, we illustrate how to use two data preprocessing classes for creating "edges" and "edge attributes" used by the `e3nn.point.message_passing.Convolution` class for the interacting atoms within a specified radius. We give examples for both molecules and crystals (with periodic boundary conditions).

In [1]:
import torch
import e3nn
import e3nn.point.data_helpers as dh
import glob

torch.set_default_dtype(torch.float64)  # torch.float64 is the datatype we use for e3nn

### Import data from files
I'm using `pymatgen` simply because it's the package I'm most accustomed to. Feel free to use what you like.


In [2]:
import pymatgen
from pymatgen.core.structure import Molecule, Structure

In [3]:
# Set maximum radius for convolution and neighbor lists
r_max = 3.

# Molecules
mol_species = []
mol_coords = []
mol_atom_type_set = set()
for filename in glob.glob('data/*.xyz'):
    mol = Molecule.from_file(filename)
    mol_coords.append(torch.tensor(mol.cart_coords))
    species = list(map(str, mol.species))
    mol_atom_type_set = mol_atom_type_set.union(set(species))
    mol_species.append(species)
mol_atom_type_list = sorted(list(mol_atom_type_set)) 

A = len(mol_atom_type_list)
mol_Rs_in = [(A, 0, 1)]

## Create one-hot encoding of atom types
mol_features = []
for coords, species in zip(mol_coords, mol_species):
    N, _ = coords.shape
    feature = torch.zeros(N, A)
    atom_type_indices = [mol_atom_type_list.index(specie) for specie in species]
    feature[range(N), atom_type_indices] = 1.
    mol_features.append(feature)

## Create random labels
mol_labels = torch.randn(len(mol_features), 1)

# Crystals
xtal_species = []
xtal_coords = []
xtal_lattices = []
xtal_atom_type_set = set()
for filename in glob.glob('data/*.cif'):
    xtal = Structure.from_file(filename)
    xtal_coords.append(torch.tensor(xtal.cart_coords))
    species = list(map(str, xtal.species))
    xtal_atom_type_set = xtal_atom_type_set.union(set(species))
    xtal_species.append(species)
    xtal_lattices.append(torch.tensor(xtal.lattice.matrix.copy()))
xtal_atom_type_list = sorted(list(xtal_atom_type_set))

A = len(xtal_atom_type_list)
xtal_Rs_in = [(A, 0, 1)]

## Create one-hot encoding of atom types
xtal_features = []
for coords, species in zip(xtal_coords, xtal_species):
    N, _ = coords.shape
    feature = torch.zeros(N, A)
    atom_type_indices = [xtal_atom_type_list.index(specie) for specie in species]
    feature[range(N), atom_type_indices] = 1.
    xtal_features.append(feature)
    
## Create random labels
xtal_labels = torch.randn(len(xtal_features), 1)

Rs_out = [(1, 0, 1)]  # Single scalar with even parity for both molecules and crystals

### In `e3nn.point.data_helpers` we've defined two classes that inherit from the `torch_geometric.data.Data` class that are helpful for generating the neighbor lists (edges) doing point-wise convolutions....
`DataNeighbors` for neighbors with no boundary conditions (as one might use for molecules) and `DataPeriodicNeighbors` for  3D periodic boundary conditions (as one might use for crystals) using built in functions from `ase` and `pymatgen`, respectively.

In [4]:
mol_data = [
    dh.DataNeighbors(feature, coord, r_max, y=label, Rs_in=mol_Rs_in, Rs_out=Rs_out)
    for (feature, coord, label) in zip(mol_features, mol_coords, mol_labels)
]

# Create DataPeriodicNeighbors objects
xtal_data = [
    dh.DataPeriodicNeighbors(feature, coord, lattice, r_max, y=label, Rs_in=xtal_Rs_in, Rs_out=Rs_out)
    for (feature, coord, lattice, label) in zip(xtal_features, xtal_coords, xtal_lattices, xtal_labels)
]

In the case of thousands examples, these objects can be time consuming to generate. So let's store them in a file. To use in other notebooks

In [33]:
torch.save(mol_data, 'data/mol_data.torch')
torch.save(xtal_data, 'data/xtal_data.torch')

### Attributes are accessed in the following way

#### DataNeighbors

| Class attribute | Description |
|-----------------|-------------|
| `x` | features on atoms [N, rs.dim(Rs_in)]
| `pos` | atomic Cartesian coordinates [N, xyz] |
| `edge_index` | pairs of indices of atomic neighbors in the direction of source to target |
| `edge_attr` | relative distance vector from source to target |
| `y` | training labels either [N rs.dim(Rs_out)] if per atom or [rs.dim(Rs_out] if per structure |
| `Rs_in` | representation list of input |
| `Rs_out` | representation list of output |

#### DataPeriodicNeighbors

| Class attribute | Description |
|-----------------|-------------|
| `x` | features on atoms [N, rs.dim(Rs_in)]
| `pos` | atomic Cartesian coordinates [N, xyz] |
| `edge_index` | pairs of indices of atomic neighbors in the direction of source to target |
| `edge_attr` | relative distance vector from source to target|
| `y` | training labels either [N rs.dim(Rs_out)] if per atom or [rs.dim(Rs_out] if per structure |
| `lattice` | lattice of crystal |
| `Rs_in` | representation list of input |
| `Rs_out` | representation list of output |

Note, that for crystals, there may be several edges with identical `edge_index` but differing `edge_attr` representing neighbors due to different "images" of the same atom. For example, in the Si (silicon) structure there are two atoms in the unit cell and each atom is tetrahedrally coordinated by the other. Additionally, "self-interactions", edges between an atom and itself are also included by default.

In [35]:
print(xtal_data[1].edge_index)
print(xtal_data[1].edge_attr)  # Note that the self interactions have a relative distance vector with zeros.

tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1, 0, 0, 0, 0]])
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-2.2329e+00,  7.8946e-01,  5.8458e-10],
        [ 3.3751e-10, -2.3684e+00,  5.8458e-10],
        [ 1.1165e+00,  7.8946e-01, -1.9338e+00],
        [ 1.1165e+00,  7.8946e-01,  1.9338e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-1.1165e+00, -7.8946e-01, -1.9338e+00],
        [-1.1165e+00, -7.8946e-01,  1.9338e+00],
        [-3.3751e-10,  2.3684e+00, -5.8458e-10],
        [ 2.2329e+00, -7.8946e-01, -5.8458e-10]])
