## Tutorial for Supervised and Unsupervised Learning on Simulations of a Retro Aldolase

In this jupyter notebook we will use the model_building.py module to identify differences in the molecular interactions for a retro aldolase when it is in a catatlytically competent state and when it is 


across PTP1B
when the WPD-loop of PTP1B is in the Closed state, versus when the WPD-loop is in the Open state.
This notebook will also cover all the pre- and post-processing steps requireds to prepare, analyse and visualise the results.

The dataset used here is for PTP1B is the same as what we used in the manuscript. 

<center><img src="miscellaneous/TODO.png" alt="Drawing" style="width: 70%" /></center>

In [None]:
import sys # note temporary... 
sys.path.append("..") # note temporary...

from key_interactions_finder import pycontact_processing
from key_interactions_finder import data_preperation
from key_interactions_finder import model_building
from key_interactions_finder import post_proccessing
from key_interactions_finder import pymol_projections

### Step 1. Process PyContact files with the pycontact_processing.py module 

In this section we will work with the PyContact output files generated. 
Here we will merge our seperate runs together and remove any false interactions that can be generated by the PyContact library. 

In [None]:
pycontact_files_horizontal = ["PyContact_Per_Frame_Interactions_Block1.csv", "PyContact_Per_Frame_Interactions_Block2.csv",
                              "PyContact_Per_Frame_Interactions_Block3.csv", "PyContact_Per_Frame_Interactions_Block4.csv",
                              "PyContact_Per_Frame_Interactions_Block5.csv", "PyContact_Per_Frame_Interactions_Block6.csv",
                              "PyContact_Per_Frame_Interactions_Block7.csv", "PyContact_Per_Frame_Interactions_Block8.csv",
                              "PyContact_Per_Frame_Interactions_Block9.csv", "PyContact_Per_Frame_Interactions_Block10.csv",
                              "PyContact_Per_Frame_Interactions_Block11.csv", "PyContact_Per_Frame_Interactions_Block12.csv",
                              "PyContact_Per_Frame_Interactions_Block13.csv", "PyContact_Per_Frame_Interactions_Block14.csv",
                              "PyContact_Per_Frame_Interactions_Block15.csv", "PyContact_Per_Frame_Interactions_Block16.csv",
                              "PyContact_Per_Frame_Interactions_Block17.csv"]

pycontact_dataset = pycontact_processing.PyContactInitializer(
    pycontact_files=pycontact_files_horizontal,
    multiple_files=True,
    merge_files_method="horizontal",  
    remove_false_interactions=True,
    in_dir="datasets/retrol_aldolase_data/",
)

In [None]:
# As outputted above, we can inspect the newly prepared dataset by accessing the '.prepared_df' class attribute as follows:
pycontact_dataset.prepared_df

### Step 2. Prepare the Dataset for Machine Learning with the data_preperation.py module. 

In this step, we take our processed dataframe and merge our per frame classifications file to it.
We can also optionally perform several forms of filtering to select what types of interactions we
would like to study.  

In [None]:
# First we generate an instance of the SupervisedFeatureData class (because we have per frame class labels).
classifications_file = "datasets/retrol_aldolase_data/4a2s_RA95_5_Classifications.txt"

supervised_dataset = data_preperation.SupervisedFeatureData(
    input_df=pycontact_dataset.prepared_df,
    target_file=classifications_file,
    is_classification=True,
    header_present=True # If your target_file has a header present, set to True.
)

In [None]:
# As stated above to access the newly generated dataframe we can use the class attribute as follows
supervised_dataset.df_processed

##### Optional Feature Filtering

In the above dataframe we have 3057 columns (so 3056 features + 1 target). We can take all of these forward for the stastical analysis or we can perform some filtering in advance (the choice is yours). 
There are five built in filtering methods available to you to perform filtering:

1. **filter_by_occupancy(min_occupancy)** - Remove features that have an %occupancy less than the provided cut-off. %Occupancy is the % of frames with a non 0 value, i.e. the interaction is present in that frame.

2. **filter_by_interaction_type(interaction_types_included)** - PyContact defines four types of interactions ("Hbond", "Saltbr", "Hydrophobic", "Other"). You select the interactions your want to include.

3. **filter_by_main_or_side_chain(main_side_chain_types_included)** - PyContact can also define if each interaction is primarily from the backbone or side-chain for each residue. You select the interaction combinations you want to include. Options are: "bb-bb", "sc-sc", "bb-sc", "sc-bb". Where bb = backbone and sc = sidechain.

4. **filter_by_avg_strength(average_strength_cut_off)** - PyContact calculates a per frame contact score/strength for each interaction. You can filter features by the average score. Values below the cut-off are removed. 

5. **filter_by_occupancy_by_class(min_occupancy)** - Special alternative to the the standard filter features by occupancy method. %occupancy is determined for each class (as opposed to whole dataset), meaning only observations from 1 class have to meet the cut-off to keep the feature. Only avaible to datasets with a categorical target variable (classification). 


Finally if at any point in time you want to reset any filtering you've already performed, you can use the following method: 

6. **reset_filtering()** 

In [None]:

# An example of filtering the dataset using the 4 available methods. 
supervised_dataset.reset_filtering()
print(f"Number of features before any filtering: {len(supervised_dataset.df_processed.columns)}")

# Features with a %occupancy of less than 25% are removed. 
supervised_dataset.filter_by_occupancy_by_class(min_occupancy=25)
print(f"Number of features after filtering by occupancy: {len(supervised_dataset.df_filtered.columns)}")

# No filtering performed here as all possible combinations are included. 
supervised_dataset.filter_by_interaction_type(
    interaction_types_included=["Hbond", "Saltbr", "Hydrophobic", "Other"])  
print(f"Number of features after NOT filtering by interaction type: {len(supervised_dataset.df_filtered.columns)}")

# No filtering performed here as all possible combinations are included. 
supervised_dataset.filter_by_main_or_side_chain(
    main_side_chain_types_included=["bb-bb", "sc-sc", "bb-sc", "sc-bb"] 
)
print(f"Number of features after NOT filtering by main or side chain: {len(supervised_dataset.df_filtered.columns)}")

# Features with an average interaction strength less than 1.0 will be removed. 
supervised_dataset.filter_by_avg_strength(
    average_strength_cut_off=1.0,  
)
print(f"Number of features after filtering by average interaction scores: {len(supervised_dataset.df_filtered.columns)}")

Now if we look at the class attributes of our SupervisedFeatureData() instance (we called it: supervised_dataset) using the special "\_\_dict__" method we can see two dataframes we could use in the machine learning to follow. 

In [None]:
supervised_dataset.__dict__.keys()

They are: 
- 'df_processed' - The unfiltered dataframe, 3057 features
- 'df_filtered' - The filtered dataframe. Less than 3057 features. 

In the following section we will use the filtered dataframe.

### Step 3. Perform the Machine Learning with the model_building.py module. 

Now we will setup and run the supervised machine learning (ML) on the retro aldolase enzyme. Here we will apply to ML to distinguish between catalytically active and inactive conformations of the enzyme towards catalysis of XXXX. 

Describe the ML in more detail TODO

In [None]:
supervised_dataset.df_filtered["Target"].value_counts()

Looking at the output from the above cell we can see three classes, to keep things simple we will use the clearly specified classes of "CatComp" and "NotCatComp". We can specify this with the "classes_to_use" parameter in the following code block.

In [None]:
# Instantiate the model.
ml_model = model_building.ClassificationModel(
    dataset=supervised_dataset.df_filtered,
    evaluation_split_ratio=0.15,
    classes_to_use=["CatComp", "NotCatComp"], 
    models_to_use=["CatBoost", "XGBoost", "Random_Forest"],
    scaling_method="min_max",
    out_dir="outputs/retro_aldol_ml_classification",
    cross_validation_splits=5, 
    cross_validation_repeats=3,
    search_approach="none",
)

Now we can go ahead and build the models.
We have one optional parameter in the command below which is to save the models generated. This can be useful if you ever want to back and do the post-processing (described in steps 4 and 5) in the future for instance. 

If you set this to true all the files required will be saved to a folder called "temporary_files" in your current working directory. 

In [None]:
ml_model.build_models(save_models=True)

With the models now built, we can see the models seem to be quite equally matched in terms of accuracy for the train and test sets. 
We can now evaluate the quality of the models on the validation dataset (also sometimes refered to as the hold-out set).

For each ML model built a pandas dataframe is generated which contains key results on the validation dataset. 
If you are unfamiliar with any of the terms presented below, [feel free to check out this guide from scikit-learn](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html)

In [None]:
reports = ml_model.evaluate_models()

In [None]:
reports["XGBoost"]

In [None]:
reports["CatBoost"]

In [None]:
reports["Random_Forest"]

Another popular way to evaluate model quality is to generate confusion matrices. 
Using the below command we can generate confusion matrices (stored as numpy arrays) for each model we generated. 

You can then easily plot these confusion matrices in whatever graphing program you like. In this case, I will use seaborn.

In [None]:
confusion_matrices = ml_model.generate_confusion_matrix()

In [None]:
confusion_matrices["XGBoost"]

In [None]:
# TODO - write in plotly instead... 
import seaborn as sns
axis_labels = ["NotCatComp", "CatComp"]
ax = sns.heatmap(confusion_matrices["XGBoost"], annot=True, fmt="d", xticklabels=axis_labels, yticklabels=axis_labels, cmap="Greens")

In [None]:
ax = sns.heatmap(confusion_matrices["CatBoost"], annot=True, fmt="d", xticklabels=axis_labels, yticklabels=axis_labels, cmap="Greens")

In [None]:
ax = sns.heatmap(confusion_matrices["Random_Forest"], annot=True, fmt="d", xticklabels=axis_labels, yticklabels=axis_labels, cmap="Greens")

### Step 4. Work up the Machine Learning with the post_proccessing.py module. 

With this module, we can analyse our results in more detail to understand what features each model determined where important for distignugshing between each state. 

In order to perform the analysis we will need to provide the models generated in step 3. 

In [None]:
# First we will make an instance of the SupervisedPostProcessor class.
post_proc = post_proccessing.SupervisedPostProcessor(
    out_dir="outputs/retro_aldol_ml_classification",
)

# Option 1 - Load models from the instance of the SupervisedModel class. 
post_proc.load_models_from_instance(supervised_model=ml_model)

# Option 2 - Load models from disk.
#post_proc.load_models_from_disk(models_to_use=["XGBoost", "CatBoost", "Random_Forest"]) 

In [None]:
# After preparing the class we can now determine the feature importances for each model.
post_proc.get_feature_importance()

In [None]:
# We can also project these per feature importances onto the per-residue level. 
# This is done by summing each residues features importances and normalising so that the residue
#  with the greatest overall  
post_proc.get_per_res_importance()

In [None]:
# Again, if we take a look at the class attributes we can see the per feature and 
# per residue importances were not just saved to disk, but are also now stored in the class
# meaning you can analyse them here if you wish. 
print(post_proc.__dict__.keys())
all_per_res_scores = post_proc.all_per_residue_scores
all_feature_scores = post_proc.all_feature_importances

We can also visualise these results graphically in an interactive manner with the plotly graphing library. 
If you don't have that library installed you can do so now by uncommenting the code block below. 

In [None]:
# ! pip install plotly

In [None]:
import pandas as pd 
import plotly.express as px
df_all_per_res_scores = pd.DataFrame(all_per_res_scores).reset_index()
df_all_per_res_scores = df_all_per_res_scores.rename(columns={"index": "Residue Number"})
df_all_per_res_scores = df_all_per_res_scores.sort_values(["Residue Number"], ascending=True)
df_all_per_res_scores.head() 

In [None]:
fig = px.line(df_all_per_res_scores, x="Residue Number", y=["CatBoost", "XGBoost", "Random_Forest"])
fig.update_layout(
    title="Per residue Relative Importances for All 3 Machine Learning Models",
    xaxis_title="Residue Number",
    yaxis_title="Relative Importance",
    legend_title="ML Models",
    font=dict(size=16)
)

fig.show()

### Step 5. Projecting the Results onto Protein Structures with the pymol_projections.py module. 
 
Naturally, we may want to visualise some of the results we have generated above onto a protein structure. We can take advantage of the functions provided in the pymol_projections.py module to do this. 

As the name suggests this will output [PyMOL](https://pymol.org/) compatible python scripts which can be run to represent the results at either the: 

1. Per feature level. (Cylinders are drawn between both residues in each feature, with the cylinder radii marking how large the relative importance is. 
2. Per residue level. The carbon alpha of each residue will be depicted as a sphere, with the sphere radii depicting the relative importance of the residue for the machine learning model.

In [None]:
pymol_projections.project_multiple_per_res_scores(
    all_per_res_scores=all_per_res_scores,
    out_dir="outputs/retro_aldol_ml_classification"
)

In [None]:
pymol_projections.project_multiple_per_feature_scores(
    all_feature_scores=all_feature_scores,
    numb_features="all",
    out_dir="outputs/retro_aldol_ml_classification"
)

In [None]:
# TODO ADD Picture of the outputs here as an example. 