# Notebook on How to Perform Neural Network Decomposition

In [5]:
#Package imports
import pandas as pd
import numpy as np 
import matplotlib.pyplot as plt
import seaborn as sns
import math

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from sklearn.model_selection import train_test_split

#Helper functions for specific feature engineering used in paper and NN decomp alg
import feature_eng

## Step 1: Import & Format Data of a Specific Adsorbate

In [6]:
#Read in CO,CHO,COOH dataset, select 1 of the 3 adsorbates to study
df = pd.read_csv('All_data.csv')
df_COOH = df[df['Adsorbate']=='COOH']

#Create inputs (x) and labels (y) for MLP NN model
x = feature_eng.feature_embedding(df_COOH.drop(columns=['Adsorbate','Eads']))
y = df_COOH['Eads']

## Step 2: Train Model, Extract internal weights

In [7]:
#Model Training + Fitting
from random import shuffle
x_train, x_test, y_train, y_test = train_test_split(x,y, test_size=0.15)
class MyModel(tf.keras.Model):
  def __init__(self, ):
    super(MyModel, self).__init__()
    self.w1 = tf.keras.layers.Dense(6, activation='tanh', use_bias=False)
    self.w3 = tf.keras.layers.Dense(1, activation='linear', use_bias=False)
  def call(self, inputs):
    x = self.w1(inputs)
    x = self.w3(x)
    return tf.math.reduce_sum(x, axis=1,)
model = MyModel()
model.compile(optimizer='adam', loss='mse', metrics=['mae','mse'])
h = model.fit(x_train, y_train, epochs=3000, callbacks=[],verbose=0 )

## Step 3: Recreate NN manually with extracted weights

In [8]:
#Extract internal weights of NN
params = model.trainable_variables
p1 = np.array(params[0])
p2 = np.array(params[1])

#Perform NN decompistion:
#Have to specify sites 1-10 (See SI of paper for visualization/elaboration)
avg_inf = feature_eng.NN_decomposition(3,p1,p2)

Average site influence across all elements: 0.3259 eV/atom


Repeat for each adsorbate dataset to obtain the
figures produced in ACS Catal. 2022, 12, 24, 14864–14871
(https://doi.org/10.1021/acscatal.2c03675)