In [1]:
# model
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as tg
import torch_scatter
import e3nn
from e3nn import o3
from typing import Dict, Union

# crystal structure data
from ase import Atom, Atoms
from ase.neighborlist import neighbor_list
from ase.visualize.plot import plot_atoms

# data pre-processing and visualization
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

# utilities
import time
from tqdm import tqdm
# from utils.utils_data import (load_data, train_valid_test_split, plot_example, plot_predictions, plot_partials,
#                               palette, colors, cmap)
# from utils.utils_model import Network, visualize_layers, train
# from utils.utils_plot import plotly_surface, plot_orbitals, get_middle_feats
from utils.data import (load_data)

bar_format = '{l_bar}{bar:10}{r_bar}{bar:-10b}'
tqdm.pandas(bar_format=bar_format)
default_dtype = torch.float32
torch.set_default_dtype(default_dtype)

In [2]:
# load data
df, species = load_data('data/bec_run.db')
df.head()

Unnamed: 0,structure,bec,energy,forces,formula,species
0,"(Atom('Nd', [0.21386, 5.14073, 5.74185], magmo...","[[[4.18062096, -0.10635331, 0.05667826], [-0.0...",-160.954128,"[[0.04883505, -0.02411291, 0.01005567], [0.016...",H4Nd4Ni4O12,"[O, H, Nd, Ni]"
1,"(Atom('Nd', [0.48857, 5.18121, 5.82727], magmo...","[[[3.94453436, 0.05827721, 0.08160375], [0.097...",-160.83221,"[[0.0062039, -0.01181304, -0.00304418], [-0.00...",H4Nd4Ni4O12,"[O, H, Nd, Ni]"
2,"(Atom('Nd', [0.51902, 5.11215, 6.25313], magmo...","[[[4.23309052, -0.06272848, -0.0245231], [0.20...",-161.126651,"[[-0.00153404, -0.01371113, 0.03815604], [-0.0...",H4Nd4Ni4O12,"[O, H, Nd, Ni]"
3,"(Atom('Nd', [0.22578, 5.15645, 5.72975], magmo...","[[[4.31953816, 0.39699484, 0.10816991], [0.492...",-161.306738,"[[-0.02449503, -0.019128, -0.02941035], [0.001...",H4Nd4Ni4O12,"[O, H, Nd, Ni]"
4,"(Atom('Nd', [0.48497, -0.06402, 5.75643], magm...","[[[4.12645607, 0.18284011, 0.09932547], [0.383...",-160.415418,"[[0.00508027, -0.03106301, 0.08675875], [-0.10...",H4Nd4Ni4O12,"[O, H, Nd, Ni]"


In [3]:
df.dtypes

structure     object
bec           object
energy       float64
forces        object
formula       object
species       object
dtype: object

In [5]:
df.iloc[0]['bec']

array([[[ 4.18062096e+00, -1.06353310e-01,  5.66782600e-02],
        [-6.56474800e-02,  4.10805767e+00,  3.90021400e-02],
        [ 3.58588300e-02,  2.28088000e-02,  4.05099451e+00]],

       [[ 4.39752881e+00, -1.79209270e-01,  6.05955000e-03],
        [-3.17797760e-01,  3.69352964e+00,  1.34514200e-02],
        [-1.18605100e-02,  1.28390100e-02,  3.97240425e+00]],

       [[ 4.21032064e+00,  1.95080080e-01,  1.81588520e-01],
        [ 1.99591500e-02,  4.26366292e+00,  1.84628820e-01],
        [ 1.19836250e-01,  1.74157690e-01,  4.05328229e+00]],

       [[ 4.01634342e+00,  1.24787020e-01, -2.19450880e-01],
        [ 2.98161900e-01,  4.35724060e+00, -1.35950910e-01],
        [-1.94406960e-01, -1.76541150e-01,  4.16347780e+00]],

       [[ 1.87083165e+00,  1.04859660e-01, -4.87251000e-03],
        [-1.01829890e-01,  1.82847470e+00, -1.07960940e-01],
        [-9.08903000e-03,  1.73633770e-01,  1.74551643e+00]],

       [[ 1.82997296e+00, -2.28738060e-01,  1.01038970e-01],
        [ 2.20

In [27]:
Atoms.fromdict({'numbers': [56, 52], 
 'positions': [[0.0, 0.0, 0.0], [2.220446049250313e-16, 6.022430681109324, -4.440892098500626e-16]], 
 'cell': [[1.4195005246127737, 4.014953787406216, 2.4586470300000003], [1.4195005246127737, 4.014953787406216, -2.4586470299999994], [-2.839001049225547, 4.014953787406216, -0.0]], 
 'pbc': [True, True, True]})

Atoms(symbols='BaTe', pbc=True, cell=[[1.4195005246127737, 4.014953787406216, 2.4586470300000003], [1.4195005246127737, 4.014953787406216, -2.4586470299999994], [-2.839001049225547, 4.014953787406216, -0.0]])

In [3]:
# one-hot encoding atom type and mass
type_encoding = {}
specie_am = []
for Z in tqdm(range(1, 119), bar_format=bar_format):
    specie = Atom(Z)
    type_encoding[specie.symbol] = Z - 1
    specie_am.append(specie.mass)

type_onehot = torch.eye(len(type_encoding))
am_onehot = torch.diag(torch.tensor(specie_am))

100%|██████████| 118/118 [00:00<00:00, 160534.50it/s]                                                                         


In [6]:
# build data
def build_data(entry, type_encoding, type_onehot, r_max=3.5):
    symbols = list(entry.structure.symbols).copy()
    positions = torch.from_numpy(entry.structure.positions.copy())
    lattice = torch.from_numpy(entry.structure.cell.array.copy()).unsqueeze(0)

    # edge_src and edge_dst are the indices of the central and neighboring atom, respectively
    # edge_shift indicates whether the neighbors are in different images or copies of the unit cell
    edge_src, edge_dst, edge_shift = neighbor_list("ijS", a=entry.structure, cutoff=r_max, self_interaction=True)
    
    # compute the relative distances and unit cell shifts from periodic boundaries
    edge_batch = positions.new_zeros(positions.shape[0], dtype=torch.long)[torch.from_numpy(edge_src)]
    edge_vec = (positions[torch.from_numpy(edge_dst)]
                - positions[torch.from_numpy(edge_src)]
                + torch.einsum('ni,nij->nj', torch.tensor(edge_shift, dtype=default_dtype), lattice[edge_batch]))

    # compute edge lengths (rounded only for plotting purposes)
    edge_len = np.around(edge_vec.norm(dim=1).numpy(), decimals=2)
    
    data = tg.data.Data(
        pos=positions, lattice=lattice, symbol=symbols,
        x=am_onehot[[type_encoding[specie] for specie in symbols]],   # atomic mass (node feature)
        z=type_onehot[[type_encoding[specie] for specie in symbols]], # atom type (node attribute)
        edge_index=torch.stack([torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], dim=0),
        edge_shift=torch.tensor(edge_shift, dtype=default_dtype),
        edge_vec=edge_vec, edge_len=edge_len,
        bec=torch.from_numpy(entry.bec).unsqueeze(0)
    )
    
    return data

r_max = 3.5 # cutoff radius
df['data'] = df.progress_apply(lambda x: build_data(x, type_encoding, type_onehot, r_max), axis=1)

  1%|          | 1/89 [00:00<00:00, 667.88it/s]                                                                               


AttributeError: 'str' object has no attribute 'symbols'