In [1]:
import yaml
import json

import pandas as pd
import numpy as np
import tensorflow as tf

from pathlib import Path
from pymatgen.core import Structure
from sklearn.model_selection import train_test_split
#from megnet.models import MEGNetModel
#from megnet.data.crystal import CrystalGraph




def read_pymatgen_dict(file):
    with open(file, "r") as f:
        d = json.load(f)
    return Structure.from_dict(d)


def energy_within_threshold(prediction, target):
    # compute absolute error on energy per system.
    # then count the no. of systems where max energy error is < 0.02.
    e_thresh = 0.02
    error_energy = tf.math.abs(target - prediction)

    success = tf.math.count_nonzero(error_energy < e_thresh)
    total = tf.size(target)
    return success / tf.cast(total, tf.int64)

def prepare_dataset(dataset_path):
    dataset_path = Path(dataset_path)
    targets = pd.read_csv(dataset_path / "targets.csv", index_col=0)
    struct = {
        item.name.strip(".json"): read_pymatgen_dict(item)
        for item in (dataset_path / "structures").iterdir()
    }

    data = pd.DataFrame(columns=["structures"], index=struct.keys())
    data = data.assign(structures=struct.values(), targets=targets)

    return train_test_split(data, test_size=0.25, random_state=666)

 
def prepare_model(cutoff, lr):
    nfeat_bond = 10
    r_cutoff = cutoff
    gaussian_centers = np.linspace(0, r_cutoff + 1, nfeat_bond)
    gaussian_width = 0.8
    
    return MEGNetModel(
        graph_converter=CrystalGraph(cutoff=r_cutoff),
        centers=gaussian_centers,
        width=gaussian_width,
        loss=["MAE"],
        npass=2,
        lr=lr,
        metrics=energy_within_threshold
    )


def main(config):
    train, test = prepare_dataset(config["datapath"])
    model = prepare_model(
        float(config["model"]["cutoff"]),
        float(config["model"]["lr"]), 
    )
    model.train(
        train.structures,
        train.targets,
        validation_structures=test.structures,
        validation_targets=test.targets,
        epochs=int(config["model"]["epochs"]),
        batch_size=int(config["model"]["batch_size"]),
    )


C:\Users\User\Anaconda3\lib\site-packages\numpy\.libs\libopenblas.PYQHXLVVQ7VESDPUVUADXEVJOBGHJPAY.gfortran-win_amd64.dll
C:\Users\User\Anaconda3\lib\site-packages\numpy\.libs\libopenblas.XWYDX2IKJW2NMTWSFYNGFUWKQU3LYTCZ.gfortran-win_amd64.dll
  stacklevel=1)


In [3]:
train, test = prepare_dataset('data\dichalcogenides_public')

In [7]:
temp = train['structures'].iloc[0]

In [18]:
train

Unnamed: 0,structures,targets
6142710b4e27a1844a5f07f4,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.1451
6141d01431cf3ef3d4a9edec,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,0.4315
6141d38dee0a3fd43fb47b49,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,0.4292
6141e2d4ee0a3fd43fb47cc5,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,0.3543
61422bfc4e27a1844a5f0682,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.1470
...,...,...
6141cf2031cf3ef3d4a9ed56,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,0.3552
614415764e27a1844a5f0aa0,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,1.1476
61421fb831cf3ef3d4a9f32c,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,0.4184
61421c9c31cf3ef3d4a9f30e,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,0.4232


In [10]:
temp_df = temp.as_dataframe()

In [13]:
temp_df['Species'].value_counts()

(S)     126
(Mo)     63
(W)       1
(Se)      1
Name: Species, dtype: int64

In [15]:
train['targets']

6142710b4e27a1844a5f07f4    1.1451
6141d01431cf3ef3d4a9edec    0.4315
6141d38dee0a3fd43fb47b49    0.4292
6141e2d4ee0a3fd43fb47cc5    0.3543
61422bfc4e27a1844a5f0682    1.1470
                             ...  
6141cf2031cf3ef3d4a9ed56    0.3552
614415764e27a1844a5f0aa0    1.1476
61421fb831cf3ef3d4a9f32c    0.4184
61421c9c31cf3ef3d4a9f30e    0.4232
61425a159cbada84a8676d25    1.8086
Name: targets, Length: 2224, dtype: float64

In [16]:
temp.as_dict()

{'@module': 'pymatgen.core.structure',
 '@class': 'Structure',
 'charge': None,
 'lattice': {'matrix': [[25.5225256, 0.0, 1.5628039641098191e-15],
   [-12.761262799999994, 22.10315553833868, 1.5628039641098191e-15],
   [0.0, 0.0, 14.879004]],
  'a': 25.5225256,
  'b': 25.5225256,
  'c': 14.879004,
  'alpha': 90.0,
  'beta': 90.0,
  'gamma': 119.99999999999999,
  'volume': 8393.668021812642},
 'sites': [{'species': [{'element': 'Mo', 'occu': 1.0}],
   'abc': [0.04166667, 0.08333333, 0.25],
   'xyz': [1.2761262868643541e-07, 1.8419295545177048, 3.719751],
   'label': 'Mo',
   'properties': {}},
  {'species': [{'element': 'Mo', 'occu': 1.0}],
   'abc': [0.04166667, 0.3333333333333333, 0.25],
   'xyz': [-3.1903156149249123, 7.367718512779559, 3.7197510000000005],
   'label': 'Mo',
   'properties': {}},
  {'species': [{'element': 'Mo', 'occu': 1.0}],
   'abc': [0.04166667, 0.45833333, 0.25],
   'xyz': [-4.785473422387369, 10.13061288139471, 3.719751000000001],
   'label': 'Mo',
   'properti

In [12]:
temp_df

Unnamed: 0,Species,a,b,c,x,y,z
0,(Mo),0.041667,0.083333,0.250000,1.276126e-07,1.841930,3.719751
1,(Mo),0.041667,0.333333,0.250000,-3.190316e+00,7.367719,3.719751
2,(Mo),0.041667,0.458333,0.250000,-4.785473e+00,10.130613,3.719751
3,(Mo),0.041667,0.583333,0.250000,-6.380631e+00,12.893507,3.719751
4,(Mo),0.041667,0.708333,0.250000,-7.975789e+00,15.656402,3.719751
...,...,...,...,...,...,...,...
186,(S),0.958333,0.416667,0.355174,1.914189e+01,9.209648,5.284635
187,(S),0.958333,0.541667,0.355174,1.754674e+01,11.972543,5.284635
188,(S),0.958333,0.666667,0.355174,1.595158e+01,14.735437,5.284635
189,(S),0.958333,0.791667,0.355174,1.435642e+01,17.498332,5.284635


In [17]:
temp_df.groupby('Species').count()

Unnamed: 0_level_0,a,b,c,x,y,z
Species,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
(S),126,126,126,126,126,126
(Se),1,1,1,1,1,1
(W),1,1,1,1,1,1
(Mo),63,63,63,63,63,63


In [None]:
def vectorize_structure(structure):
    structure_df = structure.as_dataframe()
    grouped_df = structure_df.groupby('Species')
    return vector