# MPNN fingerprints

# Import packages

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import torch
from sklearn.decomposition import PCA

from chemprop import data, featurizers, models

# Change model input here

In [None]:
checkpoint_path = '../tests/data/example_model_v2.ckpt'  # path to the checkpoint file.
# If the checkpoint file is generated using the training notebook,
# it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`.

## Load model

In [None]:
mpnn = models.MPNN.load_from_checkpoint(checkpoint_path)
mpnn

# Change data input here

In [None]:
test_path = '../tests/data/smis.csv'
smiles_column = 'smiles'

## Load data

In [None]:
df_test = pd.read_csv(test_path)

smis = df_test[smiles_column]

test_data = [data.MoleculeDatapoint.from_smi(smi) for smi in smis]
test_data[:5]

# Get featurizer

In [None]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

# Get datasets

In [None]:
test_dset = data.MoleculeDataset(test_data, featurizer=featurizer)
test_loader = data.MolGraphDataLoader(test_dset, shuffle=False)

# Calculate fingerprints

`models.MPNN.fingerprint(inputs : BatchMolGraph)` returns the learned fingerprints of a chemprop model given a batch of input molecules.

In [None]:
fingerprints = torch.Tensor()
for batch in test_loader:
    fingerprints = torch.cat((fingerprints, mpnn.fingerprint(batch.bmg)), 0)
    
fingerprints.shape

# Using fingerprints

In [None]:
fingerprints = fingerprints.detach()

pca = PCA(n_components=2)

principalComponents = pca.fit_transform(fingerprints)

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(1, 1, 1)

ax.scatter(principalComponents[:, 0], principalComponents[:, 1])
plt.show()