In [2]:
from matminer.datasets import load_dataset

# dielectric = load_dataset("matbench_dielectric")
# log_gvrh = load_dataset("matbench_log_gvrh")
# jdft2d = load_dataset("matbench_jdft2d")

boltztrap_mp = load_dataset("boltztrap_mp")

# matbench_dielectric
# matbench_expt_gap
# matbench_expt_is_metal
# matbench_glass
# matbench_jdft2d
# matbench_steels

In [41]:
import pandas as pd
# dielectric.iloc[0].structure 
# boltztrap_mp.columns

# expt_gap.composition

pd.Series(list(set(boltztrap_mp.formula) & set(expt_is_metal.composition)))


0        SrCrO4
1          PrF3
2           SnS
3      Tl9SbSe6
4       K4GeSe4
         ...   
362         TlS
363      CrPbO4
364    LiNiP2O7
365         LiF
366          KF
Length: 367, dtype: object

In [49]:
boltztrap_mp.head()
boltztrap_mp.columns

Index(['mpid', 'pf_n', 'pf_p', 's_n', 's_p', 'formula', 'm_n', 'm_p',
       'structure'],
      dtype='object')

In [8]:
import torch
from torch_geometric.data import Data
import numpy as np
import pandas as pd
from tqdm import tqdm

periodic_table = pd.read_csv('data/periodic_table.csv')

# Find the largest element in the dataset
# Returns the atomic number of the largest element
def get_largest_element(df):
    largest_element = 0
    for idx, entry in df.iterrows():
        struct = entry.structure
        for site in struct._sites:
            symbol = str(list(site._species._data.keys())[0])
            atomic_number = periodic_table.AtomicNumber[periodic_table['Symbol'] == symbol].values[-1]
            if atomic_number > largest_element:
                largest_element = atomic_number
    return largest_element

# Create a fully connected graph from a pymatgen structure
# Returns an array of shape (2, n_edges) where each column is an edge
# This is the format required by the PyTorch Geometric library
def make_edge_indices(entry):
    n_nodes = len(entry.structure._sites)
    edge_index = []
    for i in range(n_nodes):
        for j in range(n_nodes):
            if i != j:
                edge_index.append([i, j])
    return torch.tensor(edge_index).transpose(0, 1)

# Create a graph from a Pandas DataFrame of pymatgen structures
# Returns a tuple of (feature_matrix, coord_matrix, label)
# feature_matrix is a matrix of shape (n_nodes, n_features)
# coord_matrix is a matrix of shape (n_nodes, 3)
# label is a scalar
def get_features_and_coords(df):
    largest_element = get_largest_element(df)
    data = []

    for idx, entry in tqdm(df.iterrows(), desc="Building material graphs"):
        struct = entry.structure

        feature_matrix = []
        coord_matrix = []

        # Features
        for site in struct._sites:
            feature_vec = [0 for _ in range(largest_element)] # create a vector of zeros
            symbol = str(list(site._species._data.keys())[0])
            atomic_number = periodic_table.AtomicNumber[periodic_table['Symbol'] == symbol].values[-1]
            feature_vec[atomic_number - 1] = 1 # one-hot encode atomic number
            feature_matrix.append(feature_vec)

        # Coordinates
        for site in struct._sites:
            coords = site._frac_coords
            coord_matrix.append(coords)

        coord_matrix = torch.FloatTensor(np.array(coord_matrix))
        feature_matrix = torch.FloatTensor(np.array(feature_matrix))

        # Labels
        labels = {}
        for col in df.columns:
            if col != 'structure':
                labels[col] = torch.tensor(entry[col])

        if (feature_matrix is not None) and (len(feature_matrix) > 1): 
            edge_index=make_edge_indices(entry)
            if len(edge_index.shape) > 1:
                datum = Data(x=feature_matrix, edge_index=edge_index, y=labels, pos=coord_matrix)
                data.append(datum)

    return data

In [9]:
from tqdm import tqdm
df = boltztrap_mp.drop(columns=['mpid', 'formula'])
dataset = get_features_and_coords(df)


Building material graphs: 873it [00:02, 241.67it/s]

In [8]:


from sklearn.neighbors import NearestNeighbors

nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(dataset[0].pos)


In [17]:
_, indices = nbrs.kneighbors(dataset[0].pos.numpy())
print (indices.shape)
print (indices)

(56, 5)
[[ 0 36 23  4 44]
 [ 1 37 26 22 28]
 [ 2 38 21 54 31]
 [ 3 39 24 20 18]
 [ 4 36 55 23 16]
 [ 5 37 54 22 17]
 [ 6 38 53 21 18]
 [ 7 39 52 20 19]
 [ 8 40 13 27 31]
 [ 9 41 12 26 30]
 [10 42 25 29 53]
 [11 43 24 28  5]
 [12  9 44 36 23]
 [13 41 30 26  8]
 [14 46 38 21  6]
 [15 43 28 24 47]
 [16 51 44 23  4]
 [17 50 45 22  5]
 [18 49 46 21  3]
 [19 48 47 20  7]
 [20 47 39  3 32]
 [21 46 38  2 33]
 [22 45 37  1 34]
 [23 44 36  0 35]
 [24 43 39 28  3]
 [25 42 29 10 35]
 [26 41 37 30  1]
 [27 40 31  8 54]
 [28 43 47 32 24]
 [29 42 25 10  7]
 [30 41 45 34 26]
 [31 40 27  2 38]
 [32 47 51 28 20]
 [33 46 21 18 38]
 [34 45 30 22 41]
 [35 44 23 25 42]
 [36 23  0  4 44]
 [37 22  1 26  5]
 [38 21  2  6 46]
 [39 20  3 24  7]
 [40 31 27  8 54]
 [41 30 26 13  9]
 [42 29 25 10 35]
 [43 28 24 15 11]
 [44 23 35 16 36]
 [45 22 34 30 17]
 [46 21 33 18 38]
 [47 20 32 28 19]
 [48 19 15 52 47]
 [49 18 14 10 53]
 [50 17 13 54 45]
 [51 16 32 47  6]
 [52  7 15 48 39]
 [53  6 16 14 10]
 [54  5  2 40  8]
 [

In [42]:
d2  = dataset[0].pos.numpy()

In [43]:
# alps = (1/5) * (dataset[indices[0, 1]].y['pf_n'] + dataset[indices[0, 2]].y['pf_n'] + dataset[indices[0, 3]].y['pf_n'] + dataset[indices[0, 4]].y['pf_n'])
# print (alps)

alps = [abs(d2[i].y['pf_n'] - (1/4) * (d2[indices[i, 1]].y['pf_n'] + d2[indices[i, 2]].y['pf_n'] + d2[indices[i, 3]].y['pf_n'] + d2[indices[i, 4]].y['pf_n'])) for i in range(len(d2))]
print (alps)

# dataset[0].y['pf_n']


AttributeError: 'numpy.ndarray' object has no attribute 'y'

In [41]:
print (sum(alps) / len(dataset))

tensor(0.0022)


In [None]:
max_atom = 0
for i in dataset:
    if len(i.x) > max_atom:
        max_atom = len(i.x)

print(max_atom)

In [6]:
print (boltztrap_mp.structure)

print (max([len(i.structure) for i in boltztrap_mp.structure]))

0       [[ 0.08245398 10.58009491 11.61923254] O, [3.1...
1       [[2.84699546 0.94899849 0.        ] F, [0.9489...
2       [[-2.8085287   7.06608376  1.25800196] Na, [1....
3       [[2.52287112 1.45658029 7.18290524] Li, [-2.85...
4       [[0.86371064 5.14343422 4.80450637] P, [4.5512...
                              ...                        
9031    [[0.90152298 0.90152298 2.11482411] Li, [2.704...
9032    [[0.90289168 4.52569433 1.31102215] S, [2.7086...
9033    [[-1.59315710e-08  2.22451952e+00  1.81128203e...
9034    [[1.85042092 1.068341   9.80765447] S, [-2.183...
9035    [[4.99072785 2.06137412 4.38830252] O, [2.0884...
Name: structure, Length: 8924, dtype: object
