# Joint embedding of fragmentation spectra and chemical compounds

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import keras
from keras.callbacks import History, ReduceLROnPlateau, EarlyStopping
from keras import backend as K
from livelossplot import PlotLossesKeras

from rdkit.Chem import AllChem
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole #Needed to show molecules
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions #Only needed if modifying default
from rdkit.Chem import Draw
from rdkit import Chem, DataStructs

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, normalize

from rdkit.Chem.Fingerprints import FingerprintMols
from rdkit.Chem import MACCSkeys

import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import csv

from functions import *

### Load GNPS+Massbank data

Data contains SMILES of known molecules and also their fragmentation spectra from massbank + gnps.

Collision energies are merged.

In [None]:
import pickle
with open('../data/gnps_massbank_data.p', 'rb') as f:
    data = pickle.load(f)

In [None]:
data['spectra'].shape

In [None]:
# with open('smiles_list.smi', 'w') as f:
#     for smile in data['smiles']:
#         f.write(smile + '\n')

### Create spectra embedding

Load a simple dense model. This was trained on the training data and is used to map

- From: fragmentation spectra 
- To: 100-dimensional representation

TODO: represent spectra as their LDA topic decomposition

In [None]:
input_spectra_autoencoder = keras.models.load_model('../models/spectra_autoencoder_gnps_massbank.h5')
input_spectra_encoder = keras.models.load_model('../models/spectra_encoder_gnps_massbank.h5')
input_spectra_decoder = keras.models.load_model('../models/spectra_decoder_gnps_massbank.h5')
input_spectra_autoencoder.summary()
svg = plot_model_in_notebook(input_spectra_autoencoder)
svg

In [None]:
spectra_latent = input_spectra_encoder.predict(data['spectra'])
print(spectra_latent.shape)

In [None]:
# spectra_decoded = input_spectra_decoder.predict(spectra_latent)
# for idx in range(10):
#     pos = np.nonzero(data['spectra'][idx])
# #     print(data['vocab'][pos])
#     plt.plot(data['vocab'], data['spectra'][idx])
#     plt.plot(data['vocab'], -spectra_decoded[idx])
#     plt.show()

### Create  Fingerprints of Molecules

In [None]:
def smiles_to_fingerprints(smiles):
    max_length = 0
    valid_idx = []
    fingerprints = []
    for i in range(len(smiles)):
        smile = smiles[i]
        mol = Chem.MolFromSmiles(smile)
        try:
            # fp = FingerprintMols.FingerprintMol(mol)
            # fp = MACCSkeys.GenMACCSKeys(mol)
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 3, nBits=2048)
            bits = np.array(list(fp.GetOnBits()))
            if max(bits) > max_length:
                max_length = max(bits)
            valid_idx.append(i)
            fingerprints.append(bits)
        except Exception:
            print('Invalid smile', smiles[i])
            continue

    fingerprint_arr = np.zeros((len(fingerprints), max_length+1))
    for i in range(len(fingerprints)):
        fingerprint = fingerprints[i]
        for bit in fingerprint:
            fingerprint_arr[i][bit] = 1

    return fingerprint_arr, valid_idx

In [None]:
# fingerprint_arr, valid_idx = smiles_to_fingerprints(data['smiles'])
fingerprint_arr, valid_idx = smiles_to_fingerprints(data['smiles'])
assert len(data['smiles']) == len(data['spectra'])
spectra = data['spectra'][valid_idx]
smiles = data['smiles'][valid_idx]

In [None]:
# # load simon's fingerprint
# fprints = {}
# with open('../data/smiles_sub.csv','r') as f:
#     reader = csv.reader(f)
#     for line in reader:
#         fprints[line[0]] = [int(i) for i in line[1:]]

In [None]:
# fingerprint_arr = np.zeros((len(smiles), 306+1))
# for i in range(len(smiles)):
#     smile = smiles[i]
#     fingerprint = fprints[smile]
#     for bit in fingerprint:
#         fingerprint_arr[i][bit] = 1

In [None]:
spectra_latent = input_spectra_encoder.predict(spectra)

In [None]:
fingerprint_arr.shape

In [None]:
spectra_latent[0]

In [None]:
fingerprint_arr[0]

### Try joint embedding

Objective: build a model that projects spectra and molecules in the same representation space, so that a spectra is close to its molecule in that space, and far away from dissimilar spectra and dissimilar molecules.

Each training point is a triplet of:
- fragmentation spectra, also called an anchor
- compound correctly associated to that spectra, also called the positive example
- compound incorrectly associated to that spectra, also called the negative example

During training, we compute the scores of the anchor to the positive and negative examples (dot products). The optimisation objective is to maximise total positive scores and minimise total negative scores. Then for each training step, we shuffle the negative examples randomly.

See:

- https://pageperso.lis-lab.fr/benoit.favre/dl4nlp/tutorials/05-caption.pdf
- https://arxiv.org/abs/1511.06078

In [None]:
def get_embedding_model(input_dim_spectra, input_dim_molecule, embedding_dim):
    
    spectra_input = Input(shape=(input_dim_spectra,), name='spectra_input')
    smile_input = Input(shape=(input_dim_molecule,), name='positive_molecule')
    noise_input = Input(shape=(input_dim_molecule,), name='negative_molecule')
        
    spectra_pipeline = Dense(embedding_dim, use_bias=False, name='spectra_weights')(spectra_input)
    spectra_pipeline = BatchNormalization(name='bn1')(spectra_pipeline)
    spectra_pipeline = Activation('relu', name='relu1')(spectra_pipeline)
    spectra_pipeline = Dense(embedding_dim, activation='relu', name='spectra_weights2')(spectra_pipeline)

    smile_dense1 = Dense(embedding_dim, use_bias=False, name='molecule_weights') 
    bn = BatchNormalization(name='bn2')
    activation = Activation('relu', name='relu2')
    smile_dense2 = Dense(embedding_dim, activation='relu', name='molecule_weights2')
    smile_pipeline = smile_dense2(activation(bn(smile_dense1(smile_input))))
    noise_pipeline = smile_dense2(activation(bn(smile_dense1(noise_input))))        

    positive_pair = dot([spectra_pipeline, smile_pipeline], axes=1)
    negative_pair = dot([spectra_pipeline, noise_pipeline], axes=1)
    concat_output = concatenate([positive_pair, negative_pair])
    embedding_model = Model(inputs=[spectra_input, smile_input, noise_input], outputs=concat_output)

    l2_norm1 = Lambda(lambda  x: K.l2_normalize(x, axis=1))   
    l2_norm2 = Lambda(lambda  x: K.l2_normalize(x, axis=1))       
    spectra_encoder = Model(inputs=spectra_input, outputs=l2_norm1(spectra_pipeline))
    smile_encoder = Model(inputs=smile_input, outputs=l2_norm2(smile_pipeline))
    
    # also see https://github.com/keras-team/keras/issues/150
    def custom_loss(y_true, y_pred):
        positive = y_pred[:,0]
        negative = y_pred[:,1]
        return K.sum(K.maximum(0., 1. - positive + negative))
    
    def accuracy(y_true, y_pred):
        positive = y_pred[:,0]
        negative = y_pred[:,1]
        return K.mean(positive > negative)
    
    embedding_model.compile(loss=custom_loss, optimizer='adam', metrics=[accuracy])
    return embedding_model, spectra_encoder, smile_encoder

In [None]:
EMBEDDING_DIM = 50
input_dim_spectra = spectra_latent.shape[1]
input_dim_molecule = fingerprint_arr.shape[1]
joint_embedding_model, spectra_encoder, smile_encoder = get_embedding_model(input_dim_spectra, 
                                                                            input_dim_molecule, 
                                                                            EMBEDDING_DIM)
joint_embedding_model.summary()
plot_model_in_notebook(joint_embedding_model)

Preparing training and test data

In [None]:
def shuffle_together(a, b, c, d):
    assert len(a) == len(b)
    assert len(a) == len(c)
    assert len(a) == len(d)
    p = np.random.permutation(len(a))
    return a[p], b[p], c[p], d[p]

spectra, spectra_latent, smiles, fingerprint_arr = shuffle_together(spectra, spectra_latent, smiles, fingerprint_arr)

In [None]:
fingerprint_arr = normalize(fingerprint_arr, norm='l2', axis=1)
spectra_latent = normalize(spectra_latent, norm='l2', axis=1)

In [None]:
pos = int(len(spectra_latent) * 0.8)
remaining = len(spectra_latent) - pos
print(pos, remaining)

noise = np.copy(fingerprint_arr)
fake_labels = np.zeros((len(spectra_latent), 1))

X_train = [spectra_latent[:pos], fingerprint_arr[:pos], noise[:pos]]
Y_train = fake_labels[:pos]
X_test = [spectra_latent[-remaining:], fingerprint_arr[-remaining:], noise[-remaining:]]
Y_test = fake_labels[-remaining:]

spectra_train = spectra[:pos]
spectra_test = spectra[-remaining:]
smiles_train = smiles[:pos]
smiles_test = smiles[-remaining:]

In [None]:
print(X_train[0].shape, X_train[1].shape, X_train[2].shape)
print(X_test[0].shape, X_test[1].shape, X_test[2].shape)

In [None]:
# rlr = ReduceLROnPlateau(monitor='val_loss', factor=0.5,patience=10, min_lr=0.000001,
#                         verbose=1, epsilon=1e-5)
# early_stop = EarlyStopping(monitor='val_loss', min_delta=0, patience=3, verbose=1, mode='auto')
# callbacks = [rlr, early_stop]

# tensorboard = keras.callbacks.TensorBoard(log_dir='./graph', histogram_freq=0,  
#           write_graph=True, write_images=True)
# callbacks = [rlr, early_stop, tensorboard]
# if is_notebook():
#     callbacks.append(PlotLossesKeras())

TODO: generate mini-batch properly https://stackoverflow.com/questions/48568062/keras-custom-infinite-data-generator-with-shuffle

In [None]:
# actual training
for epoch in range(1000):
    if epoch % 100 == 0:
        print('\nIteration %d' % epoch)
        verbose=1
    else:
        verbose=0
    np.random.shuffle(noise) # shuffle mismatched smiles
    joint_embedding_model.fit(X_train, Y_train,
        validation_data=[X_test, Y_test], epochs=1,
        batch_size=32, verbose=verbose)

In [None]:
# joint_embedding_model.save('../models/joint_embedding_gnps_massbank.h5')
# spectra_encoder.save('../models/joint_spectra_encoder_gnps_massbank.h5')
# smile_encoder.save('../models/joint_smile_encoder_gnps_massbank.h5')

### Visualise the joint embedding results

In [None]:
def visualise_embedding(spectra_data, molecule_data, spectra_encoder, smile_encoder):
    
    embedded_spectra = spectra_encoder.predict(spectra_data)
    embedded_molecules = smile_encoder.predict(molecule_data)
    embedded_combined = np.concatenate([embedded_spectra, embedded_molecules], axis=0)
    
    PCA_COMPONENTS = 25
    pca = PCA(n_components = PCA_COMPONENTS)
    latent_proj = pca.fit_transform(embedded_combined)
    covariance = pca.get_covariance()
    evr = pca.explained_variance_ratio_
    print('Explained variations -- first two PCs: %.2f' % (evr[0] + evr[1]))
    print('Explained variations -- all components: %.2f' % np.sum(evr))
    print(evr)
        
    plt.figure(figsize=(8, 8))
    plt.scatter(latent_proj[0:len(embedded_spectra), 0], latent_proj[0:len(embedded_spectra), 1], marker='x', c='red', s=1)
    plt.scatter(latent_proj[len(embedded_spectra)+1:, 0], latent_proj[len(embedded_spectra)+1:, 1], marker='.', c='blue', s=1)
    plt.title('Joint embedding of fragmentation spectra (red) and molecules (blue)')

Visualise embedding on training data

In [None]:
visualise_embedding(X_train[0], X_train[1], spectra_encoder, smile_encoder)

Visualise embedding on testing data

In [None]:
visualise_embedding(X_test[0], X_test[1], spectra_encoder, smile_encoder)

### Evaluation

In [None]:
embedded_spectra = spectra_encoder.predict(X_test[0])
embedded_molecules = smile_encoder.predict(X_test[1])

In [None]:
x = embedded_spectra[0]
print(x)
print(np.dot(x, x))

In [None]:
plt.rcParams['figure.figsize'] = (8,8)

In [None]:
scores = np.dot(embedded_spectra, embedded_molecules.T)
print(scores.shape)
plt.matshow(scores)
plt.colorbar()
plt.xlabel('molecules')
plt.ylabel('spectra')
plt.title('Dot product')

In [None]:
def recall_at(n, scores, verbose=False):
    found = 0.0
    total = len(scores)
    results = {}
    for i in range(total):
        row = scores[i]
        max_idx = row.argsort()[-n:][::-1]
        if i in max_idx:
            found += 1
            correct = True
        else:
            correct = False
        retrieved = list(zip(max_idx, row[max_idx]))
        if verbose:
            print(i, correct, retrieved)
        results[i] = retrieved
    precision = found/total
    return precision, found, total, results

In [None]:
prec, found, total, results = recall_at(10, scores)

In [None]:
print('Found %d/%d (%.2f)' % (found, total, prec))

In [None]:
def plot_spectra_and_molecule(idx, spectra, smiles):    
    pos = np.nonzero(spectra[idx])
    plt.plot(data['vocab'], spectra[idx])
    plt.show()
    smile = smiles[idx]    
    print(smile)
    mol = Chem.MolFromSmiles(smile)
    mol_drawing = Draw.MolToMPL(mol, size=(150, 150))
    plt.show()

In [None]:
def plot_results(idx, spectra_test, smiles_test, results):
    plt.rcParams['figure.figsize'] = (4,4)
    
    print('Query')
    plot_spectra_and_molecule(idx, spectra_test, smiles_test)
    
    print("Retrieved")
    retrieved = results[idx]
    for j, score in retrieved:
        print('Molecule %d score %.2f' % (j, score))
        plot_spectra_and_molecule(j, spectra_test, smiles_test)

In [None]:
# plot_results(12, spectra_test, smiles_test, results)

### Add decoy compounds

In [None]:
import pandas
import h5py
decoy_data = pandas.read_hdf('/Users/joewandy/Dropbox/Analysis/autoencoder/data/pubchem_100k.h5', 'table')
decoy_smiles = decoy_data['structure'].values

In [None]:
decoy_fingerprint_arr, valid_decoy_idx = smiles_to_fingerprints(decoy_smiles)

In [None]:
valid_decoy_smiles = decoy_smiles[valid_decoy_idx]

In [None]:
decoy_fingerprint_arr = normalize(decoy_fingerprint_arr, norm='l2', axis=1)

In [None]:
embedded_decoy_molecules = smile_encoder.predict(decoy_fingerprint_arr)

In [None]:
recalls = []
increase = 10000
decoy_counts = list(range(0, len(valid_decoy_smiles)+increase, increase))
for num_decoy in decoy_counts:
    combined_molecules = np.concatenate([embedded_molecules, embedded_decoy_molecules[0:num_decoy],], axis=0)    
    scores = np.dot(embedded_spectra, combined_molecules.T)
    recall, found, total, results = recall_at(10, scores)
    recalls.append(recall)
    print('%.2f %d/%d %s' % (recall, found, total, scores.shape))

In [None]:
plt.plot(decoy_counts, recalls, linestyle='--', marker='o', color='b')
plt.title('Recall@10 with increasing decoy compounds')
plt.ylabel('Recall@10')
plt.xlabel('#decoy')
plt.grid(b=True, which='both')
plt.yticks(np.arange(min(recalls), max(recalls)+0.05, 0.05))