In [None]:
import torch
import pandas as pd
from rdkit import Chem
from sklearn.preprocessing import MinMaxScaler

from src.featurizers import GraphFeaturizer
from src.models.gnn import GraphConvolutionalNetwork
from src.dataset import (
    load_herg_data_split,
    load_cyp_data_split,
    load_pampa_data_split,
    load_synthetic_data_split,
)
from src.explanations import (
    grad_cam,
    saliency_map,
    complete_rings_in_components,
    plot_grad_cam_explanation,
    get_n_atom_connected_components,
)
from src.utils import (
    get_sub_molecule,
    get_iupac_name_of_smiles
)
from tuning_results import (
    synthetic_gnn_params,
    herg_gnn_params,
    cyp_gnn_params,
    pampa_gnn_params
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train, _, test = load_cyp_data_split()

graph_featurizer = GraphFeaturizer("y", log_target_transform=False)
graph_train = graph_featurizer(train)
graph_test = graph_featurizer(test)

best_params = cyp_gnn_params
dataset_name = "cyp"
model = GraphConvolutionalNetwork(
    input_dim=graph_train[0].x.shape[1],
    hidden_size=best_params["hidden_size"],
    n_layers=best_params["num_layers"],
    dropout=best_params["dropout"]
).to(device)
model.load_state_dict(torch.load(f"models/gnn_tuned_{dataset_name}.pth"))

### Find all connected components of size N in a molecule among atoms highlighted by an explanation method

This will serve as an input to the LLM when asking about the effect of having a certain component in the molecule on an ADMET property.

In [None]:
mol_index = 7
grad_cam_score = grad_cam(model, graph_train[mol_index])
plot_grad_cam_explanation(model, train[mol_index], graph_train[mol_index])

In [None]:
n_atom_connected_components = get_n_atom_connected_components(train[mol_index], grad_cam_score, 0.1, 3)
n_atom_connected_components

In [None]:
complete_components = complete_rings_in_components(train[mol_index], n_atom_connected_components)
complete_components

In [None]:
get_sub_molecule(train[mol_index], complete_components[0])

### Finding most important components and converting to IUPAC

In [None]:
iupacs_grad_cam = {}
iupacs_saliency_map = {}

smiles_grad_cam = {}
smiles_saliency_map = {}

In [None]:
current_results = []
current_smiles = []
for component_size in range(3, 9):
    components = []
    for i in range(len(graph_train)):
        exp_scores = grad_cam(model, graph_train[i])
        scaled_exp_scores = MinMaxScaler().fit_transform(exp_scores.reshape(-1, 1)).reshape(-1)
        connected_components = complete_rings_in_components(train[i], get_n_atom_connected_components(train[i], scaled_exp_scores, 0.1, component_size))
        connected_components_mols = [get_sub_molecule(train[i], component) for component in connected_components]
        components.extend(connected_components_mols)
    components_smiles = list(map(Chem.MolToSmiles, components))
    components_smiles_series = pd.Series(components_smiles)
    smiles_to_input = list(components_smiles_series.value_counts().head(10).index)
    iupac_names = []
    for smiles in smiles_to_input:
        try:
            iupac_name = get_iupac_name_of_smiles(smiles)
            iupac_names.append(iupac_name)
        except:
            iupac_names.append(smiles)
    current_results.append(iupac_names)
    current_smiles.append(smiles_to_input)
iupacs_grad_cam[dataset_name] = current_results
smiles_grad_cam[dataset_name] = current_smiles

In [None]:
current_results = []
current_smiles = []
for component_size in range(3, 9):
    components = []
    for i in range(len(graph_train)):
        exp_scores = saliency_map(model, graph_train[i])
        scaled_exp_scores = MinMaxScaler().fit_transform(exp_scores.reshape(-1, 1)).reshape(-1)
        connected_components = complete_rings_in_components(train[i], get_n_atom_connected_components(train[i], scaled_exp_scores, 0.1, component_size))
        connected_components_mols = [get_sub_molecule(train[i], component) for component in connected_components]
        components.extend(connected_components_mols)
    components_smiles = list(map(Chem.MolToSmiles, components))
    components_smiles_series = pd.Series(components_smiles)
    smiles_to_input = list(components_smiles_series.value_counts().head(10).index)
    iupac_names = []
    for smiles in smiles_to_input:
        try:
            iupac_name = get_iupac_name_of_smiles(smiles)
            iupac_names.append(iupac_name)
        except:
            iupac_names.append(smiles)
    current_results.append(iupac_names)
    current_smiles.append(smiles_to_input)
iupacs_saliency_map[dataset_name] = current_results
smiles_saliency_map[dataset_name] = current_smiles