In [None]:
import sys,os,argparse
from IPython.display import HTML

CONFIG_FILE = '.config_ipynb'

if os.path.isfile(CONFIG_FILE):
    print("Reading config file")
    with open(CONFIG_FILE) as f:
        sys.argv = f.read().split()
else:
    print("No config file found, using default values")
    sys.argv = ['evaluate.py', '--data_csv', "~/git/ppptr/combined_holdout.csv", '--checkpoint', 'prosit_transformer-val_loss=0.334872_epoch=005.ckpt']

parser = argparse.ArgumentParser()

parser.add_argument("--data_csv", type=str, help="CSV file containing Sequences, Encodings, mIRT, SpectraEncoding, Charge")
parser.add_argument("--checkpoint", type=str, help="Checkpoint file to use for evaluation")
args, unknown = parser.parse_known_args()

#args = parser.parse_args()

dict_args = vars(args)

"""
dict_args = {
    'data_csv': "~/git/ppptr/combined_holdout.csv",
    "checkpoint": '~/Downloads/prosit_transformer-epoch=20-step=14174.ckpt'
}

"""
print(dict_args)

In [None]:
from random import sample

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from scipy.cluster.hierarchy import dendrogram, linkage
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

import torch
from torch.utils.data import DataLoader

import pytorch_lightning as pl

from transprosit import model
from transprosit.datamodules import PeptideDataset
from transprosit import spectra
from transprosit import encoding_decoding
from transprosit import constants

pl.seed_everything(2020)

In [None]:
# Load model and dataloader/dataset
mod = model.PepTransformerModel.load_from_checkpoint(dict_args['checkpoint'])
print(model)

In [None]:
# "input_csv"
df = pd.read_csv(dict_args['data_csv'])
print(df)

# Uncomment this line on the next release ...
# df = df[[len(eval(x)) == constants.MAX_SEQUENCE for x in df["SequenceEncoding"]]].copy().reset_index()
df = df[[len(eval(x)) == 25 for x in df["SequenceEncoding"]]].copy().reset_index()
# Note that the input sequence does not really have to be fixed

# Subsample input dataframe to 2k observations, or less....
df = df.loc[sample(list(df.index), min(500, len(df)))].copy().reset_index()

# Generate a dataloader to 
ds = PeptideDataset(df)
dl = DataLoader(ds, batch_size=32)

In [None]:
# Get results
out_yhat_irts = []
in_yhat_irts = []
out_yhat_spectra = []
in_yhat_spectra = []
in_charges = []

with torch.no_grad():
    for batch in dl:
        encoded_sequence, charge, encoded_spectra, norm_irt = batch
        in_charges.append(charge)
        in_yhat_irts.append(norm_irt)
        in_yhat_spectra.append(encoded_spectra)

        yhat_irt, yhat_spectra = mod(encoded_sequence, charge)
        out_yhat_irts.append(yhat_irt)
        out_yhat_spectra.append(yhat_spectra)
        
out_yhat_irts = torch.cat(out_yhat_irts)
in_yhat_irts = torch.cat(in_yhat_irts)

out_yhat_spectra = torch.cat(out_yhat_spectra)
out_yhat_spectra[out_yhat_spectra < 0] = 0
out_yhat_spectra = (out_yhat_spectra.T * (1 / out_yhat_spectra.max(axis = 1).values)).T

in_yhat_spectra = torch.cat(in_yhat_spectra)
in_yhat_spectra[in_yhat_spectra < 0] = 0
in_yhat_spectra = (in_yhat_spectra.T * (1 / in_yhat_spectra.max(axis = 1).values)).T


# Visialize the comparisson of predictions and ground truth

In [None]:
plt.scatter(out_yhat_irts, in_yhat_irts, marker = ".", alpha = 0.1)

In [None]:
plt.scatter(in_yhat_spectra.flatten(), out_yhat_spectra.flatten(), marker = ".", alpha = 0.1)

In [None]:
dist = torch.nn.PairwiseDistance(2, keepdim=True)
plt.hist(dist(in_yhat_spectra, out_yhat_spectra).flatten().numpy(), bins=100)
plt.show()

In [None]:

for i in range(200):
    predicted = encoding_decoding.decode_fragment_tensor(df["Sequences"][i], out_yhat_spectra[i,:])
    ground_truth = encoding_decoding.decode_fragment_tensor(df["Sequences"][i], in_yhat_spectra[i,:])

    plt.title(f"{df['Sequences'][i]}, {df['Charges'][i]}")
    plt.vlines(0, -1, 1, color = "gray")

    plt.vlines(predicted['Mass'], 0, predicted['Intensity'], color = "blue")
    plt.vlines(ground_truth['Mass'], 0, -ground_truth['Intensity'], color="red")
    plt.axhline(0, color='black')
    plt.show()



# Visualize embeddings for the aminoacids

In [None]:
mod.encoder

In [None]:
plt.imshow(mod.encoder.aa_encoder.weight.data, aspect = "auto")
plt.colorbar()

In [None]:
AAS = list(constants.ALPHABET.keys())
AA_NAMES = ["#"] + AAS

aa_weight_df = pd.DataFrame(mod.encoder.aa_encoder.weight.data.numpy(), index = AA_NAMES)
aa_weights = aa_weight_df.loc[[x for x in AAS]]
p = sns.clustermap(
    aa_weights, z_score = None, col_cluster=True,
    cmap = 'viridis',
    figsize = (5,5), dendrogram_ratio = (0.1, 0.1),
    method = "ward", vmin=-0.05, vmax=0.05)


In [None]:
# Calculate the distance between each sample
Z = linkage(aa_weights, 'ward')
 
# Make the dendro
plt.subplots(figsize=(3, 6))
dendrogram(Z, labels=aa_weights.index, orientation="left", color_threshold=1.8, above_threshold_color='grey', distance_sort='ascending')
plt.show()

In [None]:
x = StandardScaler().fit_transform(aa_weights.values.T).T

pca = PCA(n_components = 2)
pca_weights = pca.fit_transform(x)
print(pca_weights.shape)

plt.subplots(figsize=(5, 5))
plt.scatter(pca_weights[...,0], pca_weights[...,1])
for i in range(0, len(aa_weights)):
    plt.text(pca_weights[i,0] + 0.5, pca_weights[i,1], aa_weights.index[i])

# Visualizing the encodings of the ions

In [None]:
mod.decoder

In [None]:
frag_weight_df = pd.DataFrame(mod.decoder.trans_decoder_embedding.weight.data.numpy(), index = constants.FRAG_EMBEDING_LABELS)
frag_weights = frag_weight_df.loc[[x for x in constants.FRAG_EMBEDING_LABELS]]
p = sns.clustermap(
    frag_weights, z_score = None, col_cluster=True,
    cmap = 'viridis',
    figsize = (10,10), dendrogram_ratio = (0.1, 0.1),
    method = "ward")


In [None]:
# Calculate the distance between each sample
Z = linkage(frag_weights, 'ward')
 
# Make the dendro
plt.subplots(figsize=(5, 34))
dendrogram(Z, labels=frag_weights.index, orientation="left", color_threshold=40, above_threshold_color='grey', distance_sort='ascending', leaf_font_size = 12)
plt.show()

# Visializing activations on different layers of the encoder layer

In [None]:
encoder_visualisation = {}
decoder_visualisation = {}

def make_hook(target):
    def hook_fn(m, i, o):
        target[m] = o
    
    return hook_fn

handles = []
encoder_hook = make_hook(encoder_visualisation)
for layer in range(0, len(mod.encoder.transformer_encoder.layers)):
    print(f"Adding hook to encoder layer: {layer}")
    handle = mod.encoder.transformer_encoder.layers[layer].self_attn.register_forward_hook(encoder_hook)
    handles.append(handles)

decoder_hook = make_hook(decoder_visualisation)
for layer in range(0, len(mod.decoder.trans_decoder.layers)):
    print(f"Adding hook to decoder layer: {layer}")
    handle = mod.decoder.trans_decoder.layers[layer].self_attn.register_forward_hook(decoder_hook)
    handles.append(handles)

"""
# Use this to remove the handles
for h in handles:
    h.remove()
"""

In [None]:
for batch_num, batch in enumerate(dl):
    if batch_num == 25:
        encoded_sequence, charge, encoded_spectra, norm_irt = batch
        break

yhat_irt, yhat_spectra = mod(encoded_sequence, charge)

## Plotting the activation on the transformer encoder

In [None]:
list(encoder_visualisation.values())[0][1].shape # shape is [batch, 25, 25]

sequences = []
last_seq = ""
for pep in range(10):
    df_position = 32*batch_num + pep
    sequence = df['Sequences'][df_position]
    if sequence == last_seq:
        continue
    
    last_seq = sequence
    
    fig, axs = plt.subplots(1,4, figsize=(25, 6))

    print(sequence)
    print(len(sequence))
    # print(encoded_sequence[pep])
    recoded_seq = "".join([(["_"] + AAS)[i] for i in encoded_sequence[pep]])
    print(recoded_seq)
    sequences.append(recoded_seq)

    fig.suptitle(sequence + " " + str(df['Charges'][df_position]) + "+")

    for i in range(4):
        axs[i].set_title(f'Layer {i}')
        axs[i].imshow(list(encoder_visualisation.values())[i][1][pep,:,:].detach().numpy()[0:len(sequence),0:len(sequence)], vmin = 0, vmax = 0.2)
        axs[i].set_xticks(np.arange(len(sequence)))
        axs[i].set_yticks(np.arange(len(sequence)))
        axs[i].set_xticklabels([x for x in sequence])
        axs[i].set_yticklabels([x for x in sequence])
        
        for ii in range(len(sequence)):
            for j in range(len(sequence)):
                if ii != j:
                    continue
                axs[i].text(j, ii, "O", ha="center", va="center", color="w")
        
    
    plt.show()
    
sequences

## Plotting the activation on the transformer encoder

In [None]:
[x[1].shape for x in list(decoder_visualisation.values())] 

In [None]:
"""
>>> [ [x[0].shape, x[1].shape] for x in list(decoder_visualisation.values()) ]

[[torch.Size([150, 32, 516]), torch.Size([32, 150, 150])],
 [torch.Size([150, 32, 516]), torch.Size([32, 150, 150])],
 [torch.Size([150, 32, 516]), torch.Size([32, 150, 150])],
 [torch.Size([150, 32, 516]), torch.Size([32, 150, 150])]]

>>> # would be all the LAYER ACTIVATIONS
>>> activations_averages = [x[1] for x in list(decoder_visualisation.values())] 

>>> # would be all the SELF ATTENTION AVERAGES
>>> self_attn_averages = [x[1] for x in list(decoder_visualisation.values())] 

>>> # would give a list of the self attention layers for the first peptide in the batch.
>>> [x[0, ...] for x in self_attn_averages] 
...

>>> [x[0, ...].shape for x in self_attn_averages] 
[torch.Size([150, 150]),
 torch.Size([150, 150]),
 torch.Size([150, 150]),
 torch.Size([150, 150])]


Each attention head returns

Outputs:
        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
          E is the embedding dimension.
        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
          L is the target sequence length, S is the source sequence length.

> Actual attention output
attn_output
> Average self attention
attn_output_weights.sum(dim=1) / num_heads

here, N = 32; E=512, L=150, S=150

"""

plt.rcParams['font.family'] = 'monospace'
sequences = []
last_seq = ""
labs = constants.FRAG_EMBEDING_LABELS

self_attn_averages = [x[1] for x in list(decoder_visualisation.values())] 

for pep in range(10):
    df_position = 32*batch_num + pep
    sequence = df['Sequences'][df_position]
    if sequence == last_seq:
        continue
    
    last_seq = sequence
    fig, axs = plt.subplots(2,5, figsize=(25, 10))

    print(sequence)
    print(len(sequence))
    # print(encoded_sequence[pep])
    recoded_seq = "".join([(["_"] + AAS)[i] for i in encoded_sequence[pep]])
    print(recoded_seq)
    sequences.append(recoded_seq)

    fig.suptitle(sequence + " " + str(df['Charges'][df_position]) + "+")

    spec = eval(df['SpectraEncoding'][df_position])
    indices = [x for x in range(len(spec))]

    axs[0,4].vlines(x = indices, ymin=[0 for _ in range(len(indices))], ymax=spec, color="black")
    axs[0,4].set(yticklabels=[])
    axs[0,4].tick_params(left=False)
    axs[0,4].set_title(f'Encoded Ground Truth Spectrum')
    
    for ind, inten in zip(indices, spec):
        if inten < 0.1:
            continue
        axs[0,4].text(x = ind + 5, y = inten - 0.01, s = labs[ind], color = "blue")


    peptide_self_attn_avgs = [x[pep, ...] for x in self_attn_averages] 
    
    for i in range(4):
        axs[0,i].set_title(f'Layer {i}')
        axs[0,i].imshow(peptide_self_attn_avgs[i].detach().numpy(), vmin = 0., vmax = 0.03)
        
        out_labs = []
        vals = []
        act_vals = peptide_self_attn_avgs[i].detach().numpy()

        for xind in range(constants.NUM_FRAG_EMBEDINGS):
            for yind in range(constants.NUM_FRAG_EMBEDINGS):
                if xind == yind:
                    continue
                vals.append(act_vals[xind, yind])
                out_labs.append(f"{labs[xind] : <6}"+f"{labs[yind] : >6}")

        interaction_df = pd.DataFrame({'Interaction': out_labs, 'Value': vals})

        ordered_df = interaction_df.sort_values(by='Value')
        plotting_df = pd.concat([ordered_df[:5], ordered_df[-15:]]).copy()
        my_range=range(1,len(plotting_df.index)+1)

        # Vertical lollipop chart.
        axs[1,i].hlines(y=my_range, xmin=0, xmax=plotting_df['Value'], color='black')
        axs[1,i].set_yticks(my_range)
        axs[1,i].set_yticklabels(plotting_df['Interaction'])
        axs[1,i].plot(plotting_df['Value'], my_range, "D")

    plt.tight_layout()
    plt.show()
    
sequences