# Neural Message Passing for Quantum Chemistry

Ref: https://arxiv.org/pdf/1704.01212.pdf

Assumptions:
1. Hidden states for atoms are not updated (only for atoms).

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch.optim as optim
import matplotlib.pyplot as plt

from utils import rolling_mean
from data import DataPreprocessor
from mpnn.M import MfuncMLP
from mpnn.R import RfuncMLP
from mpnn.U import UfuncMLP
from mpnn.mpnn import MPNN

In [None]:
m_dim = 32
h_dim = 24
Mfunc_start = MfuncMLP(inp_atom_features=75, inp_edge_features=6, out_size_atom=m_dim)
Mfunc_hid = MfuncMLP(inp_atom_features=h_dim, inp_edge_features=6, out_size_atom=m_dim)
Ufunc_start = UfuncMLP(inp_atom_features=75, inp_atom_m_state=m_dim, out_size_atom=h_dim)
Ufunc_hid = UfuncMLP(inp_atom_features=h_dim, inp_atom_m_state=m_dim, out_size_atom=h_dim)
Rfunc = RfuncMLP(inp_size=h_dim, hid=10)

In [None]:
mpnn = MPNN(Mfunc_start, Mfunc_hid, Ufunc_start, Ufunc_hid, Rfunc, 2)

In [None]:
DATASET = 'data.test'

In [None]:
data = DataPreprocessor(DATASET, filter_dots=True, filter_atoms=True)

In [None]:
data.load_dataset()

In [None]:
data.filter_data()

In [None]:
train_smiles, train_labels, valid_smiles, valid_labels, test_smiles, test_labels = data.get_data()

In [None]:
losses = []

In [None]:
for _ in range(5):
    for i in range(1000):
        loss = mpnn.make_opt_step((train_smiles[i:i+1], train_labels[i:i+1]), 2)
        print(i, loss)
        losses.append(loss)

In [None]:
plt.plot(rolling_mean(losses, 150))

In [None]:
mpnn.opt.state_dict()