# Import Libraries

In [None]:
pip install rdkit mlflow

In [3]:
from google.colab import drive
import os
import numpy as np
import pandas as pd
import shap
import mlflow
import mlflow.pyfunc
import mlflow.sklearn
from mlflow.tracking import MlflowClient
from rdkit import Chem
from rdkit.Chem import Draw, Descriptors
import joblib
import matplotlib.pyplot as plt
import base64
from io import BytesIO
from IPython.display import display, HTML
import seaborn as sns

# Load Models

In [4]:
# load data
drive.mount('/content/drive')

# set working directory
os.chdir('/content/drive/MyDrive/Solubility')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
# load models
model = joblib.load("solubility_rf_model.joblib")
explainer = joblib.load("explainer.pkl")

# Define Functions for Example Processing

In [6]:
# function to compute molecular descriptors for input molecule
def compute_descriptors(smiles):
    mol = Chem.MolFromSmiles(smiles)

    if mol is None:
        st.error("Invalid SMILES string")
        return None

    descriptor_names = [
        "MaxEStateIndex",
        "MinEStateIndex",
        "qed",
        "SPS",
        "MolWt",
        "MaxPartialCharge",
        "MinPartialCharge",
        "FpDensityMorgan2",
        "BCUT2D_MWHI",
        "BCUT2D_CHGHI",
        "BCUT2D_LOGPHI",
        "BCUT2D_MRHI",
        "AvgIpc",
        "BalabanJ",
        "HallKierAlpha",
        "Ipc",
        "Kappa3",
        "TPSA",
        "FractionCSP3",
        "NumAromaticCarbocycles",
        "NumAromaticRings",
        "NumHAcceptors",
        "NumHDonors",
        "NumHeteroatoms",
        "NumRotatableBonds",
        "Phi",
        "RingCount",
        "MolLogP",
    ]

    descriptor_vals = {name: func(mol) for name, func in Descriptors.descList if name in descriptor_names}

    return pd.DataFrame([descriptor_vals])

In [7]:
# define function to get prediction for example molecule and map to solubility categories
def get_prediction(smiles):

  # define target class mapping
  solubility_class_labels = {
    0: "Insoluble (< -4 LogS)",
    1: "Slightly soluble (-4 to -2 LogS)",
    2: "Soluble (> -2 LogS)",
  }

  # get prediction class
  prediction_class = model.predict(descriptors)[0]
  prediction_label = solubility_class_labels[prediction_class]

  # get class probs
  prediction_probs = model.predict_proba(descriptors)[0]
  print(prediction_probs)
  probs_df = pd.DataFrame(
      data=prediction_probs,
      columns=["Insoluble (< -4 LogS)", "Slightly soluble (-4 to -2 LogS)", "Soluble (> -2 LogS)"]
  )
  return prediction_label, probs_df

  #print("\nPredicted Solubility Class:", prediction_class)
  #print("\n Prediction Probability:", prediction_prob[0])

In [25]:
def generate_html_report(smiles, top_k=5):
    # compute descriptors for input molecule
    descriptors_df = compute_descriptors(smiles)
    if descriptors_df is None:
        return "<p>Invalid SMILES string. Please try again.</p>"

    # map descriptors to interpretable names
    descriptor_labels = {
      "MolWt": "Molecular Weight",
      "MolLogP": "LogP",
      "TPSA": "Topological Polar Surface Area",
      "qed": "QED (Drug-likeness)",
      "FractionCSP3": "Fraction sp3 Carbons",
      "NumHAcceptors": "H-Bond Acceptor Count",
      "NumHDonors": "H-Bond Donor Count",
      "RingCount": "Ring Count",
      "FpDensityMorgan2": "Fragment Density",
      "BalabanJ": "Molecular Complexity (BalabanJ)",
      "MaxEStateIndex": "Max E-State Index",
      "MinEStateIndex": "Min E-State Index",
      "Phi": "Phi (Flexibility)",
      "SPS": "Simple Polar Surface"
    }

    # get molecule structure
    mol = Chem.MolFromSmiles(smiles)
    img = Draw.MolToImage(mol, size=(300, 300))
    buf = BytesIO()
    img.save(buf, format="PNG")
    img_data = base64.b64encode(buf.getvalue()).decode()
    img_html = f'<img src="data:image/png;base64,{img_data}" alt="Molecule Structure" style="max-width: 300px;">'

    # define target class mapping
    solubility_class_labels = {
      0: "Insoluble",
      1: "Slightly soluble",
      2: "Soluble"
    }

    # get predictions and probabilities
    predicted_class = model.predict(descriptors_df)[0]
    predicted_label = solubility_class_labels[predicted_class]
    probabilities = model.predict_proba(descriptors_df)[0]
    prob_df = pd.DataFrame(
        data=[probabilities],
        columns=["Insoluble", "Slightly soluble", "Soluble"]
    )
    prob_html = prob_df.to_html(index=False, float_format="%.2f")

    # generate SHAP values
    shap_values = explainer.shap_values(descriptors_df)

    predicted_class = model.predict(descriptors_df)[0]
    shap_values_for_class = shap_values[0][:, predicted_class]

    if len(descriptors_df.columns) != len(shap_values_for_class):
        print(f"Error: Mismatch between descriptors ({len(descriptors_df.columns)}) and SHAP values ({len(shap_values_for_class)}).")
        return pd.DataFrame()

    # map feature names to interpretable labels
    shap_df = pd.DataFrame({
        "Feature": descriptors_df.columns,
        "SHAP Value": shap_values_for_class,
        "Feature Value": descriptors_df.iloc[0].values
    })
    shap_df["Feature"] = shap_df["Feature"].map(descriptor_labels).fillna(shap_df["Feature"])

    # sort by absolute SHAP value
    shap_df["Abs SHAP Value"] = shap_df["SHAP Value"].abs()
    shap_df = shap_df.sort_values(by="Abs SHAP Value", ascending=False).head(top_k)

    # add color label for hue
    shap_df["Contribution"] = ["Decreases Solubility" if val < 0 else "Increases Solubility" for val in shap_df["SHAP Value"]]

    # create the waterfall plot
    plt.figure(figsize=(8, 6))
    sns.barplot(
        y=[f"{feat} ({val:.2f})" for feat, val in zip(shap_df["Feature"], shap_df["Feature Value"])],
        x=shap_df["SHAP Value"],
        hue=shap_df["Contribution"],
        palette={"Decreases Solubility": "#E74C3C", "Increases Solubility": "#2ECC71"},
        dodge=False
    )
    plt.title(f"Top {top_k} SHAP Feature Contributions")
    plt.xlabel("Impact on Solubility")
    plt.ylabel("Descriptor (Value)")
    plt.axvline(0, color="black", linewidth=0.8)
    plt.tight_layout()
    shap_html_path = "shap_waterfall.png"
    plt.savefig(shap_html_path, bbox_inches="tight", dpi=300)
    plt.close()

    # convert the SHAP plot to base64
    with open(shap_html_path, "rb") as f:
        shap_img_data = base64.b64encode(f.read()).decode()
    shap_img_html = f'<img src="data:image/png;base64,{shap_img_data}" alt="SHAP Waterfall" style="max-width: 600px;">'

    # generate final HTML report
    html_content = f"""
    <html>
    <head><style>body {{ font-family: Arial, sans-serif; line-height: 1.6; }}</style></head>
    <body>
    <h2>Solubility Report</h2>
    {img_html}
    <h3>Predicted Solubility Class: {predicted_label}</h3>
    <h3>Class Probabilities</h3>
    {prob_html}
    <h3>SHAP Waterfall Plot</h3>
    {shap_img_html}
    </body>
    </html>
    """

    # display the report in the notebook
    display(HTML(html_content))

# Example

In [24]:
# provide input smiles and generate report
smiles = 'CC(=O)OC1=CC=CC=C1C(=O)O' # Aspirin
generate_html_report(smiles)

Insoluble,Slightly soluble,Soluble
0.0,0.23,0.78
