In [None]:
%cd ../

In [None]:
import glob
import os
from collections import defaultdict
from typing import List, Tuple

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots
import seaborn as sns
import torch
from rdkit.Chem import Descriptors, Draw
from rdkit.Chem.rdmolfiles import MolFromSmarts, MolFromSmiles
from rdkit.Chem.Scaffolds import MurckoScaffold
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder
from tokenizers import Tokenizer
from torchmetrics.functional.clustering import davies_bouldin_score
from tqdm.auto import tqdm

from src.data.components.utils import smiles2vector_fg, smiles2vector_mfg
from src.models.fgr_module import FGRLitModule

plt.style.use(["science", "nature"])
plt.rcParams["figure.figsize"] = (4, 3)

In [None]:
data_dir = "./data/processed"
fgroups = pd.read_parquet(os.path.join(data_dir, "training", "fg.parquet"))[
    "SMARTS"
].tolist()  # Get functional groups
fgroups_list = [MolFromSmarts(x) for x in fgroups]  # Convert to RDKit Mol
tokenizer = Tokenizer.from_file(
    os.path.join(
        data_dir,
        "training",
        "tokenizers",
        f"BPE_pubchem_{500}.json",
    )
)  # Load tokenizer

In [None]:
def get_descriptors(smi_list: List[str]) -> torch.Tensor:
    """Get descriptors from list of SMILES strings.

    :param smi_list: List of SMILES strings
    :return: Stacked descriptor tensor
    """
    desc_tensors = []
    for smi in smi_list:
        mol = MolFromSmiles(smi)  # Get molecule from SMILES string

        # Get descriptors
        desc_list = []
        for _, func in Descriptors._descList:
            try:
                desc_list.append(func(mol))
            except BaseException:
                desc_list.append(0)
        descriptors = torch.FloatTensor(desc_list)
        descriptors = torch.nan_to_num(
            descriptors, nan=0.0, posinf=0.0, neginf=0.0
        )  # Replace NaNs with 0
        descriptors = descriptors / torch.norm(descriptors)  # Normalize

        desc_tensors.append(descriptors)

    return torch.stack(desc_tensors)

In [None]:
def get_representation(
    smiles: List[str],
    method: str,
    fgroups_list: List[MolFromSmarts],
    tokenizer: Tokenizer,
    dataset: str = "BBBP",
    mode: str = "train",
) -> np.ndarray:
    if method == "FG":
        x = np.stack([smiles2vector_fg(x, fgroups_list) for x in smiles])
    elif method == "MFG":
        x = np.stack([smiles2vector_mfg(x, tokenizer) for x in smiles])
    elif method == "FGR":
        f_g = np.stack([smiles2vector_fg(x, fgroups_list) for x in smiles])
        mfg = np.stack([smiles2vector_mfg(x, tokenizer) for x in smiles])
        x = np.concatenate((f_g, mfg), axis=1)  # Concatenate both vectors
    else:
        raise ValueError("Method not supported")  # Raise error if method not supported
    if mode == "input":
        return x
    elif mode == "train":
        ckpt_path = glob.glob(
            f"./logs/train/multiruns/*/{dataset}/{method}/scaffold/*/checkpoints/last.ckpt"
        )[0]
        model = FGRLitModule.load_from_checkpoint(ckpt_path)
        model.eval()
        with torch.no_grad():
            x = torch.tensor(x, dtype=torch.float32, device=model.device)
            desc = get_descriptors(smiles).to(model.device)
            x = model((x, desc))
        return x[1].cpu().numpy()
    else:
        raise ValueError("Mode not supported")

In [None]:
def get_scaffolds(s: List[str]) -> Tuple[pd.DataFrame, LabelEncoder]:
    scaffolds = defaultdict(set)
    idx2mol = dict(zip(list(range(len(s))), s))
    error_smiles = 0
    for i, smiles in enumerate(s):
        try:
            scaffold = MurckoScaffold.MurckoScaffoldSmiles(
                mol=MolFromSmiles(smiles), includeChirality=False
            )
            scaffolds[scaffold].add(i)
        except BaseException:
            print(smiles + " returns RDKit error and is thus omitted...")
            error_smiles += 1

    top_5_scaffolds = sorted(
        ((k, v) for k, v in scaffolds.items() if k != ""),
        key=lambda item: len(item[1]),
        reverse=True,
    )[:5]
    data = [(idx2mol[idx], scaffold) for scaffold, indices in top_5_scaffolds for idx in indices]
    scaffold_df = pd.DataFrame(data, columns=["SMILES", "Label"])
    label_encoder = LabelEncoder()
    scaffold_df["Label"] = label_encoder.fit_transform(scaffold_df["Label"])
    return scaffold_df, label_encoder

In [None]:
def plot_scaffolds(dataset_dir: str, label_encoder: LabelEncoder) -> None:
    scaffold_dir = f"{dataset_dir}/scaffolds"
    os.makedirs(scaffold_dir, exist_ok=True)
    for i in range(5):
        scaffold = label_encoder.inverse_transform([i])[0]
        mol = MolFromSmiles(scaffold)
        Draw.MolToFile(
            mol,
            f"{scaffold_dir}/scaffold_{i}.svg",
        )

In [None]:
def plot_kde_2d(method: str, components: np.ndarray, ax: mpl.axes._axes.Axes) -> None:  # type: ignore
    sns.kdeplot(
        x=components[:, 0],
        y=components[:, 1],
        cmap="rocket_r",
        fill=True,
        levels=500,
        bw_adjust=0.2,
        cbar=True,
        ax=ax,
    )
    ax.set_xlabel("Features")
    ax.set_ylabel("Features")
    ax.set_title(r"$\textbf{Method:}$" f"{method}")


def plot_kde(components: np.ndarray, ax: mpl.axes._axes.Axes) -> None:  # type: ignore
    # Calculate the angles
    angles = np.arctan2(components[:, 1], components[:, 0])
    sns.kdeplot(x=angles, cmap="Blues", fill=True, ax=ax)
    ax.set_xlabel("Angles")
    ax.set_ylabel("Density")


def plot_tsne(
    method: str,
    dbi: float,
    components: np.ndarray,
    labels: np.ndarray,
    ax: mpl.axes._axes.Axes,  # type: ignore
) -> None:
    sns.scatterplot(
        ax=ax,
        x=components[:, 0],
        y=components[:, 1],
        hue=labels,
        palette=sns.color_palette("colorblind"),
    )
    ax.text(
        0.95,
        0.05,
        r"$\textbf{DBI:}$" f"{dbi}",
        transform=ax.transAxes,
        ha="right",
        va="bottom",
        fontsize=8,
        bbox=dict(facecolor="grey", alpha=0.2, edgecolor="black"),
    )
    ax.get_legend().set_visible(False)
    ax.set_title(r"$\textbf{Method:}$" f"{method}")

In [None]:
def plot_dataset(
    dataset: str, fgroups_list: List[MolFromSmarts], tokenizer: Tokenizer, mode: str
) -> None:
    methods = ["FG", "MFG", "FGR"]

    df = pd.read_parquet(f"./data/processed/tasks/{dataset}/{dataset}.parquet")
    s = df["SMILES"].tolist()

    dataset_dir = f"./reports/figures/{dataset}"
    os.makedirs(dataset_dir, exist_ok=True)

    scaffold_df, label_encoder = get_scaffolds(s)
    plot_scaffolds(dataset_dir, label_encoder)

    tsne_fig, tsne_axes = plt.subplots(1, 3, figsize=(12, 3))
    for i, method in enumerate(methods):
        x = get_representation(
            scaffold_df["SMILES"].tolist(), method, fgroups_list, tokenizer, dataset, mode
        )
        components = TSNE(random_state=123).fit_transform(x)  # type: ignore
        # Calculate DBI
        dbi = round(
            float(
                davies_bouldin_score(
                    torch.tensor(x), torch.tensor(scaffold_df["Label"].to_numpy()).reshape(-1)
                )
            ),
            2,
        )
        plot_tsne(method, dbi, components, scaffold_df["Label"].to_numpy(), ax=tsne_axes[i])
    tsne_fig.savefig(
        f"{dataset_dir}/{mode}_alignment.png",
        bbox_inches="tight",
        dpi=600,
        transparent=True,
    )
    plt.close()

    kde_fig, kde_axes = plt.subplots(2, 3, figsize=(12, 6))
    for i, method in enumerate(methods):
        x = get_representation(s, method, fgroups_list, tokenizer)
        components = TSNE(random_state=123).fit_transform(x)  # type: ignore
        # Normalize the points to project them onto the unit circle
        tsne_norm = components / np.linalg.norm(components, axis=1, keepdims=True)
        plot_kde_2d(method, tsne_norm, ax=kde_axes[0, i])
        plot_kde(tsne_norm, ax=kde_axes[1, i])
    kde_fig.savefig(
        f"{dataset_dir}/{mode}_uniformity.png",
        bbox_inches="tight",
        dpi=600,
        transparent=True,
    )
    plt.close()

In [None]:
df = pd.read_parquet("./data/processed/tasks/summary.parquet")
tasks = df[df["Datapoints"] < 100000]["Task"].tolist()
for task in tqdm(tasks):
    plot_dataset(task, fgroups_list, tokenizer, "input")