## M3GNet property models

In this notebook, I will use the `matbench` datasets for training property models using M3GNet model framework.

```
Dunn, et al. npj Comput. Mater., 6(1), 1-10. 2020
```

In [8]:
from matbench.bench import MatbenchBenchmark


dataset = ['matbench_dielectric']

mb = MatbenchBenchmark(subset=dataset)

2022-11-13 10:42:53 INFO     Initialized benchmark 'matbench_v0.1' with 1 tasks: 
['matbench_dielectric']


In [15]:
import tensorflow as tf
from sklearn.model_selection import train_test_split

from m3gnet.models import M3GNet
from m3gnet.trainers import Trainer

for task in mb.tasks:
    print(f"Training {task.dataset_name}")
    task.load()
    for fold in task.folds:
        print(f"Fold {fold}")
        # load train and validation data
        train_val_inputs, train_val_outputs = task.get_train_and_val_data(fold)
        
        # split train and val
        train_inputs, val_inputs, train_outputs, val_outputs = train_test_split(
            train_val_inputs, train_val_outputs, test_size=0.1
        )
        
        # load test 
        test_inputs, test_outputs = task.get_test_data(fold, include_target=True)
        
        # initialize a model, note that refractive index (dielectric property)
        # here is an intensive property
        model = M3GNet(n_blocks=1, is_intensive=True)
        
        # Trainer
        trainer = Trainer(model=model, optimizer=tf.keras.optimizers.Adam(1e-3))
        
        # Train
        trainer.train(train_inputs, train_outputs,
                     validation_graphs_or_structures=val_inputs,
                     validation_targets=val_outputs,
                     epochs=10,
                     batch_size=4)
        
        # Test on test data
        test_predict_outputs = model.predict_structures(test_inputs, batch_size=16)
        
        # only train on fold as an example
        break

Training matbench_dielectric
2022-11-13 10:49:23 INFO     Dataset matbench_dielectric already loaded; not reloading dataset.


INFO:matbench.task:Dataset matbench_dielectric already loaded; not reloading dataset.


Fold 0
Epoch 1/10
      2/Unknown - 3s 2s/step - loss: 5.8985



    856/Unknown - 41s 46ms/step - loss: 3.3114



Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [6]:
MatbenchBenchmark.