## Training M3GNet 

Here we show examples of how we can train ML potentials using M3GNet on crystalline and molecular data.

**Crystalline data**

Here we use a subsample of the MPF 2021.2.8 data from the M3GNet publication


In [1]:
import warnings
warnings.filterwarnings("ignore")

import pickle

with open('ZnS_SiO2.p', 'rb') as f:
    data = pickle.load(f)

In [2]:
len(data)

169

In [26]:
!ls /Users/zhouruixuan0424/Downloads/*.json

/Users/zhouruixuan0424/Downloads/config.json


In [3]:
from itertools import chain
import numpy as np

def get_data(ids):
    structures = list(chain(*[data[i]['structure'] for i in ids]))
    energies = list(chain(*[data[i]['energy'] for i in ids]))
    forces = list(chain(*[data[i]['force'] for i in ids]))
    
    # 1 kBa to 0.1 GPa and the sign convention change
    stresses = list(chain(*[np.array(data[i]['stress']) * -0.1 for i in ids]))
    return structures, energies, forces, stresses

mp_ids = list(data.keys())  # here we have 169 materials
np.random.seed(42)
np.random.shuffle(mp_ids)

train_ids = mp_ids[:150]
val_ids = mp_ids[150:]

train_structures, train_energies, train_forces, train_stresses = get_data(train_ids)
val_structures, val_energies, val_forces, val_stresses = get_data(val_ids)

**M3GNet crystal model training**

In [24]:
from m3gnet.models import M3GNet, Potential
from m3gnet.trainers import PotentialTrainer
import tensorflow as tf

potential = Potential(M3GNet(n_blocks=1, is_intensive=False))


trainer = PotentialTrainer(potential=potential,
                          optimizer=tf.keras.optimizers.Adam(2e-3))

In [None]:
import tensorflow as tf

hub = tf.keras.losses.Huber(0.1)

def loss(x, y):
    return 100 * hub(x, y)

trainer.train(graphs_or_structures=train_structures,
             energies=train_energies,
             forces=train_forces,
             stresses=train_stresses,
             validation_graphs_or_structures=val_structures,
             val_energies=val_energies,
             val_forces=val_forces,
             val_stresses=val_stresses,
             batch_size=8,
             force_loss_ratio=0.1,
             stress_loss_ratio=0.01,
             loss= loss,
             epochs=10,
             fit_per_element_offset=True)