# 0- Load Library

In [None]:
import os
import pickle
import subprocess


import exmol
import numpy as np
import seaborn as sns
import pandas as pd
import deepchem as dc
from sklearn.svm import SVC
from rdkit import Chem
from rdkit.Chem import AllChem
import matplotlib.pyplot as plt
from matplotlib import font_manager
from imblearn.over_sampling import RandomOverSampler


import utils
import plots
#import retrieve_data
from pipeline import run_model_pipeline
from conformal_prediction import (
    ecfp_generator,
    probability_dataframe,
    run_data,
    calculate_conformity_scores,
    evaluate_conformal_predictor,
)

# 1- Downloading Data

Download data for three UniProt from ChEMBL database.

In [None]:
for uniprot_id in ["O43570", "Q16790", "P00918"]:
    retrieve_data.download_data(uniprot_id)

# 2- Pre-processing

Canonicalize SMILES of each ligand.

In [None]:
csv_files = [
    "../Data/O43570_ChEMBL_data.csv",
    "../Data/P00918_ChEMBL_data.csv",
    "../Data/Q16790_ChEMBL_data.csv",
]


def process_smiles(row):
    standardize_smiles = utils.canonical_smiles(row[smiles_column])
    row["standardize_smiles"] = standardize_smiles
    return row


# Loop through each CSV file
for csv_file in csv_files:
    print(f"{csv_file} UniProt...")
    data = pd.read_csv(csv_file)

    smiles_column = "smiles"

    data = data.apply(process_smiles, axis=1)

    # Save the updated DataFrame to a new file
    output_file = csv_file.replace(".csv", "_canonical_smiles.csv")
    data.to_csv(output_file, index=False)
    print(f"Processed '{csv_file}' and saved to '{output_file}'")
    print("---------------------")

# 3- Substructure Filter

Filter data based on sulfonamide functional group.

In [None]:
csv_files = [
    "../Data/O43570_ChEMBL_data_canonical_smiles.csv",
    "../Data/P00918_ChEMBL_data_canonical_smiles.csv",
    "../Data/Q16790_ChEMBL_data_canonical_smiles.csv",
]

# Loop through eeach CSV file
for csv_file in csv_files:
    print(f"{csv_file} UniProt...")
    df = pd.read_csv(csv_file)

    # Extract SMILES strings from the specified column
    smiles_column = df.loc[:, "smiles"]

    # Check for sulfonamide substructure
    has_sulfonamide = smiles_column.apply(utils.check_sulfonamide)

    # Add a new column indicating the presence of sulfonamide
    df.loc[:, "has_sulfonamide"] = has_sulfonamide

    # Save the updated DataFrame to a new file
    output_file = csv_file.replace(".csv", "_filtered.csv")
    df.to_csv(output_file, index=False)

    print(
        f"From {df.shape[0]} data points {has_sulfonamide.sum()} has primary sulfonamide."
    )
    print(f"Processed '{csv_file}' and saved to '{output_file}'")
    print("---------------------")

# 4- Statistics

Calculate statistics of binding affinity values.

In [None]:
file_paths = [
    "../Data/P00918_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/Q16790_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/O43570_ChEMBL_data_canonical_smiles_filtered_status.csv",
]


all_stats = []

for file_path in file_paths:

    df = pd.read_csv(file_path)

    stats = utils.calculate_statistics(df["pK"])

    all_stats.append(stats)


combined_stats = pd.DataFrame(all_stats).transpose().round(3)

file_names = ["II", "IX", "XII"]

combined_stats.columns = file_names

combined_stats.to_csv("../Data/binding_affinity_statistics.csv", index=True)

# 5- Count Binding Affinity Types

Count number of different Ki, Kd, and IC50 for a compiled dataset.

In [None]:
file_paths = [
    "../Data/O43570_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/P00918_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/Q16790_ChEMBL_data_canonical_smiles_filtered_status.csv",
]

types = ["IC50", "Ki", "Kd"]

counts_dict = {}

for file_path in file_paths:
    name = file_path.split("/")[-1].split(".")[0]
    df = pd.read_csv(file_path)
    counts = utils.count_letters(df, types)
    counts_dict[name] = counts

pd.DataFrame(counts_dict).transpose().to_csv(
    "../Data/binding_affinity_types_counts.csv", index=True
)

# 6- Calculate Molecular Properties

Calculate several molecular properties and return their statistics.

In [None]:
file_paths = [
    "../Data/P00918_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/Q16790_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/O43570_ChEMBL_data_canonical_smiles_filtered_status.csv",
]


for file_path in file_paths:

    utils.calculate_molecular_property(file_path)


file_paths = [
    "../Data/P00918_molecular_property.csv",
    "../Data/Q16790_molecular_property.csv",
    "../Data/O43570_molecular_property.csv",
]


for file_path in file_paths:

    utils.molecular_property_stats(file_path)

# 7- Plots

## 7-1- pKa

Plot KDE diagram of pKa values.

In [None]:
font_manager.findfont("Helvetica")
plt.rc("font", family="Helvetica")
plt.rc("font", serif="Helvetica", size=32)
plt.rcParams["axes.linewidth"] = 1.25
plt.rcParams["xtick.major.size"] = 10
plt.rcParams["xtick.minor.size"] = 2
plt.rcParams["xtick.major.width"] = 1.5
plt.rcParams["xtick.minor.width"] = 1.5
plt.rcParams["ytick.major.size"] = 10
plt.rcParams["ytick.minor.size"] = 2
plt.rcParams["ytick.major.width"] = 1.5
plt.rcParams["ytick.minor.width"] = 1.5
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["xtick.direction"] = "in"

plt.rcParams["mathtext.it"] = "Helvetica :italic"
plt.rcParams["mathtext.rm"] = "Helvetica"
plt.rcParams["mathtext.default"] = "regular"

file_paths = [
    "../Data/P00918_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/Q16790_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/O43570_ChEMBL_data_canonical_smiles_filtered_status.csv",
]


colors = ["#bee9e8", "#62b6cb", "#1b4965"]

plots.plot_pka(file_paths, colors)

## 7-2- Properties

Plot molecular properties KDE diagram.

In [None]:
font_manager.findfont("Helvetica Light")
plt.rc("font", family="Helvetica Light")
plt.rc("font", serif="Helvetica Light", size=25)
plt.rcParams["axes.linewidth"] = 1.25
plt.rcParams["xtick.major.size"] = 8
plt.rcParams["xtick.minor.size"] = 2
plt.rcParams["xtick.major.width"] = 1.25
plt.rcParams["xtick.minor.width"] = 1.25
plt.rcParams["ytick.major.size"] = 8
plt.rcParams["ytick.minor.size"] = 2
plt.rcParams["ytick.major.width"] = 1.25
plt.rcParams["ytick.minor.width"] = 1.25
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["xtick.direction"] = "in"

plt.rcParams["mathtext.it"] = "Helvetica Light:italic"
plt.rcParams["mathtext.rm"] = "Helvetica Light"
plt.rcParams["mathtext.default"] = "regular"


file_paths = [
    "../Data/P00918_molecular_property.csv",
    "../Data/Q16790_molecular_property.csv",
    "../Data/O43570_molecular_property.csv",
]

colors = ["#bee9e8", "#62b6cb", "#1b4965"]

plots.plot_property(file_paths, colors)

## 7-3- T-SNE

Plot T-SNE diagram.

In [None]:
font_manager.findfont("Helvetica Light")
plt.rc("font", family="Helvetica Light")
plt.rc("font", serif="Helvetica Light", size=20)
plt.rcParams["axes.linewidth"] = 1.25
plt.rcParams["xtick.major.size"] = 8
plt.rcParams["xtick.minor.size"] = 2
plt.rcParams["xtick.major.width"] = 1.25
plt.rcParams["xtick.minor.width"] = 1.25
plt.rcParams["ytick.major.size"] = 8
plt.rcParams["ytick.minor.size"] = 2
plt.rcParams["ytick.major.width"] = 1.25
plt.rcParams["ytick.minor.width"] = 1.25
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["xtick.direction"] = "in"

plt.rcParams["mathtext.it"] = "Helvetica Light:italic"
plt.rcParams["mathtext.rm"] = "Helvetica Light"
plt.rcParams["mathtext.default"] = "regular"

file_paths = [
    "../Data/P00918_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/Q16790_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/O43570_ChEMBL_data_canonical_smiles_filtered_status.csv",
]

colors = ["#f46036", "#2e294e", "#1b998b"]


plot_t_sne(file_paths, colors)

# 8- Train-Test-Validation

## 8-1- LR

#### 8-1-1- Random

In [None]:
run_model_pipeline(model_type="logistic", split_type="random")

#### 8-1-2- Scaffold

In [None]:
run_model_pipeline(model_type="logistic", split_type="scaffold")

In [None]:
with open(file_path, "rb") as f:
    data_split = pickle.load(f)

## 8-2- SVC

### 8-2-1- Random

In [None]:
run_model_pipeline(model_type="svm", split_type="random")

### 8-2-2- Scaffold

In [None]:
run_model_pipeline(model_type="svm", split_type="scaffold")

## 8-3- RF

### 8-3-1- Random

In [None]:
run_model_pipeline(model_type="randomforest", split_type="random")

### 8-3-2- Scaffold

In [None]:
run_model_pipeline(model_type="randomforest", split_type="scaffold")

## 8-4- XGBoost

### 8-4-1- Random

In [None]:
run_model_pipeline(model_type="xgboost", split_type="random")

### 8-4-2- Scaffold

In [None]:
run_model_pipeline(model_type="xgboost", split_type="scaffold")

## 8-5- FFNN

### 8-5-1- Random

In [None]:
run_model_pipeline(model_type="ffneuralnetwork", split_type="random")

### 8-5-2- Scaffold

In [None]:
run_model_pipeline(model_type="ffneuralnetwork", split_type="scaffold")

## 8-6- GIN

### 8-6-1- Random

In [None]:
run_model_pipeline(model_type="gin", split_type="random")

### 8-6-2- Scaffold

In [None]:
run_model_pipeline(model_type="gin", split_type="scaffold")

# 9- Results

## 9-1- McNemar Test 

In [None]:
result = subprocess.run(["python", "mcnemar_test.py"], capture_output=True, text=True)

# 10- Conformal Prediction

## 10-1- Training and Testing

In [None]:
uniprot_to_isoform = {"P00918": "CA2", "Q16790": "CA9", "O43570": "CA12"}

CA2_HP = {
    "C": 66.96581256335178,
    "kernel": "rbf",
    "gamma": "scale",
    "coef0": 0.4168195903861248,
}
CA9_HP = {
    "C": 268.35706901387067,
    "kernel": "poly",
    "gamma": "scale",
    "coef0": 0.7717853353913062,
}
CA12_HP = {
    "C": 2.320983366278537,
    "kernel": "rbf",
    "gamma": "scale",
    "coef0": 0.9704194409850162,
}

file_paths = [
    "../Data/O43570_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/P00918_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/Q16790_ChEMBL_data_canonical_smiles_filtered_status.csv",
]

folds_dict = {}

for file_path in file_paths:

    name = file_path.split("/")[-1].split("_")[0]

    df = pd.read_csv(file_path)
    x = df.loc[:, "standardize_smiles"].to_numpy()
    y = df.loc[:, "status"].values

    dataset = dc.data.DiskDataset.from_numpy(X=x, y=y, w=np.zeros(len(x)), ids=x)
    for i in range(10):
        scaffoldsplitter = dc.splits.ScaffoldSplitter()
        train_set, val_set, test_set = scaffoldsplitter.train_valid_test_split(
            dataset, frac_train=0.7, frac_valid=0.1, frac_test=0.2
        )

        X_train, y_train = train_set.X, train_set.y.reshape(-1)
        X_val, y_val = val_set.X, val_set.y.reshape(-1)
        X_test, y_test = test_set.X, test_set.y.reshape(-1)

        ecfp_fv_dict = {}

        X_train_ecfp = ecfp_generator(X_train)
        X_val_ecfp = ecfp_generator(X_val)
        X_test_ecfp = ecfp_generator(X_test)

        ros = RandomOverSampler()
        X_train_ecfp_resampled, y_train_resampled = ros.fit_resample(
            X_train_ecfp, y_train
        )

        ecfp_fv_dict = {
            "train_fv": X_train_ecfp_resampled,
            "train_label": y_train_resampled,
            "val_fv": X_val_ecfp,
            "val_label": y_val,
            "test_fv": X_test_ecfp,
            "test_label": y_test,
        }

        folds_dict[f"{name}_split_{i}"] = ecfp_fv_dict

for name in ["P00918", "Q16790", "O43570"]:
    if name == "P00918":
        models = [SVC(**CA2_HP, probability=True) for _ in range(10)]

    elif name == "Q16790":
        models = [SVC(**CA9_HP, probability=True) for _ in range(10)]
    else:
        models = [SVC(**CA12_HP, probability=True) for _ in range(10)]

    models = [
        svm.fit(
            folds_dict[f"{name}_split_{i}"]["train_fv"],
            folds_dict[f"{name}_split_{i}"]["train_label"],
        )
        for i, svm in enumerate(models)
    ]

    probability_dataframe(models, folds_dict, name, split_type="val")
    probability_dataframe(models, folds_dict, name, split_type="test")

## 10-2- Evaluation (Validity and Efficiency)

In [None]:
for isoform in ["CA2", "CA9", "CA12"]:
    cal_file = f"../Results/Conformal_Prediction/svm_ecfp_{isoform}_val_prob.csv"
    test_file = f"../Results/Conformal_Prediction/svm_ecfp_{isoform}_test_prob.csv"

    cal_df = pd.read_csv(cal_file)
    test_df = pd.read_csv(test_file)

    epsilons = [0.01, 0.05, 0.1, 0.15, 0.20, 0.25, 0.3]
    num_runs = 10

    avg_validities = []
    avg_efficiencies_tl = []
    avg_empty_rate_overalls = []

    for epsilon in epsilons:
        run_validities = []
        run_efficiencies_tl = []
        run_empty_rates = []

        for i in range(num_runs):

            p_active_cal, p_inactive_cal, labels_cal = run_data(cal_df, i)
            p_active_test, p_inactive_test, labels_test = run_data(test_df, i)

            calibration_scores_dict = calculate_conformity_scores(
                p_active_cal, p_inactive_cal, labels_cal
            )

            validity, efficiency_tl, empty_rate = evaluate_conformal_predictor(
                p_active_test,
                p_inactive_test,
                labels_test,
                calibration_scores_dict,
                epsilon,
            )
            run_validities.append(validity)
            run_efficiencies_tl.append(efficiency_tl)
            run_empty_rates.append(empty_rate)

        avg_validity = np.mean(run_validities)
        avg_efficiency_tl = np.mean(run_efficiencies_tl)
        avg_empty_rate_overall = np.mean(run_empty_rates)

        avg_validities.append(avg_validity)
        avg_efficiencies_tl.append(avg_efficiency_tl)
        avg_empty_rate_overalls.append(avg_empty_rate_overall)

        cp_single_result_df = pd.DataFrame(
            {
                "Validity": run_validities,
                "Efficiency": run_efficiencies_tl,
                "Empty Rate": run_empty_rates,
            }
        )
        cp_single_result_df.to_csv(
            f"../Results/Conformal_Prediction/cp_{isoform}_{epsilon}.csv", index=False
        )

    cp_total_result_df = pd.DataFrame(
        {
            "Epsilon": epsilons,
            "Avg Validity": avg_validities,
            "Avg Efficiency": avg_efficiencies_tl,
            "Avg Empty Rate": avg_empty_rate_overalls,
        }
    ).to_csv(f"../Results/Conformal_Prediction/cp_result_{isoform}.csv", index=False)

# 11- XAI 

## 11-0 Functions

In [None]:
def counterfactual_explain(samples, name):
    sns.set_context("notebook")
    sns.set_style("dark")

    font_manager.findfont("Helvetica")
    plt.rc("font", family="Helvetica")
    plt.rc("font", serif="Helvetica", size=22)
    fkw = {
        "figsize": (8, 6),  # Width, height in inches
        "dpi": 600,  # Dots per inch (high-res)
        "facecolor": "white",  # Background color of the figure
        "edgecolor": "white",  # Edge color of the figure
    }

    cfs = exmol.cf_explain(samples, nmols=3)
    exmol.plot_cf(cfs, figure_kwargs=fkw, mol_size=(350, 300), nrows=2, mol_fontsize=8)
    plt.tight_layout()
    # plt.savefig(f"{name}_counterfactual_samples.pdf", bbox_inches="tight")
    svg = exmol.insert_svg(cfs)
    with open(f"{name}_counterfactual_samples.svg", "w") as f:
        f.write(svg)

In [None]:
def counterfactual_space(samples, name):
    sns.set_context("notebook")
    sns.set_style("dark")

    font_manager.findfont("Helvetica")
    plt.rc("font", family="Helvetica")
    plt.rc("font", serif="Helvetica", size=22)
    fkw = {
        "figsize": (10, 8),  # Width, height in inches
        "dpi": 300,  # Dots per inch (high-res)
        "facecolor": "white",  # Background color of the figure
        "edgecolor": "white",  # Edge color of the figure
    }

    cfs = exmol.cf_explain(samples, nmols=3)
    exmol.plot_space(
        samples, cfs, figure_kwargs=fkw, mol_size=(350, 300), mol_fontsize=8
    )
    plt.scatter([], [], label="Same Class", s=150, color=plt.get_cmap("viridis")(1.0))
    plt.scatter(
        [], [], label="Counterfactual", s=150, color=plt.get_cmap("viridis")(0.0)
    )
    plt.legend(
        fontsize=22,  # Font size of the legend text
        loc=None,  # Location of the legend; None = automatic best placement
        ncol=1,  # Number of columns in the legend
        frameon=True,  # Whether to draw a border/frame around the legend
        shadow=False,  # Whether to draw a shadow behind the legend
        title=None,  # Title text for the legend (None = no title)
        title_fontsize=18,  # Font size of the legend title
        facecolor="white",  # Background color of the legend box
        edgecolor="black",  # Border color of the legend box
        framealpha=0.8,  # Transparency of the legend frame (0 = fully transparent, 1 = opaque)
        labelspacing=0.5,  # Vertical space between legend entries (in font-size units)
        handlelength=2.5,  # Length of the legend handles (lines or markers)
        handletextpad=1.0,  # Padding between handle and text (in font-size units)
        borderpad=0.5,  # Padding inside the legend border (in fraction of font size)
        borderaxespad=0.4,  # Padding between the legend and the axes (in fraction of font size)
        columnspacing=1.5,  # Horizontal spacing between columns (if ncol > 1)
        fancybox=True,  # Whether to draw a rounded box (True) or square box (False)
    )

    plt.tight_layout()
    # plt.savefig("{name}_counterfactual_space.pdf", bbox_inches="tight")
    svg = exmol.insert_svg(cfs)
    with open(f"{name}_counterfactual_space.svg", "w") as f:
        f.write(svg)

## 11-1- Training

In [None]:
uniprot_to_isoform = {"P00918": "CA2", "Q16790": "CA9", "O43570": "CA12"}

CA2_HP = {
    "C": 66.96581256335178,
    "kernel": "rbf",
    "gamma": "scale",
    "coef0": 0.4168195903861248,
}
CA9_HP = {
    "C": 268.35706901387067,
    "kernel": "poly",
    "gamma": "scale",
    "coef0": 0.7717853353913062,
}
CA12_HP = {
    "C": 2.320983366278537,
    "kernel": "rbf",
    "gamma": "scale",
    "coef0": 0.9704194409850162,
}

file_paths = [
    "../Data/O43570_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/P00918_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/Q16790_ChEMBL_data_canonical_smiles_filtered_status.csv",
]

folds_dict = {}

for file_path in file_paths:

    name = file_path.split("/")[-1].split("_")[0]

    if os.path.exists(f"svm_{name}.pkl"):
        continue

    df = pd.read_csv(file_path)
    x = df.loc[:, "standardize_smiles"].to_numpy()
    y = df.loc[:, "status"].values

    dataset = dc.data.DiskDataset.from_numpy(X=x, y=y, w=np.zeros(len(x)), ids=x)

    scaffoldsplitter = dc.splits.ScaffoldSplitter()
    train_set, val_set, test_set = scaffoldsplitter.train_valid_test_split(
        dataset, frac_train=0.7, frac_valid=0.1, frac_test=0.2, seed=42
    )

    X_train, y_train = train_set.X, train_set.y.reshape(-1)
    X_val, y_val = val_set.X, val_set.y.reshape(-1)
    X_test, y_test = test_set.X, test_set.y.reshape(-1)

    ecfp_fv_dict = {}

    X_train_ecfp = ecfp_generator(X_train)
    X_val_ecfp = ecfp_generator(X_val)
    X_test_ecfp = ecfp_generator(X_test)

    ros = RandomOverSampler()
    X_train_ecfp_resampled, y_train_resampled = ros.fit_resample(X_train_ecfp, y_train)

    ecfp_fv_dict = {
        "train_fv": X_train_ecfp_resampled,
        "train_label": y_train_resampled,
        "val_fv": X_val_ecfp,
        "val_label": y_val,
        "test_fv": X_test_ecfp,
        "test_label": y_test,
    }

    folds_dict[f"{name}_split"] = ecfp_fv_dict

for name in ["P00918", "Q16790", "O43570"]:

    if os.path.exists(f"svm_{name}.pkl"):
        continue
    if name == "P00918":
        svm = SVC(**CA2_HP, random_state=42)

    elif name == "Q16790":
        svm = SVC(**CA9_HP, random_state=42)
    else:
        svm = SVC(**CA12_HP, random_state=42)

    svm.fit(
        folds_dict[f"{name}_split"]["train_fv"],
        folds_dict[f"{name}_split"]["train_label"],
    )

    with open(f"svm_{name}.pkl", "wb") as file:
        pickle.dump(svm, file)

## 11-2- Isoform II

In [None]:
with open(f"svm_P00918.pkl", "rb") as file:
    svm = pickle.load(file)

In [None]:
base = "NS(=O)(=O)c1ccc(NC(=O)Nc2ccc(F)cc2)cc1"


def model(smiles):
    mol = Chem.MolFromSmiles(smiles)
    ecfp_fv = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    y_predict = svm.predict(np.array(ecfp_fv).reshape(1, -1))
    return 1 if y_predict else 0


samples = exmol.sample_space(base, model, batched=False)

In [None]:
counterfactual_explain(samples, "P00918")
counterfactual_space(samples, "P00918")

## 11-3- Isoform IX

In [None]:
with open(f"svm_Q16790.pkl", "rb") as file:
    svm = pickle.load(file)

In [None]:
base = "NS(=O)(=O)c1ccc(NC(=O)Nc2ccc(F)cc2)cc1"


def model(smiles):
    mol = Chem.MolFromSmiles(smiles)
    ecfp_fv = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    y_predict = svm.predict(np.array(ecfp_fv).reshape(1, -1))
    return 1 if y_predict else 0


samples = exmol.sample_space(base, model, batched=False)

In [None]:
counterfactual_explain(samples, "Q16790")
counterfactual_space(samples, "Q16790")

## 11-4- Isoform XII

In [None]:
with open(f"svm_O43570.pkl", "rb") as file:
    svm = pickle.load(file)

In [None]:
base = "NS(=O)(=O)c1ccc(NC(=O)Nc2ccc(F)cc2)cc1"


def model(smiles):
    mol = Chem.MolFromSmiles(smiles)
    ecfp_fv = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    y_predict = svm.predict(np.array(ecfp_fv).reshape(1, -1))
    return 1 if y_predict else 0


samples = exmol.sample_space(base, model, batched=False)

In [None]:
counterfactual_explain(samples, "O43570")
counterfactual_space(samples, "O43570")

# 12- GUI

## 12-1- Training and Calibration Making

In [None]:
uniprot_to_isoform = {"P00918": "CA2", "Q16790": "CA9", "O43570": "CA12"}

CA2_HP = {
    "C": 66.96581256335178,
    "kernel": "rbf",
    "gamma": "scale",
    "coef0": 0.4168195903861248,
}
CA9_HP = {
    "C": 268.35706901387067,
    "kernel": "poly",
    "gamma": "scale",
    "coef0": 0.7717853353913062,
}
CA12_HP = {
    "C": 2.320983366278537,
    "kernel": "rbf",
    "gamma": "scale",
    "coef0": 0.9704194409850162,
}

file_paths = [
    "../Data/O43570_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/P00918_ChEMBL_data_canonical_smiles_filtered_status.csv",
    "../Data/Q16790_ChEMBL_data_canonical_smiles_filtered_status.csv",
]

folds_dict = {}

for file_path in file_paths:

    name = file_path.split("/")[-1].split("_")[0]

    df = pd.read_csv(file_path)
    x = df.loc[:, "standardize_smiles"].to_numpy()
    y = df.loc[:, "status"].values

    dataset = dc.data.DiskDataset.from_numpy(X=x, y=y, w=np.zeros(len(x)), ids=x)

    scaffoldsplitter = dc.splits.ScaffoldSplitter()
    train_set, val_set, test_set = scaffoldsplitter.train_valid_test_split(
        dataset, frac_train=0.7, frac_valid=0.1, frac_test=0.2, seed=42
    )

    X_train, y_train = train_set.X, train_set.y.reshape(-1)
    X_val, y_val = val_set.X, val_set.y.reshape(-1)
    X_test, y_test = test_set.X, test_set.y.reshape(-1)

    ecfp_fv_dict = {}

    X_train_ecfp = ecfp_generator(X_train)
    X_val_ecfp = ecfp_generator(X_val)
    X_test_ecfp = ecfp_generator(X_test)

    ros = RandomOverSampler()
    X_train_ecfp_resampled, y_train_resampled = ros.fit_resample(X_train_ecfp, y_train)

    ecfp_fv_dict = {
        "train_fv": X_train_ecfp_resampled,
        "train_label": y_train_resampled,
        "val_fv": X_val_ecfp,
        "val_label": y_val,
        "test_fv": X_test_ecfp,
        "test_label": y_test,
    }

    folds_dict[f"{name}_split"] = ecfp_fv_dict

for name in ["P00918", "Q16790", "O43570"]:

    if os.path.exists(f"svm_{name}.pkl"):
        continue
    if name == "P00918":
        svm = SVC(**CA2_HP, probability=True, random_state=42)

    elif name == "Q16790":
        svm = SVC(**CA9_HP, probability=True, random_state=42)
    else:
        svm = SVC(**CA12_HP, probability=True, random_state=42)

    svm.fit(
        folds_dict[f"{name}_split"]["train_fv"],
        folds_dict[f"{name}_split"]["train_label"],
    )

    with open(f"svm_{name}.pkl", "wb") as file:
        pickle.dump(svm, file)

with open('svm_P00918.pkl', 'rb') as f:
    ca2_svm = pickle.load(f)

with open('svm_Q16790.pkl', 'rb') as f:
    ca9_svm = pickle.load(f)

with open('svm_O43570.pkl', 'rb') as f:
    ca12_svm = pickle.load(f)


for name in ["P00918", "Q16790", "O43570"]:
        split = folds_dict[f"{name}_split"]
        if name == "P00918":
                pred_df = pd.DataFrame(ca2_svm.predict_proba(split[f"val_fv"]))
        elif name == "Q16790":
                pred_df = pd.DataFrame(ca9_svm.predict_proba(split[f"val_fv"]))
        else: 
                pred_df = pd.DataFrame(ca12_svm.predict_proba(split[f"val_fv"]))

        pred_df.columns = ["active", "inactive"]
        label_df = pd.DataFrame(split[f"val_label"])
        label_df = label_df.rename({0: "label"}, axis=1)
        final_df = pd.concat([pred_df, label_df], axis=1)
        final_df.to_csv(
                f"svm_ecfp_{uniprot_to_isoform[name]}_prob.csv", index=False)

## 12-2- Backend

In [None]:
import pickle
from typing import Dict, List, Tuple

import exmol
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import font_manager
from rdkit import Chem
from rdkit.Chem import AllChem


def predict_with_conformal(
    smiles: str, model_info: Dict[str, Tuple[str, str]], epsilon: float = 0.3
) -> Dict[str, List[int]]:
    """
    Predicts class membership for a molecule using conformal prediction with SVM models.

    Parameters
    ----------
    smiles : str
        Molecule in SMILES format.
    model_info : dict of str to tuple (str, str)
        Mapping of model labels to (model pickle file path, calibration CSV file path).
    epsilon : float, optional
        Threshold for class inclusion in prediction set (default is 0.3).

    Returns
    -------
    dict of str to list of int
        Dictionary mapping model labels to prediction sets (class indices 0 or 1).
    """
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES string")

    ecfp_fv = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    feature_array = np.array(ecfp_fv).reshape(1, -1)

    results = {}

    for label, (model_file, calibration_csv) in model_info.items():
        with open(model_file, "rb") as f:
            model = pickle.load(f)

        cal_scores_df = pd.read_csv(calibration_csv)
        probs = model.predict_proba(feature_array)[0]
        prediction_set = []

        for class_idx, class_name in enumerate(["inactive", "active"]):
            conformity_scores = cal_scores_df[class_name].to_numpy()
            score_test = probs[class_idx]
            p_value = (np.sum(conformity_scores <= score_test) + 1) / (
                len(conformity_scores) + 1
            )
            if p_value > epsilon:
                prediction_set.append(class_idx)

        results[label] = prediction_set

    return results


def model(smiles: str, svm) -> int:
    """
    Predict the binary class of a molecule using an SVM model.

    Parameters
    ----------
    smiles : str
        SMILES string representing the molecule.
    svm : object
        Preloaded SVM model with a `predict` method.

    Returns
    -------
    int
        Predicted class label (0 or 1).
    """
    mol = Chem.MolFromSmiles(smiles)
    ecfp_fv = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    y_predict = svm.predict(np.array(ecfp_fv).reshape(1, -1))
    return int(y_predict[0])


def counterfactual_explain(samples, name: str):
    """
    Generate and save counterfactual explanations for molecule samples.

    Parameters
    ----------
    samples : list
        Molecular samples to explain, typically output from `exmol.sample_space`.
    name : str
        Identifier prefix for the saved SVG file.

    Returns
    -------
    None
        Saves an SVG file named '{name}_counterfactual_samples.svg' with counterfactual plots.
    """
    sns.set_context("notebook")
    sns.set_style("dark")

    font_manager.findfont("Helvetica")
    plt.rc("font", family="Helvetica")
    plt.rc("font", serif="Helvetica", size=22)

    fkw = {
        "figsize": (8, 6),
        "dpi": 300,
        "facecolor": "white",
        "edgecolor": "white",
    }

    cfs = exmol.cf_explain(samples, nmols=3)
    exmol.plot_cf(cfs, figure_kwargs=fkw, mol_size=(350, 300), nrows=2, mol_fontsize=8)
    plt.tight_layout()

    svg = exmol.insert_svg(cfs)
    with open(f"{name}_counterfactual_samples.svg", "w") as f:
        f.write(svg)

In [None]:
smiles = "NS(=O)(=O)c1ccc(NC(=O)Nc2ccc(F)cc2)cc1"

model_info = {
    'CA2': ('svm_P00918.pkl', 'svm_ecfp_CA2_prob.csv'),
    'CA9': ('svm_Q16790.pkl', 'svm_ecfp_CA9_prob.csv'),
    'CA12': ('svm_O43570.pkl', 'svm_ecfp_CA12_prob.csv')
}

predictions = predict_with_conformal(smiles, model_info)
print(predictions)


isoform_models = {
    "CA2": "svm_P00918.pkl",
    "CA9": "svm_Q16790.pkl",
    "CA12": "svm_O43570.pkl"
}

for isoform, model_path in isoform_models.items():
    with open(model_path, "rb") as f:
        svm_model = pickle.load(f)

    model_fn = lambda smiles: model(smiles, svm_model)

    samples = exmol.sample_space(smiles, model_fn, batched=False)
    counterfactual_explain(samples, isoform)
