In [1]:
import warnings

warnings.filterwarnings('ignore')

from megnet.models import MEGNetModel


### Load formation energy model

In [2]:
model_form = MEGNetModel.from_file('../mvl_models/mp-2018.6.1/formation_energy.hdf5')

2024-01-22 20:24:40.155562: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2024-01-22 20:24:40.169098: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2024-01-22 20:24:40.169184: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2024-01-22 20:24:40.169652: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags

In [3]:
model_form.layers

[<keras.engine.input_layer.InputLayer at 0x7f4a5f33d4f0>,
 <keras.layers.embeddings.Embedding at 0x7f4a5f33d460>,
 <keras.engine.input_layer.InputLayer at 0x7f4a5f2b9550>,
 <keras.engine.input_layer.InputLayer at 0x7f4a5f2b94f0>,
 <keras.layers.core.dense.Dense at 0x7f4a5f33d550>,
 <keras.layers.core.dense.Dense at 0x7f4a5f2b9fd0>,
 <keras.layers.core.dense.Dense at 0x7f4a5e9dc370>,
 <keras.layers.core.dense.Dense at 0x7f4a5e9dc6d0>,
 <keras.layers.core.dense.Dense at 0x7f4a5e9dca30>,
 <keras.layers.core.dense.Dense at 0x7f4a5e9dcd90>,
 <keras.engine.input_layer.InputLayer at 0x7f4a5e9e0130>,
 <keras.engine.input_layer.InputLayer at 0x7f4a5e9e0160>,
 <keras.engine.input_layer.InputLayer at 0x7f4a5e9e05b0>,
 <keras.engine.input_layer.InputLayer at 0x7f4a5e9e0940>,
 <megnet.layers.graph.megnet.MEGNetLayer at 0x7f4a5e9dce20>,
 <keras.layers.merge.Add at 0x7f4a5e9e56a0>,
 <keras.layers.merge.Add at 0x7f4a5e9e5700>,
 <keras.layers.merge.Add at 0x7f4a5e9e5790>,
 <keras.layers.core.dense.Dens

In [6]:
[i.name for i in model_form.layers]

['input_1',
 'embedding_1',
 'input_2',
 'input_3',
 'dense_1',
 'dense_3',
 'dense_5',
 'dense_2',
 'dense_4',
 'dense_6',
 'input_4',
 'input_5',
 'input_6',
 'input_7',
 'meg_net_layer_1',
 'add_1',
 'add_2',
 'add_3',
 'dense_7',
 'dense_9',
 'dense_11',
 'dense_8',
 'dense_10',
 'dense_12',
 'meg_net_layer_2',
 'add_4',
 'add_5',
 'add_6',
 'dense_13',
 'dense_15',
 'dense_17',
 'dense_14',
 'dense_16',
 'dense_18',
 'meg_net_layer_3',
 'add_7',
 'add_8',
 'set2_set_1',
 'set2_set_2',
 'add_9',
 'concatenate_1',
 'dense_19',
 'dense_20',
 'dense_21']

In [5]:
model_form.layers[0].name

'input_1'

### Get the embedding layer

In [7]:
embedding_layer = [i for i in model_form.layers if i.name.startswith('embedding')][0]
embedding = embedding_layer.get_weights()[0]
print('Embedding matrix dimension is ', embedding.shape)

Embedding matrix dimension is  (95, 16)


In [8]:
embedding

array([[-0.01925377, -0.04383501,  0.01868666, ..., -0.02677364,
        -0.03300881, -0.04033415],
       [-0.49307713,  0.48247465, -0.2530202 , ..., -0.40325513,
         0.30468777,  0.18270242],
       [-0.9741373 , -0.1879723 , -0.6057493 , ..., -0.44998255,
         0.5049042 ,  0.17792016],
       ...,
       [ 1.0661256 ,  0.45345035,  0.41983265, ...,  0.5088932 ,
         0.02797919, -0.16767566],
       [ 0.81127083, -0.21892792,  0.19341795, ...,  0.5296463 ,
        -0.21794377, -0.00269936],
       [ 1.0503706 ,  0.22049105,  0.28066757, ...,  0.25579533,
         0.16338132, -0.2446297 ]], dtype=float32)

The embedding matrix size is 95 x 16, so that the maximum atomic number (94) in the MP database can find the corresponding row in the embedding matrix. 

### Construct a new model and set embeddings

In [9]:
model = MEGNetModel(100, 2, nvocal=95, embedding_dim=16)

In [10]:
# find the embedding layer  index in all the model layers
embedding_layer_index = [i for i, j in enumerate(model.layers) if j.name.startswith('atom_embedding')][0]

# Set the weights to our previous embedding
model.layers[embedding_layer_index].set_weights([embedding])

# Freeze the weights
model.layers[embedding_layer_index].trainable = False

In [11]:
embedding_layer_index

1

In [13]:
model.layers

[<keras.engine.input_layer.InputLayer at 0x7f49d0616730>,
 <keras.layers.embeddings.Embedding at 0x7f49d0616580>,
 <keras.engine.input_layer.InputLayer at 0x7f49d0616070>,
 <keras.engine.input_layer.InputLayer at 0x7f49d061da30>,
 <keras.layers.core.dense.Dense at 0x7f49d062eb80>,
 <keras.layers.core.dense.Dense at 0x7f49d05ad9a0>,
 <keras.layers.core.dense.Dense at 0x7f49d053d7c0>,
 <keras.layers.core.dense.Dense at 0x7f49d061da60>,
 <keras.layers.core.dense.Dense at 0x7f49d05348b0>,
 <keras.layers.core.dense.Dense at 0x7f49d0542dc0>,
 <keras.engine.input_layer.InputLayer at 0x7f49d0606bb0>,
 <keras.engine.input_layer.InputLayer at 0x7f49d060fb80>,
 <keras.engine.input_layer.InputLayer at 0x7f49d060fca0>,
 <keras.engine.input_layer.InputLayer at 0x7f49d060f1c0>,
 <megnet.layers.graph.megnet.MEGNetLayer at 0x7f49d062e0d0>,
 <keras.layers.merge.Add at 0x7f49d0557e50>,
 <keras.layers.merge.Add at 0x7f49d0542790>,
 <keras.layers.merge.Add at 0x7f49d051d520>,
 <keras.layers.core.dense.Dens

Now that `model` should have the same embeddings as the pre-trained model, and the weights won't change during training.