# 1. Dealing with atomistic data

## 1.1. Building graph data structure using `AtomsGraph`

- In `aml`, each data point is considered as graph, which has species(`elems`) and atomic positions(`pos`) as default node attributes.
- The class `AtomsGraph` is container for atomistic graph data, which is basically `torch_geometric.data.Data` with some additional features.

### 1.1.1 Molecules or cluster without PBC

- Although all attributes are optional, `elems` and `pos` are common attributes for molecules
- `pos` is `(N, 3)` tensor and `elems` is `(N,)` tensor where `N` is number of atoms
- `elems` contains atomic numbers of each atom
- The attributes `n_atoms` and `cell` are automatically created
  - `cell` is filled with zero when no PBC is present

In [1]:
import torch
from aml.data import AtomsGraph

# Water molecule
pos = torch.tensor(
    [[0.00,  0.00,  0.12],
     [0.00,  0.76, -0.48],
     [0.00, -0.76, -0.48]]
)
elems = torch.tensor([8, 1, 1], dtype=torch.long)
data = AtomsGraph(elems=elems, pos=pos)
print(data)

AtomsGraph(elems=[3], pos=[3, 3], n_atoms=[1], cell=[1, 3, 3])


### 1.1.2 Crystals with PBC

- Input the unit cell matrix(`cell`) to make crystal data.
- `cell` is `(1, 3, 3)` tensor, where `1` is batch dimension to handle multiple data in the future
  - For convinience `(3, 3)` tensor is also accepted 

In [2]:
# Pt FCC crystal
pos = torch.tensor(
    [[0.00, 0.00, 0.00],
     [1.96, 1.96, 0.00],
     [1.96, 0.00, 1.96],
     [0.00, 1.96, 1.96]]
)
elems = torch.tensor([78, 78, 78, 78], dtype=torch.long)
cell = torch.tensor(
    [[3.92, 0.00, 0.00],
    [0.00, 3.92, 0.00],
    [0.00, 0.00, 3.92]]
)
data = AtomsGraph(elems=elems, pos=pos, cell=cell)
print(data)

AtomsGraph(elems=[4], pos=[4, 3], cell=[1, 3, 3], n_atoms=[1])


### 1.1.3 Adding properties into data

- Any other properties (ex. energy) can be added to the `data` object.

In [3]:
energy = torch.tensor(-42.0)
force = torch.rand(pos.size(0), 3)
data = AtomsGraph(elems=elems, pos=pos, cell=cell, energy=energy, force=force)
print(data)

AtomsGraph(elems=[4], pos=[4, 3], cell=[1, 3, 3], energy=[1], force=[4, 3], n_atoms=[1])


### 1.1.4 Interface with `ASE`

- `AtomsGraph` has two methods to interact with `ASE` package.
- `to_ase()`: converts data into `Atoms` object
- `from_ase()`: Construct data from `Atoms` object (additionally neighbor list can be built, which will be discussed later)

In [4]:
atoms = data.to_ase()
print("Atoms:", atoms)

data = AtomsGraph.from_ase(atoms, device="cpu")
print("Data:", data)

Atoms: Atoms(symbols='Pt4', pbc=True, cell=[3.9200000762939453, 3.9200000762939453, 3.9200000762939453], calculator=SinglePointCalculator(...))
Data: AtomsGraph(n_atoms=[1], elems=[4], pos=[4, 3], cell=[1, 3, 3], energy=[1], force=[4, 3], batch=[4])


- If pre-computed properties are stored in `Atoms`, `AtomsGraph` automatically read them.
- Currently supported auto-read properties: `"energy"`, `"force"`, `"stress"`

In [5]:
import ase.build
from ase.calculators.emt import EMT

atoms = ase.build.bulk("Pt", "fcc", a=3.92, cubic=True)
atoms.rattle(0.05)
atoms.calc = EMT()
data = AtomsGraph.from_ase(atoms) # Build neighbor list
print("Energy:", data.energy) # eV
print("Forces:", data.force) # eV/A
print("Stress:", data.stress) # eV/A^3

Energy: tensor([0.0335])
Forces: tensor([[ 0.2283, -0.0933, -0.4044],
        [-0.1799,  0.2257,  0.1325],
        [-0.2301, -0.4467,  0.2327],
        [ 0.1817,  0.3143,  0.0392]])
Stress: tensor([[[-0.0073,  0.0004,  0.0012],
         [ 0.0004, -0.0061, -0.0003],
         [ 0.0012, -0.0003, -0.0082]]])


### 1.1.5 Assign edges to graph data

- Until now we only considered graph-level or node-level data
- The key of graph structure is **edge** - encoding the connectivity relationship between nodes
- Edges are represented as [adjacency matrix](https://en.wikipedia.org/wiki/Adjacency_matrix). Here [sparse matrix](https://en.wikipedia.org/wiki/Sparse_matrix) format is used to save space.
    - `edge_index`: `(2, E)` tensor where `E` is number of edges
    - `edge_index[0, i]` and `edge_index[1, i]` are the indices of connected nodes
    - See [here](https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html#data-handling-of-graphs) for more details
- Typical way of constructing the edges is to consider neighbor list (within cutoff) to be connected with center atoms.

In [6]:
# Build neighbor list
data.build_neighborlist(cutoff=3.0)
edge_index = data.edge_index
idx_center, idx_neighbor = edge_index[1], edge_index[0]
print("Center atom indices:", idx_center)
print("Neighbor atom indices:", idx_neighbor)
print("Offset vectors:", data.edge_shift) # Offset vectors towards neighbors outside of unit cell

Center atom indices: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
Neighbor atom indices: tensor([3, 2, 1, 3, 2, 3, 2, 1, 2, 1, 1, 3, 0, 2, 2, 3, 0, 2, 2, 0, 3, 3, 0, 3,
        0, 3, 0, 1, 3, 1, 3, 0, 0, 1, 3, 1, 0, 1, 2, 1, 0, 2, 1, 2, 1, 0, 0, 2])
Offset vectors: tensor([[ 0.,  0.,  0.],
        [ 0.,  0., -1.],
        [ 0.,  0., -1.],
        [-1.,  0.,  0.],
        [-1.,  0., -1.],
        [ 0., -1.,  0.],
        [ 0.,  0.,  0.],
        [ 0., -1.,  0.],
        [-1.,  0.,  0.],
        [ 0.,  0.,  0.],
        [ 0., -1., -1.],
        [-1., -1.,  0.],
        [ 0.,  1.,  0.],
        [ 0.,  0.,  0.],
        [-1.,  0.,  0.],
        [-1.,  0.,  1.],
        [ 0.,  0.,  1.],
        [ 0.,  1.,  0.],
        [-1.,  1.,  0.],
        [ 0.,  1.,  1.],
        [ 0.,  0.,  1.],
        [ 0.,  0.,  0.],
        [ 0.,  0.,  0.],
        [-1.,  0.,  0.],
        [ 1.,  0.,

`AtomsGraph.from_ase` has option to build neighborlist edges when creating graph.

In [7]:
data = AtomsGraph.from_ase(atoms, neighborlist_cutoff=5.0)

### 1.1.6 Dealing with batch of data
Multiple data can be stored in one `AtomsGraph` using `Batch` object from `torch_geometric`.

In [8]:
from torch_geometric.data import Batch

atoms_1 = ase.build.bulk("Pt", "fcc", a=3.92, cubic=True)
atoms_2 = ase.build.bulk("Si", "diamond", a=5.43, cubic=True)

data_1 = AtomsGraph.from_ase(atoms_1, neighborlist_cutoff=5.0)
data_2 = AtomsGraph.from_ase(atoms_2, neighborlist_cutoff=5.0)

batch = Batch.from_data_list([data_1, data_2])
print(batch)
print(batch.batch) # 0~3: data_1, 4~12: data_2
print("Unit cell:", batch.cell)

AtomsGraphBatch(n_atoms=[2], elems=[12], pos=[12, 3], cell=[2, 3, 3], batch=[12], edge_index=[2, 392], edge_shift=[392, 3], ptr=[3])
tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
Unit cell: tensor([[[3.9200, 0.0000, 0.0000],
         [0.0000, 3.9200, 0.0000],
         [0.0000, 0.0000, 3.9200]],

        [[5.4300, 0.0000, 0.0000],
         [0.0000, 5.4300, 0.0000],
         [0.0000, 0.0000, 5.4300]]])


## 1.2. Using datapipe to create data 

- `torchdata` package provides tool `DataPipe` for pipelining the data creation workflow
- Easy to modularize & compose multiple tasks into single datapipe
- Lazy evaluation of pipeline - save memory

### 1.2.1 Basic datapipe tutorial

- `IterDataPipe` is pipeline to iteratively create dataset
- Define `__iter__` as generator that takes other datapipe

In [9]:
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe


class AddNumber(IterDataPipe):
    def __init__(self, dp: IterDataPipe, num: float):
        self.dp = dp
        self.num = num

    def __iter__(self):
        for data in self.dp:
            yield data + self.num

class MultiplyNumber(IterDataPipe):
    def __init__(self, dp: IterDataPipe, num: float):
        self.dp = dp
        self.num = num

    def __iter__(self):
        for data in self.dp:
            yield data * self.num

numbers = [1, 2, 3, 4]
my_datapipe = IterableWrapper(numbers)          # Makes ordinary list into equivalent IterDataPipe
my_datapipe = AddNumber(my_datapipe, 10.0)      # Apply pipeline 1
my_datapipe = MultiplyNumber(my_datapipe, 2.0)  # Apply pipeline 2

# Test
for data in my_datapipe:
    print(data)

22.0
24.0
26.0
28.0


### 1.2.2 Datapipes to create `AtomsGraph`

- `aml` provides some `DataPipe`s to create `AtomsGraph`
    - `ASEFileReader`: read structure files (ex. `xyz`, `traj`, `vasprun.xml`, ...) using `ase.io`
    - `AtomsGraphParser`: Convert `Atoms` into `AtomsGraph`
    - `NeighborListBuilder`: Build neighbor list edges for `AtomsGraph`

In [10]:
from aml.data.datapipes import ASEFileReader, AtomsGraphParser, NeighborListBuilder

file_srcs = ["data/molecules_1.xyz", "data/molecules_2.xyz"] # Multiple files supported
dp = IterableWrapper(file_srcs)

print("==== ASEFileReader ====")
dp = ASEFileReader(dp)
for atoms in dp:
    print(atoms)
print()

print("==== AtomsGraphParser ====")
dp = AtomsGraphParser(dp)
for data in dp:
    print(data)
print()

print("==== NeighborListBuilder ====")
dp = NeighborListBuilder(dp, cutoff=5.0)
for data in dp:
    print(data)
print()



==== ASEFileReader ====
Atoms(symbols='OH2', pbc=False)
Atoms(symbols='C2H2', pbc=False)
Atoms(symbols='C2H4', pbc=False)
Atoms(symbols='C6H6', pbc=False)
Atoms(symbols='C2H6', pbc=False)
Atoms(symbols='C2OH6', pbc=False)
Atoms(symbols='NH3', pbc=False)

==== AtomsGraphParser ====
AtomsGraph(n_atoms=[1], elems=[3], pos=[3, 3], cell=[1, 3, 3], batch=[3])
AtomsGraph(n_atoms=[1], elems=[4], pos=[4, 3], cell=[1, 3, 3], batch=[4])
AtomsGraph(n_atoms=[1], elems=[6], pos=[6, 3], cell=[1, 3, 3], batch=[6])
AtomsGraph(n_atoms=[1], elems=[12], pos=[12, 3], cell=[1, 3, 3], batch=[12])
AtomsGraph(n_atoms=[1], elems=[8], pos=[8, 3], cell=[1, 3, 3], batch=[8])
AtomsGraph(n_atoms=[1], elems=[9], pos=[9, 3], cell=[1, 3, 3], batch=[9])
AtomsGraph(n_atoms=[1], elems=[4], pos=[4, 3], cell=[1, 3, 3], batch=[4])

==== NeighborListBuilder ====
AtomsGraph(n_atoms=[1], elems=[3], pos=[3, 3], cell=[1, 3, 3], batch=[3], edge_index=[2, 6], edge_shift=[6, 3])
AtomsGraph(n_atoms=[1], elems=[4], pos=[4, 3], cell=[1

In [11]:
# all in one
def make_dp(src, neighbor_cutoff):
    if isinstance(src, str):
        src = [src]
    dp = IterableWrapper(src)
    dp = ASEFileReader(dp)
    dp = AtomsGraphParser(dp)
    dp = NeighborListBuilder(dp, neighbor_cutoff)
    return dp

dp = make_dp(["data/molecules_1.xyz", "data/molecules_2.xyz"], 5.0)
for data in dp:
    print(data)

AtomsGraph(n_atoms=[1], elems=[3], pos=[3, 3], cell=[1, 3, 3], batch=[3], edge_index=[2, 6], edge_shift=[6, 3])
AtomsGraph(n_atoms=[1], elems=[4], pos=[4, 3], cell=[1, 3, 3], batch=[4], edge_index=[2, 12], edge_shift=[12, 3])
AtomsGraph(n_atoms=[1], elems=[6], pos=[6, 3], cell=[1, 3, 3], batch=[6], edge_index=[2, 30], edge_shift=[30, 3])
AtomsGraph(n_atoms=[1], elems=[12], pos=[12, 3], cell=[1, 3, 3], batch=[12], edge_index=[2, 132], edge_shift=[132, 3])
AtomsGraph(n_atoms=[1], elems=[8], pos=[8, 3], cell=[1, 3, 3], batch=[8], edge_index=[2, 56], edge_shift=[56, 3])
AtomsGraph(n_atoms=[1], elems=[9], pos=[9, 3], cell=[1, 3, 3], batch=[9], edge_index=[2, 72], edge_shift=[72, 3])
AtomsGraph(n_atoms=[1], elems=[4], pos=[4, 3], cell=[1, 3, 3], batch=[4], edge_index=[2, 12], edge_shift=[12, 3])
