# This notebook contains the pipeline for setting up MPNN prefilters

# Step 0: Imports and Variable Set-up

### Imports

In [None]:
### DESIGN-00

###################################
#### Same as rifdock notebook #####
###################################
# Utilities
import sys, os, time, glob, random, shutil, subprocess, math

# Data Processing
import numpy as np
import pandas as pd

# string is a utility library which contains some useful string constants. For instance, string.digits == '0123456789'
import string

# re is a module that enables the use of "regular expressions," a common format for matching strings to patterns
import re

# matplotlib, with its submodule pyplot (or plt), is a popular tool for plotting data and making publication-quality graphs.
import matplotlib.pyplot as plt

# seaborn, commonly abbreviated as sns, is a library used to make publication-quality plots with lots of options; it is built on top of matplotlib.
import seaborn as sns

# sklearn, aka sci-kit learn, is a library for simple machine learning in python. We use it for a few statistical calculations.
from sklearn.metrics import roc_curve, auc
from sklearn.cluster import KMeans

# scipy is a general science module for python with many useful statistical tools. We use it to fit curves to data
from scipy.optimize import curve_fit

# Functions written by Chris Norn / Cameron for maximum likelihood fitting are stored in this module
from maximum_likelihood import *

# Pyrosetta is the python-based interface for running Rosetta for protein modeling and design
import pyrosetta
from pyrosetta import *
from pyrosetta.rosetta import *

#######################################
#######################################

# This last bit is here to prevent certain functions from returning a bunch of warnings we don't care about
import warnings
warnings.filterwarnings('ignore')

------------------------------------------------------------------------------------------------------------------------<br>
This cell collects the outputs from the "predictor" script and saves it all to a single `.csv` (per target DNA) <br>
NOTE: this cell takes a few minutes to run, due to pandas inefficiencies

In [None]:
### PRE-01

output_path = f'prefilter.csv'


# Collect the paths to all output .csv files
csv_fs = glob.glob(f'{path_to_prefilter_calibration_outputs}/*.csv')

# load a dataframe
df = pd.concat([pd.read_csv(f) for f in csv_fs], sort=False)

# Then save the DataFrame to a csv
df.to_csv(output_path)

------------------------------------------------------------------------------------------------------------------------<br>
This cell loads in the outputs from the previous cell and filters out very bad results from the dataset so they don't skew the results

In [None]:
### PRE-02

# Read in the DataFram from where we saved it
df = pd.read_csv(f'prefilter.csv')


# Filter only for cases with negative ddg
# ddg is the change in Gibbs free energy of binding between protein and DNA (as computed by Rosetta)
# This notation selects only the rows of a DataFrame for which the [bracketed] statement evaluates to True
df = df[ df['ddg'] < 0 ]


# Filter also for contact molecular surface > 10
# This metric measures the surface area along which the protein and DNA are in contact with each other
df = df[ df['contact_molecular_surface'] > 10 ]

------------------------------------------------------------------------------------------------------------------------<br>
This cell splits the data into "predictor" and "pilot" <br>
The "predictor" data have only fast metrics. <br>
The "pilot" data have both the fast and slow metrics, which are actually the same score terms with one key difference: <br>
&emsp; "pilot" examples were run through a Rosetta "relax" protocol, which moves the protein backbone and sidechains to minimize free energy. <br>
We will try to predict whether the pilot data pass the full filters based on their fast metrics only. <br>

In [None]:
### PRE-03

# Define a function that can add a suffix to the names of the columns in a DataFrame
# For instance, a column named "FOO" could become "FOO_bar"
# This is useful for combining two dataframes with the same column names, such as our "predictor" and "pilot" data
def suffix_all_columns(df, suffix):
    cols = list(df.columns)
    for i in range(len(cols)):
        cols[i] = cols[i] + suffix
    df.columns = cols
    return df

predictor_dfs = {}
pilot_dfs = {}
    
# Divide the DataFrame into our two sets based on the value in the column 'is_prefilter', which contains boolean values
predictor_df = df[   df['is_prefilter'] ]
pilot_df     = df[ ~ df['is_prefilter'] ]
#                 ^^^ the ~ operator performs logical negation, element-wise on pandas objects

# Add the "_pred" suffix to the column names of the predictor DataFrame
predictor_df = suffix_all_columns(predictor_df, "_pred")

# Restore the un-suffixed "tag" column
predictor_df['tag'] = predictor_df['tag_pred']


# Merge the predictor data back into the pilot dataframe
pilot_df = pd.merge(pilot_df, predictor_df, how='inner', on='tag')

# Print the sizes of the datasets
print("Length of predictor dataframe:\t", len(predictor_df))
print("Length of pilot dataframe:\t", len(pilot_df))
print()

------------------------------------------------------------------------------------------------------------------------<br>
This cell is now going to apply a specified set of filters to the dataset, using the available metrics. <br>
Our goal is, essentially, to distinguish good docks/designs from bad ones based on the quality of the interface. <br>
The key calculations we base this on are "ddg" and "contact_molecular_surface". <br>

The aptly-named `contact_molecular_surface` (shorthand `cms`) is the surface area in which the protein contacts the target. <br>

Because of inherent Lennard-Jones attraction between molecules, there is a bias for larger interfaces to also have lower `ddg`s. <br>
To counteract this, we filter on the ratio between `ddg` and `cms`, which you can think of as the "quality" or "density" of the interface.

In [None]:
### PRE-04

# The format for the dictionary of features for filtering is:
#   'NAME_IN_PILOT_DF'           :  [ CUTOFF , KEEP_ABOVE , NAME_IN_PREDICTOR_DF , IS_INT],
#      (string)                       (float)  (boolean)   (string)               (boolean)

# For instance, the following would say that we should filter on the variable "ddg", 
#     which is called "ddg_pred" in the predictor DF, which is not a discrete variable (integer),
#     and we should keep only cases with scores *below* -10 :
example_dict = {
    'ddg' :  [ -10  ,   False   , "ddg_pred" , False  ]
}
# and now we delete the example variable
del example_dict

# ...and here is the actual dictionary to use:
terms_and_cuts = {
    'ddg_over_cms'               :  [ -0.078 ,   False   , "ddg_over_cms_pred"              , False ],
    'contact_molecular_surface'  :  [   225  ,    True   , "contact_molecular_surface_pred" , False ],
}
# This is one place you could play around with the numbers to see what happens, if you'd like


# Next, we are going to actually apply our filters one by one and keep track of the pass-rates
# If you see that a filter is removing 99.99% of designs or doing nothing, you may want to change your cutoffs.
print('-------------------------------------------------')

# Filter all the terms and print the thresholds
ok_terms = []
for pilot_term in terms_and_cuts:
    cut, keep_high, term, is_int = terms_and_cuts[pilot_term]
    ok_term = pilot_term + "_ok"
    if ( keep_high ):
        score_df[ok_term] = score_df[pilot_term] >= cut
    else:
        score_df[ok_term] = score_df[pilot_term] <= cut

    ok_terms.append(ok_term)

    print("%30s: %6.2f"%(pilot_term, cut))

# Print the pass rates for each term
print()
score_df['orderable'] = True
for ok_term in ok_terms:
    score_df['orderable'] = score_df['orderable'] & score_df[ok_term]
    print("%30s: %5.0f%% pass-rate"%(ok_term.replace("_ok", ""), score_df[ok_term].sum() / len(score_df) * 100))

# print the overall pass rate   
print()
print("%30s: %i   -- %.2f%%"%('Passing', score_df['orderable'].sum(), (100*score_df['orderable'].sum() / len(score_df))))


------------------------------------------------------------------------------------------------------------------------<br>
This cell is going to plot the distributions of our filter metrics in the pilot and predictor datasets. <br>
The main take-away you should see is that the scores are better in the pilot dataset because of the Rosetta relax step. <br>

In [None]:
### PRE-05

# Get the names of metrics we care about from the filter dictionary
relevant_features  = terms_and_cuts.keys()

# Set up the axes we will be plotting on using matplotlib
ncols = len(relevant_features)
nrows = len(seqs)
(fig, axs) = plt.subplots(
    ncols=ncols, nrows=nrows, figsize=[6*ncols,3*nrows]
)
axs = axs.reshape(-1)

# Make all of the plots
i = 0

for metric in terms_and_cuts:
    metric_pred = terms_and_cuts[metric][2]

    # seaborn's distplot essentially makes a histogram. 
    # It also plots the kernel density estimate (kde), which is sort of a smooth fit to the histogram
    sns.distplot(pilot_dfs[seq][metric], ax=axs[i], color='blue', label='pilot')
    sns.distplot(predictor_dfs[seq][metric_pred], ax=axs[i], color='orange', label='predictor')

    # add legend and title
    axs[i].legend()
    axs[i].set_title(seq)

    # keep track of which axis we are plotting on with a simple incrementor
    i += 1

# Format our plots for better readability / aesthetics
sns.despine()
plt.tight_layout()

------------------------------------------------------------------------------------------------------------------------<br>
Here, we want to look at some of the designs that are passing our filters and decide if they actually look good. <br>
This cell will print out the tags for the passing designs into `tags.list`. <br>
For now you'll need to work out where the actual corresponding pdbs are. (FIX THIS!) <br>

------------------------------------------------------------------------------------------------------------------------<br>
Next we are going to set up our pre-filter equation. <br>
We will use this to decide which designs are worth running Rosetta relax (slow metrics) on. <br>
The equation is fit by multiple-exponential regression. <br>
The input variables are the "fast" metrics, which are the values of the features in the "predictor" dataset. <br>
The output variable, which we are trying to predict, is the *probability that a corresponding pilot model would pass the full metrics* <br>
<br>

The result of this cell is an equation which is saved to a text file (`filter_eq_{seq}.txt`) so that we can use it later during the full-scale MPNN design step.

In [None]:
### PRE-06

print('-------------------------------------------------')

train_df = pilot_df.copy()
predictor_df = predictor_df.copy()

all_indices = list(range(len(train_df)))
test_indices = []

# This sets up maximum likihood method
_, prob_array, eqs = train_and_predict_mle(train_df, all_indices, test_indices, terms_and_cuts, "predict", predictor_df, True)

print("")
print('- predictor_filters ' + " , ".join( terms_and_cuts[x][2][:-5] for x in list(terms_and_cuts)))
print('- equation = "-' + "*".join(eqs) + '"')
with open(f'filter_eq.txt','w') as f_out:
    f_out.write('"-' + "*".join(eqs) + '"\n')
    

In [None]:
### PRE-07

print('-------------------------------------------------')


# Apply the mle method to the training set to get a feel for how well it worked
apply_prob_arrays(score_df, prob_array, "predict")
score_df['log_predict'] = np.log10(score_df['predict'])
plot_df = score_df
fpr,tpr,thresholds = roc_curve(plot_df["orderable"], plot_df["predict"])

# Make the ROC plot
fig, ax = plt.subplots(figsize=(6,4))
plt.title('ROC Curve')
plt.plot(fpr, tpr, 'r', label = "Predictor auc = %.2f"%(auc(fpr, tpr)))
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'k--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.xlabel("False positive rate")
plt.ylabel("True positive rate")
plt.tight_layout()
plt.show()

# Make the other plots

# Make a cool graph to get a feel for where different predict values lie
df_c=score_df.sort_values("predict", ascending=False)
df_c['total_orderable'] = df_c['orderable'].cumsum()
df_c['log_predict'] = np.log10(df_c['predict'])


cmap = sns.cubehelix_palette(as_cmap=True)

lowb = np.percentile(df_c['log_predict'], 2)
upb = np.percentile(df_c['log_predict'], 98)
f, ax = plt.subplots(figsize=(7, 4))
points = ax.scatter(range(len(df_c)), df_c['total_orderable'], c=df_c['log_predict'], vmin=lowb,vmax=upb, cmap=cmap)
plt.setp(ax.get_xticklabels(), visible=False)
plt.setp(ax.get_yticklabels(), visible=False)
cb = f.colorbar(points)
cb.set_label("log_predict")

plt.show()

apply_prob_arrays(predictor_df, prob_arrays[seq], "predict")
predictor_df['log_predict'] = np.log10(predictor_df['predict'])
minimum = predictor_df['log_predict'].min()
maximum = predictor_df['log_predict'].max()
steps = 20
step = (maximum - minimum)/steps
probability_mapping_x = np.arange(minimum, maximum, step)
probability_mapping_y = []

last_prob = None
for step_prob in probability_mapping_x:
    upper = step_prob + step
    total = score_df[(score_df['log_predict'] > step_prob) & (score_df['log_predict'] < upper)]
    orderable = total['orderable'].sum()
    if ( len(total) < 10 ):
        prob = last_prob
    else:
        prob = orderable / len(total)
    probability_mapping_y.append(prob)
    last_prob = prob
# fill in the beginning
last_prob = probability_mapping_y[-1]
for i in range(len(probability_mapping_y)):
    i = len(probability_mapping_y) - i - 1
    if ( probability_mapping_y[i] is None ):
        probability_mapping_y[i] = last_prob
    last_prob = probability_mapping_y[i]

probability_mapping_y = np.array(probability_mapping_y)

plt.xlabel("log_predict")
plt.ylabel("Pilot success rate")
plt.scatter(probability_mapping_x, probability_mapping_y)
plt.show()

------------------------------------------------------------------------------------------------------------------------<br>
Now we are going to apply our maximum likelihood equation (MLE) to the predictor dataset and see what happens. <br>
For now, it's not important to understand the purpose of this plot.

In [None]:
### PRE-08

# Apply the mle to the predictor data and see how the values look
    print('-------------------------------------------------')

predictor_df = predictor_df.copy()

apply_prob_arrays(predictor_df, prob_array, "predict")
predictor_df['log_predict'] = np.log10(predictor_df['predict'])
bounds = (np.percentile(predictor_df['log_predict'], 1), np.percentile(predictor_df['log_predict'], 99))
sns.distplot(predictor_df['log_predict'].clip(bounds[0], bounds[1]))
plt.title("All predicted data")
plt.show()

------------------------------------------------------------------------------------------------------------------------<br>
To actually use our maximum likelihood equation (MLE) as a filter, we need to decide a cutoff to use. <br>
This cell will determine the cutoff by locating the 95th percentile of the scores. <br>
We assume that the distribution will be similar in the full-scale design runs, so using this cutoff should result in a 5% pass-rate of the pre-filter. <br>
The end result is a file (`filter_cut_{seq}.txt`) which contains the cutoff scores that we'll use during the full design step.

In [None]:
### PRE-09

print('-------------------------------------------------')

predictor_df = predictor_df.copy()

apply_prob_arrays(predictor_df, prob_array, "predict")
predictor_df['log_predict'] = np.log10(predictor_df['predict'])

fraction_to_design = 0.05
topXp = int(len(predictor_df)*fraction_to_design)
MLE_cut = list(sorted(-predictor_df['log_predict']))[topXp]
print(f'To predict for the top {fraction_to_design*100}% use an MLE cutoff > {-MLE_cut}')
print(f'In your data set this corresponds to {topXp} successes out of {len(predictor_df)}')

with open(f'{filter_cut.txt','w') as f_out:
    f_out.write(str(-MLE_cut) + '\n')