In [None]:
%cd ../

In [None]:
import os
from collections import defaultdict
from typing import List, Tuple

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots
import seaborn as sns
import torch
import umap
from rdkit.Chem import Draw
from rdkit.Chem.rdmolfiles import MolFromSmarts, MolFromSmiles
from rdkit.Chem.Scaffolds import MurckoScaffold
from sklearn.preprocessing import LabelEncoder
from tokenizers import Tokenizer
from torchmetrics.functional.clustering import davies_bouldin_score
from tqdm.auto import tqdm

from src.data.components.utils import smiles2vector_fg, smiles2vector_mfg

plt.style.use(["science", "nature"])
plt.rcParams["figure.figsize"] = (4, 3)

In [None]:
data_dir = "./data/processed"
fgroups = pd.read_parquet(os.path.join(data_dir, "training", "fg"))[
    "SMARTS"
].tolist()  # Get functional groups
fgroups_list = [MolFromSmarts(x) for x in fgroups]  # Convert to RDKit Mol
tokenizer = Tokenizer.from_file(
    os.path.join(
        data_dir,
        "training",
        "tokenizers",
        f"BPE_pubchem_{500}.json",
    )
)  # Load tokenizer

In [None]:
def get_representation(
    smiles: List[str], method: str, fgroups_list: List[MolFromSmarts], tokenizer: Tokenizer
) -> torch.Tensor:
    if method == "FG":
        x = torch.stack([smiles2vector_fg(x, fgroups_list) for x in smiles])
    elif method == "MFG":
        x = torch.stack([smiles2vector_mfg(x, tokenizer) for x in smiles])
    elif method == "FGR":
        f_g = torch.stack([smiles2vector_fg(x, fgroups_list) for x in smiles])
        mfg = torch.stack([smiles2vector_mfg(x, tokenizer) for x in smiles])
        x = torch.concat((f_g, mfg), dim=1)  # Concatenate both vectors
    else:
        raise ValueError("Method not supported")  # Raise error if method not supported
    return x

In [None]:
def get_scaffolds(s: List[str]) -> Tuple[pd.DataFrame, LabelEncoder]:
    scaffolds = defaultdict(set)
    idx2mol = dict(zip(list(range(len(s))), s))
    error_smiles = 0
    for i, smiles in enumerate(s):
        try:
            scaffold = MurckoScaffold.MurckoScaffoldSmiles(
                mol=MolFromSmiles(smiles), includeChirality=False
            )
            scaffolds[scaffold].add(i)
        except BaseException:
            print(smiles + " returns RDKit error and is thus omitted...")
            error_smiles += 1

    top_5_scaffolds = sorted(
        ((k, v) for k, v in scaffolds.items() if k != ""),
        key=lambda item: len(item[1]),
        reverse=True,
    )[:10]
    data = [(idx2mol[idx], scaffold) for scaffold, indices in top_5_scaffolds for idx in indices]
    scaffold_df = pd.DataFrame(data, columns=["SMILES", "Label"])
    label_encoder = LabelEncoder()
    scaffold_df["Label"] = label_encoder.fit_transform(scaffold_df["Label"])
    return scaffold_df, label_encoder

In [None]:
def plot_scaffolds(dataset_dir: str, label_encoder: LabelEncoder) -> None:
    scaffold_dir = f"{dataset_dir}/scaffolds"
    os.makedirs(scaffold_dir, exist_ok=True)
    for i in range(10):
        scaffold = label_encoder.inverse_transform([i])[0]
        mol = MolFromSmiles(scaffold)
        Draw.MolToFile(
            mol,
            f"{scaffold_dir}/scaffold_{i}.svg",
        )

In [None]:
def plot_kde_2d(dataset_dir: str, method: str, plot_df: pd.DataFrame, mode: str) -> None:
    fig, ax = plt.subplots(figsize=(4, 3))
    sns.kdeplot(x=plot_df["UMAP 1"], y=plot_df["UMAP 2"], cmap="Blues", fill=True, ax=ax)
    ax.set_xlabel("Features")
    ax.set_ylabel("Features")
    fig.savefig(f"{dataset_dir}/{method}_{mode}_kde_2d.jpg", bbox_inches="tight", dpi=600)
    plt.close()


def plot_kde(dataset_dir: str, method: str, plot_df: pd.DataFrame, mode: str) -> None:
    fig, ax = plt.subplots(figsize=(4, 3))
    plot_df["arctan2"] = np.arctan2(plot_df["UMAP 2"], plot_df["UMAP 1"])
    sns.kdeplot(x=plot_df["arctan2"], cmap="Blues", fill=True, ax=ax)
    ax.set_xlabel("Angles")
    ax.set_ylabel("Density")
    fig.savefig(f"{dataset_dir}/{method}_{mode}_kde.jpg", bbox_inches="tight", dpi=600)
    plt.close()


def plot_umap(dataset_dir: str, method: str, dbi: float, plot_df: pd.DataFrame, mode: str) -> None:
    fig, ax = plt.subplots(figsize=(4, 3))
    sns.scatterplot(
        ax=ax,
        data=plot_df,
        x="UMAP 1",
        y="UMAP 2",
        hue="Scaffold",
        palette=sns.color_palette("colorblind"),
    )
    ax.text(
        0.95,
        0.05,
        r"$\textbf{DBI:}$" f"{dbi}",
        transform=ax.transAxes,
        ha="right",
        va="bottom",
        fontsize=8,
        bbox=dict(facecolor="grey", alpha=0.2, edgecolor="black"),
    )
    ax.get_legend().set_visible(False)
    fig.savefig(f"{dataset_dir}/{method}_{mode}_umap.jpg", bbox_inches="tight", dpi=600)
    plt.close()

In [None]:
def plot_dataset(
    dataset: str, fgroups_list: List[MolFromSmarts], tokenizer: Tokenizer, mode: str
) -> None:
    df = pd.read_parquet(f"./data/processed/tasks/{dataset}/{dataset}.parquet")
    s = df["SMILES"].values

    dataset_dir = f"./reports/figures/{dataset}"
    os.makedirs(dataset_dir, exist_ok=True)

    scaffold_df, label_encoder = get_scaffolds(s)
    plot_scaffolds(dataset_dir, label_encoder)

    for method in ["FG", "MFG", "FGR"]:
        x = get_representation(scaffold_df["SMILES"], method, fgroups_list, tokenizer)
        components = umap.UMAP(random_state=123).fit_transform(x)
        dbi = round(
            float(
                davies_bouldin_score(x, torch.tensor(scaffold_df["Label"].to_numpy()).reshape(-1))
            ),
            2,
        )
        plot_df = pd.DataFrame(components, columns=["UMAP 1", "UMAP 2"])
        plot_df["Scaffold"] = scaffold_df["Label"]

        plot_umap(dataset_dir, method, dbi, plot_df, mode)
        plot_kde(dataset_dir, method, plot_df, mode)
        plot_kde_2d(dataset_dir, method, plot_df, mode)

In [None]:
tasks = os.listdir("./data/processed/tasks")
for task in tqdm(tasks):
    plot_dataset(task, fgroups_list, tokenizer, "input")