# 具有周期边界条件的点输入

这个例子展示了如何将具有周期性边界条件的点输入(例如晶体数据)提供给用e3nn构建的欧几里得神经网络。对于特定的应用程序，应该使用更适合的网络设计修改此代码。

In [2]:
import torch
import e3nn
import ase
import ase.neighborlist
import torch_geometric
import torch_geometric.data

default_dtype = torch.float64
torch.set_default_dtype(default_dtype)



## 晶体结构示例

首先，我们创建了一些具有周期性边界条件的晶体结构。

In [3]:
# A lattice is a 3 x 3 matrix
# The first index is the lattice vector (a, b, c)
# The second index is a Cartesian index over (x, y, z)

# Polonium with Simple Cubic Lattice
po_lattice = torch.eye(3) * 3.340  # Cubic lattice with edges of length 3.34 AA
po_coords = torch.tensor([[0., 0., 0.,]])
po_types = ['Po']

# Silicon with Diamond Structure
si_lattice = torch.tensor([
    [0.      , 2.734364, 2.734364],
    [2.734364, 0.      , 2.734364],
    [2.734364, 2.734364, 0.      ]
])
si_coords = torch.tensor([
    [1.367182, 1.367182, 1.367182],
    [0.      , 0.      , 0.      ]
])
si_types = ['Si', 'Si']

po = ase.Atoms(symbols=po_types, positions=po_coords, cell=po_lattice, pbc=True)
si = ase.Atoms(symbols=si_types, positions=si_coords, cell=si_lattice, pbc=True)

print("po", po)
print("si", si)

po Atoms(symbols='Po', pbc=True, cell=[3.34, 3.34, 3.34])
si Atoms(symbols='Si2', pbc=True, cell=[[0.0, 2.734364, 2.734364], [2.734364, 0.0, 2.734364], [2.734364, 2.734364, 0.0]])


## 创建和存储周期性图形数据

我们使用[`ase.neighborlist.neighbor_list`](https://wiki.fysik.dtu.dk/ase/ase/neighborlist.html#ase.neighborlist.neighbor_list)算法和`radial_cutoff `距离来定义图中包含哪些边来表示与相邻原子的相互作用。请注意，对于卷积网络，层数决定了接受域，即任何给定原子可以看到多远。请注意，对于卷积网络，层数决定了接受域，即任何给定原子可以看到多远。因此，即使我们使用`radial_cutoff = 3.5`，两层网络有效地看到`2 * 3.5 = 7`个距离单位(在这种情况下是埃)，三层网络有效地看到`3 * 3.5 = 10.5`个距离单位。然后我们将数据存储在[`torch_geometric.data.Data`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data)对象中，我们将在下面使用`torch_geometric.data.DataLoader`进行批处理。


In [4]:
radial_cutoff = 3.5  # Only include edges for neighboring atoms within a radius of 3.5 Angstroms.
type_encoding = {'Po': 0, 'Si': 1}
type_onehot = torch.eye(len(type_encoding))

dataset = []

dummy_energies = torch.randn(2, 1, 1)  # dummy energies for example

for crystal, energy in zip([po, si], dummy_energies):
    # edge_src and edge_dst are the indices of the central and neighboring atom, respectively
    # edge_shift indicates whether the neighbors are in different images / copies of the unit cell
    edge_src, edge_dst, edge_shift = ase.neighborlist.neighbor_list("ijS", a=crystal, cutoff=radial_cutoff, self_interaction=True)

    data = torch_geometric.data.Data(
        pos=torch.tensor(crystal.get_positions()),
        lattice=torch.tensor(crystal.cell.array).unsqueeze(0),  # We add a dimension for batching
        x=type_onehot[[type_encoding[atom] for atom in crystal.symbols]],  # Using "dummy" inputs of scalars because they are all C
        edge_index=torch.stack([torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], dim=0),
        edge_shift=torch.tensor(edge_shift, dtype=default_dtype),
        energy=energy  # dummy energy (assumed to be normalized "per atom")
    )

    dataset.append(data)

print(dataset)

[Data(x=[1, 2], edge_index=[2, 7], pos=[1, 3], lattice=[1, 3, 3], edge_shift=[7, 3], energy=[1, 1]), Data(x=[2, 2], edge_index=[2, 10], pos=[2, 3], lattice=[1, 3, 3], edge_shift=[10, 3], energy=[1, 1])]


第一个[`torch_geometric.data.Data`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data)对象是简单的立方钋，它有7条边:6条是最近的邻居，1条是自己的边，`6 + 1 = 7`。第二个[` torch_geometric.data.Data`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data)对象是金刚石硅，它有10条边:每个原子有4个最近邻，每个原子有2条自边，`4 * 2 + 1 * 2 = 10`。每个结构的晶格的形状为`[1,3,3]`，这样当我们批处理示例时，批处理的晶格的形状为`[batch size, 3,3]`。



## Graph Batches

`torch_geometric.data.DataLoader`创建不同尺寸的结构的批次，并在迭代时产生包含`torch_geometric.data.Data`批次的对象。

In [5]:
batch_size = 2
dataloader = torch_geometric.data.DataLoader(dataset, batch_size=batch_size)

for data in dataloader:
    print(data)
    print(data.batch)
    print(data.pos)
    print(data.x)

DataBatch(x=[3, 2], edge_index=[2, 17], pos=[3, 3], lattice=[2, 3, 3], edge_shift=[17, 3], energy=[2, 1], batch=[3], ptr=[3])
tensor([0, 1, 1])
tensor([[0.0000, 0.0000, 0.0000],
        [1.3672, 1.3672, 1.3672],
        [0.0000, 0.0000, 0.0000]])
tensor([[1., 0.],
        [0., 1.],
        [0., 1.]])




`data.batch`是批索引，它是形状张量`[batch_size]`，它存储了哪个点或原子属于哪个例子。在这种情况下，由于我们的批处理中只有两个例子，批处理张量只包含数字`0`和`1`。批处理索引通常传递给分散操作，以聚合[每个示例值](https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html)，例如单晶结构的总能量。

有关`torch_geometric`批处理的更多详细信息，请参阅此页。

## 具有周期边界的边的相对距离向量
为了计算与表示单个示例的给定[`torch_geometric.data.Data`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data)对象的每个边相关的向量，我们使用以下表达式


In [6]:
edge_src, edge_dst = data['edge_index'][0], data['edge_index'][1]
edge_vec = (data['pos'][edge_dst] - data['pos'][edge_src]
            + torch.einsum('ni,nij->nj', data['edge_shift'], data['lattice']))

RuntimeError: einsum(): subscript n has size 2 for operand 1 which does not broadcast with previously seen size 17

`edge_vec`定义中的第一行简单地表示给定两点通常如何计算相对距离向量。第二行添加了由于跨越单元格边界而对相对距离矢量的贡献，即如果原子属于单元格的不同图像。正如我们将在下面看到的，我们可以修改这个表达式，在处理批处理数据时也包含`data['batch']`张量。

## 一种方法:在网络中添加预处理方法

虽然可以将`edge_vec`存储在[`torch_geometric.data.Data`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data)对象中，但也可以通过向Network添加预处理方法来计算`edge_vec`。对于这个例子，我们创建了一个修改版本的示例网络`SimpleNetwork`，[文档在这里](https://docs.e3nn.org/en/stable/api/nn/models/v2103.html#simple-network)，[源代码](https://github.com/e3nn/e3nn/blob/main/e3nn/nn/models/v2103/gate_points_networks.py)在这里。`SimpleNetwork`是检查数据管道的一个很好的起点，但应该为您的特定应用程序替换为更定制的网络。

In [7]:
from e3nn.nn.models.v2103.gate_points_networks import SimpleNetwork
from typing import Dict, Union
import torch_scatter

class SimplePeriodicNetwork(SimpleNetwork):
    def __init__(self, **kwargs):
        """The keyword `pool_nodes` is used by SimpleNetwork to determine
        whether we sum over all atom contributions per example. In this example,
        we want use a mean operations instead, so we will override this behavior.
        """
        self.pool = False
        if kwargs['pool_nodes'] == True:
            kwargs['pool_nodes'] = False
            kwargs['num_nodes'] = 1.
            self.pool = True
        super().__init__(**kwargs)

    # Overwriting preprocess method of SimpleNetwork to adapt for periodic boundary data
    def preprocess(self, data: Union[torch_geometric.data.Data, Dict[str, torch.Tensor]]) -> torch.Tensor:
        if 'batch' in data:
            batch = data['batch']
        else:
            batch = data['pos'].new_zeros(data['pos'].shape[0], dtype=torch.long)

        edge_src = data['edge_index'][0]  # Edge source
        edge_dst = data['edge_index'][1]  # Edge destination

        # We need to compute this in the computation graph to backprop to positions
        # We are computing the relative distances + unit cell shifts from periodic boundaries
        edge_batch = batch[edge_src]
        edge_vec = (data['pos'][edge_dst]
                    - data['pos'][edge_src]
                    + torch.einsum('ni,nij->nj', data['edge_shift'], data['lattice'][edge_batch]))

        return batch, data['x'], edge_src, edge_dst, edge_vec

    def forward(self, data: Union[torch_geometric.data.Data, Dict[str, torch.Tensor]]) -> torch.Tensor:
        # if pool_nodes was set to True, use scatter_mean to aggregate
        output = super().forward(data)
        if self.pool == True:
            return torch_scatter.scatter_mean(output, data.batch, dim=0)  # Take mean over atoms per example
        else:
            return output

我们定义并运行网络。

In [8]:
net = SimplePeriodicNetwork(
    irreps_in="2x0e",  # One hot scalars (L=0 and even parity) on each atom to represent atom type
    irreps_out="1x0e",  # Single scalar (L=0 and even parity) to output (for example) energy
    max_radius=radial_cutoff, # Cutoff radius for convolution
    num_neighbors=10.0,  # scaling factor based on the typical number of neighbors
    pool_nodes=True,  # We pool nodes to predict total energy
)



当我们将网络应用于我们的数据时，每个例子得到一个标量。

In [9]:
for data in dataloader:
    print(net(data).shape)  # One scalar per example

torch.Size([2, 1])
