# Notebook for on-the-fly stability predictions

## Setup

In [8]:
import os
import time

import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.utils.data import DataLoader, ConcatDataset

import sys
sys.path.insert(0, './src')

from mrna_dataset import MrnaDisplayDataset
from mrna_classifier import MrnaBaggingPuClassifier

from rdkit import Chem
from rdkit.Chem.AtomPairs.Pairs import GetAtomPairFingerprintAsBitVect, GetAtomPairFingerprint
from rdkit.Chem.rdmolfiles import MolFromSequence, MolFromSmiles

MODEL_DIR = './src/models/prod/prod_0.01' # change to use different model
model = MrnaBaggingPuClassifier(load_path=MODEL_DIR)

DATA_DIR = './data/' # where ap_features.csv is located

def get_ap(seq):
    """
    Returns a pandas series of AP features.
    Includes features from ap_features.csv (which are zero
    if they don't exist in the molecule)
    """
    rep = pd.Series(GetAtomPairFingerprint(MolFromSequence(seq)).GetNonzeroElements())
    features = set(pd.read_csv(os.path.join(DATA_DIR, 'ap_features.csv'), dtype=np.int64, header=None).values.flatten())
    existing = set(rep.index)
    zero = list(features - existing)
    rep = rep.loc[list(features & existing)]
    for z in zero:
        rep.at[z] = 0
    return rep.loc[sorted(list(features))]


def get_ap_smiles(smiles):
    rep = pd.Series(GetAtomPairFingerprint(MolFromSmiles(smiles)).GetNonzeroElements())
    features = set(pd.read_csv(os.path.join(DATA_DIR, 'ap_features.csv'), dtype=np.int64, header=None).values.flatten())
    existing = set(rep.index)
    zero = list(features - existing)
    rep = rep.loc[list(features & existing)]
    for z in zero:
        rep.at[z] = 0
    return rep.loc[sorted(list(features))]


def predict(seq, smiles=False):
    """
    Run prediction for a peptide sequence `seq`
    """
    if smiles:
        ft = get_ap_smiles(seq)
    else:
        ft = get_ap(seq)
    x = torch.tensor(np.array([ft]), dtype=torch.float32)
    pred = np.round(model.predict_proba(x).item(), 5)
    return pred.item()

## Predict

In [13]:
peptide = 'PGWLSE'
predict(peptide)

0.42641

In [15]:
peptide = 'PPPPPP'
predict(peptide)

0.8605

In [14]:
# Example cyclic peptide: amide bond between N and C for sequence PGWLSE
peptide_smiles = 'N31[C@@]([H])(CCC1)C(=O)NCC(=O)N[C@@]([H])(CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@]([H])(CC(C)C)C(=O)N[C@@]([H])(CO)C(=O)N[C@@]([H])(CCC(=O)O)C3(=O)'
predict(peptide_smiles, smiles=True)

0.66503