In [1]:
import torch
import e3nn
import ase
import ase.neighborlist
import torch_geometric
import torch_geometric.data

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
default_dtype = torch.float64
torch.set_default_dtype(default_dtype)

In [3]:
from alignn.data import get_train_val_loaders
from repo_utils.data_utils import mp
import repo_utils.debbug_utils as deb

In [4]:
crystals_dict_list = mp.load_json_sample(pdirname="dataset")
print(crystals_dict_list[0].keys())

Loading the dataset...
Loading complete: the number of the loaded data is 6923
dict_keys(['id', 'desc', 'formula', 'e_hull', 'gap pbe', 'mu_b', 'elastic anisotropy', 'bulk modulus', 'shear modulus', 'atoms', 'e_form'])


In [5]:
from jarvis.core.atoms import Atoms

In [6]:
type_encoding = {}
crystals = []
num_atom_types = 0
for crystal_dict in crystals_dict_list:
    target = crystal_dict["e_form"]
    crystal = Atoms.from_dict(crystal_dict["atoms"]).ase_converter()
    for atom in crystal.symbols:
        if atom not in type_encoding:
            type_encoding[atom] = num_atom_types
            num_atom_types += 1
    crystals.append([crystal,target])

type_onehot = torch.eye(len(type_encoding))
print(num_atom_types)

85


In [7]:
# dummy_energies = torch.randn(2, 1, 1)  # dummy energies for example
radial_cutoff = 3.5

dataset = []

for crystal, target in crystals:
    # edge_src and edge_dst are the indices of the central and neighboring atom, respectively
    # edge_shift indicates whether the neighbors are in different images / copies of the unit cell
    edge_src, edge_dst, edge_shift = ase.neighborlist.neighbor_list("ijS", a=crystal, cutoff=radial_cutoff, self_interaction=True)

    data = torch_geometric.data.Data(
        pos=torch.tensor(crystal.get_positions()),
        lattice=torch.tensor(crystal.cell.array).unsqueeze(0),  # We add a dimension for batching
        x=type_onehot[[type_encoding[atom] for atom in crystal.symbols]],  # Using "dummy" inputs of scalars because they are all C
        edge_index=torch.stack([torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], dim=0),
        edge_shift=torch.tensor(edge_shift, dtype=default_dtype),
        energy=target  # dummy energy (assumed to be normalized "per atom")
    )

    dataset.append(data)

In [8]:
print(dataset[0])

Data(x=[36, 85], edge_index=[2, 220], pos=[36, 3], lattice=[1, 3, 3], edge_shift=[220, 3], energy=-0.7881006506944451)


In [9]:
from torch.utils.data.dataset import random_split
batch_size = 48
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

# Calculate the lengths of each split
num_samples = len(dataset)
train_size = int(train_ratio * num_samples)
val_size = int(val_ratio * num_samples)
test_size = num_samples - train_size - val_size

# Split the dataset into train, validation, and test sets
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

# Create data loaders for each split
train_loader = torch_geometric.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch_geometric.data.DataLoader(val_set, batch_size=batch_size)
test_loader = torch_geometric.data.DataLoader(test_set, batch_size=batch_size)



In [10]:
data = next(iter(train_loader))
print(data)
print(data.batch)
print(data.pos)
print(data.x)

DataBatch(x=[1524, 85], edge_index=[2, 19418], pos=[1524, 3], lattice=[48, 3, 3], edge_shift=[19418, 3], energy=[48], batch=[1524], ptr=[49])
tensor([ 0,  0,  0,  ..., 46, 46, 47])
tensor([[ 0.1469,  1.5310,  9.5913],
        [ 3.0665,  7.0898, 14.0848],
        [ 1.0984,  2.8551,  6.0809],
        ...,
        [ 3.8533,  1.1292,  4.8165],
        [ 6.9829,  3.3875,  2.0793],
        [ 0.0000,  0.0000,  0.0000]])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [11]:
from e3nn.nn.models.v2103.gate_points_networks import SimpleNetwork
from typing import Dict, Union
import torch_scatter

class SimplePeriodicNetwork(SimpleNetwork):
    def __init__(self, **kwargs):
        """The keyword `pool_nodes` is used by SimpleNetwork to determine
        whether we sum over all atom contributions per example. In this example,
        we want use a mean operations instead, so we will override this behavior.
        """
        self.pool = False
        if kwargs['pool_nodes'] == True:
            kwargs['pool_nodes'] = False
            kwargs['num_nodes'] = 1.
            self.pool = True
        super().__init__(**kwargs)

    # Overwriting preprocess method of SimpleNetwork to adapt for periodic boundary data
    def preprocess(self, data: Union[torch_geometric.data.Data, Dict[str, torch.Tensor]]) -> torch.Tensor:
        if 'batch' in data:
            batch = data['batch']
        else:
            batch = data['pos'].new_zeros(data['pos'].shape[0], dtype=torch.long)

        edge_src = data['edge_index'][0]  # Edge source
        edge_dst = data['edge_index'][1]  # Edge destination

        # We need to compute this in the computation graph to backprop to positions
        # We are computing the relative distances + unit cell shifts from periodic boundaries
        edge_batch = batch[edge_src]
        edge_vec = (data['pos'][edge_dst]
                    - data['pos'][edge_src]
                    + torch.einsum('ni,nij->nj', data['edge_shift'], data['lattice'][edge_batch]))

        return batch, data['x'], edge_src, edge_dst, edge_vec

    def forward(self, data: Union[torch_geometric.data.Data, Dict[str, torch.Tensor]]) -> torch.Tensor:
        # if pool_nodes was set to True, use scatter_mean to aggregate
        output = super().forward(data)
        if self.pool == True:
            return torch_scatter.scatter_mean(output, data.batch, dim=0)  # Take mean over atoms per example
        else:
            return output

In [12]:
model = SimplePeriodicNetwork(
    irreps_in="85x0e",  # One hot scalars (L=0 and even parity) on each atom to represent atom type
    irreps_out="1x0e",  # Single scalar (L=0 and even parity) to output (for example) energy
    max_radius=radial_cutoff, # Cutoff radius for convolution
    num_neighbors=10.0,  # scaling factor based on the typical number of neighbors
    pool_nodes=True,  # We pool nodes to predict total energy
)

In [13]:
import numpy as np
import torch
import torch.nn as nn

from dgllife.utils import EarlyStopping, Meter
n_epochs = 300
metric ="mae"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
learning_rate = 0.005
weight_decay = 0.0001

In [14]:
print(device)

cuda


SimplePeriodicNetwork(
  (mp): MessagePassing(
    (layers): ModuleList(
      (0): Compose(
        (first): Convolution(
          (sc): FullyConnectedTensorProduct(85x0e x 1x0e -> 150x0e+50x1o+50x2e | 12750 paths | 12750 weights)
          (lin1): FullyConnectedTensorProduct(85x0e x 1x0e -> 85x0e | 7225 paths | 7225 weights)
          (fc): FullyConnectedNet[10, 100, 255]
          (tp): TensorProduct(85x0e x 1x0e+1x1o+1x2e -> 85x0e+85x1o+85x2e | 255 paths | 255 weights)
          (lin2): FullyConnectedTensorProduct(85x0e+85x1o+85x2e x 1x0e -> 150x0e+50x1o+50x2e | 21250 paths | 21250 weights)
          (lin3): FullyConnectedTensorProduct(85x0e+85x1o+85x2e x 1x0e -> 1x0e | 85 paths | 85 weights)
        )
        (second): Gate (150x0e+50x1o+50x2e -> 50x0e+50x1o+50x2e)
      )
      (1): Compose(
        (first): Convolution(
          (sc): FullyConnectedTensorProduct(50x0e+50x1o+50x2e x 1x0e -> 250x0e+50x1o+50x1e+50x2o+50x2e | 17500 paths | 17500 weights)
          (lin1): FullyCon

In [17]:
data.to(device)
model(data)

tensor([[ 0.0064],
        [-0.0516],
        [-0.0541],
        [-0.0695],
        [-0.1049],
        [-0.0162],
        [ 0.0566],
        [-0.0174],
        [-0.0523],
        [-0.0584],
        [-0.0801],
        [ 0.0191],
        [ 0.0661],
        [-0.0463],
        [-0.0485],
        [-0.0651],
        [-0.1024],
        [ 0.0067],
        [-0.0836],
        [ 0.0083],
        [-0.0599],
        [-0.0759],
        [-0.0260],
        [-0.0430],
        [-0.0193],
        [ 0.0046],
        [-0.0371],
        [-0.0629],
        [-0.0283],
        [-0.0727],
        [-0.0310],
        [-0.0223],
        [-0.0641],
        [-0.0654],
        [-0.0641],
        [-0.0114],
        [-0.0605],
        [ 0.0305],
        [-0.0751],
        [-0.0410],
        [-0.0645],
        [-0.0753],
        [ 0.1047],
        [-0.0302],
        [-0.0761],
        [ 0.0248],
        [-0.0448],
        [ 0.0172]], device='cuda:0', grad_fn=<DivBackward0>)

In [22]:
def run_a_train_epoch(args,epoch, model, data_loader,
                      loss_criterion, optimizer):
    model.train()
    train_meter = Meter()
    for batch_id, batch_data in enumerate(data_loader):
        data, labels = batch_data, batch_data["energy"]
        data = data.to(device)
        # data['batch'] = data['batch'].to('cpu')
        # data['pos'] = data['pos'].to(device)
        labels = labels.reshape([-1,1])
        labels = labels.to(device)
        prediction = model(data)
        loss = (loss_criterion(prediction, labels)).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_meter.update(prediction, labels)
    total_score = np.mean(train_meter.compute_metric(metric))
    print('epoch {:d}/{:d}, training {} {:.4f}'.format(
        epoch + 1, n_epochs, metric, total_score))

In [23]:
def run_an_eval_epoch(args, model, data_loader, loss_criterion):
    model.eval()
    eval_meter = Meter()
    with torch.no_grad():
        for batch_id, batch_data in enumerate(data_loader):
            data, labels = batch_data, batch_data["energy"]
            data = data.to(device)
            # data['batch'] = data['batch'].to('cpu')
            # data['pos'] = data['pos'].to(device)
            labels = labels.reshape([-1,1])
            labels = labels.to(device)
            prediction = model(data)
            eval_loss = (loss_criterion(prediction, labels)).mean()
            eval_meter.update(prediction, labels)
        total_score = np.mean(eval_meter.compute_metric(metric))
    return total_score, eval_loss

In [20]:
loss_fn = nn.MSELoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(),
                             lr=learning_rate,
                             weight_decay=weight_decay)

In [24]:
for epoch in range(n_epochs):
    run_a_train_epoch(None, epoch, model, train_loader, loss_fn, optimizer)
    val_score, val_loss = run_an_eval_epoch(None, model, val_loader,loss_fn)
    print('epoch {:d}/{:d}, validation {} {:.4f}'.format(
        epoch + 1, n_epochs, metric, val_score))

epoch 1/300, training mae 0.5155
epoch 1/300, validation mae 0.3896
epoch 2/300, training mae 0.3723
epoch 2/300, validation mae 0.3533
epoch 3/300, training mae 0.3587
epoch 3/300, validation mae 0.3454
epoch 4/300, training mae 0.3404
epoch 4/300, validation mae 0.3246
epoch 5/300, training mae 0.3191
epoch 5/300, validation mae 0.3008
epoch 6/300, training mae 0.2901
epoch 6/300, validation mae 0.2779
epoch 7/300, training mae 0.2752
epoch 7/300, validation mae 0.2663
epoch 8/300, training mae 0.2631
epoch 8/300, validation mae 0.2516
epoch 9/300, training mae 0.2501
epoch 9/300, validation mae 0.2480
epoch 10/300, training mae 0.2478
epoch 10/300, validation mae 0.2421
epoch 11/300, training mae 0.2379
epoch 11/300, validation mae 0.2357
epoch 12/300, training mae 0.2349
epoch 12/300, validation mae 0.2365
epoch 13/300, training mae 0.2310
epoch 13/300, validation mae 0.2342
epoch 14/300, training mae 0.2321
epoch 14/300, validation mae 0.2271
epoch 15/300, training mae 0.2283
epoc

In [None]:
test_score, test_loss = run_an_eval_epoch(None, model, train_loader,loss_fn)
print('test {} {:.4f}, test loss {:.4f}'.format(
    metric, test_score, test_loss))