In [None]:
# mount google drive
from google.colab import drive

drive.mount("/content/drive", force_remount=True)

In [1]:
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import rdMolDescriptors
import plotly.graph_objects as go
import pickle
from pprint import pprint as pp
from tqdm import tqdm
import itertools
import umap
from sklearn.manifold import TSNE
import scipy


# BASE_PATH = '/content/drive/MyDrive/Generative_ML/current_data/' #@param {type:"string"}
BASE_PATH = "/Users/morgunov/batista/Summer/pipeline/"

PRETRAINING_PATH = BASE_PATH + "1. Pretraining/"
GENERATION_PATH = BASE_PATH + "2. Generation/"
SAMPLING_PATH = BASE_PATH + "3. Sampling/"
DIFFDOCK_PATH = BASE_PATH + "4. DiffDock/"
SCORING_PATH = BASE_PATH + "5. Scoring/"
AL_PATH = BASE_PATH + "6. ActiveLearning/"
PICKLES = BASE_PATH + "Archive/pickle/"

  @numba.jit()
  @numba.jit()
  @numba.jit()
  from .autonotebook import tqdm as notebook_tqdm
  @numba.jit()


# Processing dataset descriptors

## Convert each dictionary to a dataframe

In [None]:
def split_the_dictionary(fname):
    with open(PICKLES + f"{fname}.pkl", "rb") as f:
        smiles_to_descriptors = pickle.load(f)
    smiles = list(smiles_to_descriptors.keys())
    half_index = int(len(smiles) // 2)
    pt1 = {}
    pt2 = {}
    for i, smile in enumerate(smiles):
        if i < half_index:
            pt1[smile] = smiles_to_descriptors[smile]
        else:
            pt2[smile] = smiles_to_descriptors[smile]
    print(len(pt1), len(pt2), len(pt1) + len(pt2), len(smiles_to_descriptors))
    pickle.dump(pt1, open(PICKLES + f"{fname}_subpt1.pkl", "wb"))
    pickle.dump(pt2, open(PICKLES + f"{fname}_subpt2.pkl", "wb"))

In [None]:
split_the_dictionary("smile_to_descriptors_pt1")

In [None]:
def pickle_to_csv(fname):
    with open(PICKLES + f"{fname}.pkl", "rb") as f:
        smiles_to_descriptors = pickle.load(f)
    keyToData = {}
    keys = pickle.load(open(PICKLES + "descriptors_list.pkl", "rb"))
    pbar = tqdm(
        smiles_to_descriptors.items(), total=len(smiles_to_descriptors), desc=fname
    )
    for smile, descriptors in pbar:
        keyToData.setdefault("smile", []).append(smile)
        for key in keys:
            keyToData.setdefault(key, []).append(descriptors[key])
    df = pd.DataFrame(keyToData)
    df.to_pickle(PICKLES + "_".join(fname.split("_")[2:]) + ".pkl")
    return df

In [None]:
pickle_to_csv("smile_to_descriptors_pt1_subpt1")  # done
pickle_to_csv("smile_to_descriptors_pt1_subpt2")  # done
pickle_to_csv("smile_to_descriptors_pt2_subpt1")  # done
pickle_to_csv("smile_to_descriptors_pt2_subpt2")
pickle_to_csv("smile_to_descriptors_pt3")  # done

## Combine dataframes

In [None]:
def merge_parts():
    pt1_sbpt1 = pd.read_pickle(PICKLES + "descriptors_pt1_subpt1.pkl")
    pt1_sbpt2 = pd.read_pickle(PICKLES + "descriptors_pt1_subpt2.pkl")
    pt2_sbpt1 = pd.read_pickle(PICKLES + "descriptors_pt2_subpt1.pkl")
    pt2_sbpt2 = pd.read_pickle(PICKLES + "descriptors_pt2_subpt2.pkl")
    pt3 = pd.read_pickle(PICKLES + "descriptors_pt3.pkl")
    print(pt1_sbpt1.shape, pt1_sbpt2.shape, pt2_sbpt1.shape, pt2_sbpt2.shape, pt3.shape)
    merged_df = pd.concat([pt1_sbpt1, pt1_sbpt2, pt2_sbpt1, pt2_sbpt2, pt3])
    print(merged_df.shape)
    merged_df.to_pickle(PICKLES + "descriptors_combined.pkl")


merge_parts()

In [None]:
def analyze_training_sets():
    def combine_sets(train_fname, valid_fname):
        train_df = pd.read_csv(f"{PRETRAINING_PATH}datasets/{train_fname}.csv.gz")
        valid_df = pd.read_csv(f"{PRETRAINING_PATH}datasets/{valid_fname}.csv.gz")
        combined_df = pd.concat([train_df, valid_df], ignore_index=True)
        return combined_df

    train_100 = combine_sets(
        "combined_processed_freq100_block133_train",
        "combined_processed_freq100_block133_val",
    )
    train_1000 = combine_sets(
        "combined_processed_freq1000_block133_train",
        "combined_processed_freq1000_block133_val",
    )
    print(f"{train_100.shape=}, {train_1000.shape=}")
    in100 = set(train_100["smiles"]) - set(train_1000["smiles"])
    in1000 = set(train_1000["smiles"]) - set(train_100["smiles"])
    print(f"{len(in100)=}, {len(in1000)=}")


analyze_training_sets()

In [None]:
def fill_descriptors_for_training_set(
    all_descriptors_path, train_fname, valid_fname, extra_mols_path=None
):
    descriptors_wsmiles = pd.read_pickle(PICKLES + "descriptors_combined.pkl")
    descriptors_wsmiles.rename(columns={"smile": "smiles"}, inplace=True)
    print(f"Descriptors df has shape {descriptors_wsmiles.shape}")
    train_df = pd.read_csv(f"{PRETRAINING_PATH}datasets/{train_fname}.csv.gz")
    valid_df = pd.read_csv(f"{PRETRAINING_PATH}datasets/{valid_fname}.csv.gz")
    combined_df = pd.concat([train_df, valid_df], ignore_index=True)
    # assert len(set(descriptors_wsmiles['smiles'])-set(descriptors_wsmiles['smiles'])) == 0, "Combined df"
    print(f"Training/validation df has shape {combined_df.shape}")
    merged_df = combined_df.merge(descriptors_wsmiles, on="smiles", how="inner")
    extra_mols = set(combined_df["smiles"]) - set(merged_df["smiles"])
    if len(extra_mols) != 0:
        print(f"Extra molecules in the training set detected")
        if extra_mols_path is None:
            pd.DataFrame({"smiles": list(extra_mols)}).to_csv(
                f"{PRETRAINING_PATH}descriptors/{train_fname[:-6]}_extra_mols.csv"
            )
            print(
                f"Saved extra mols to {PRETRAINING_PATH}descriptors/{train_fname[:-6]}_extra_mols.csv, abort execution"
            )
            return
        else:
            extra_descriptors = pd.read_pickle(extra_mols_path)
            assert (
                len(set(extra_descriptors["smiles"]) - extra_mols) == 0
            ), "Extra descriptors do not cover all extra_mols"
            extra_descriptors = extra_descriptors[
                extra_descriptors["smiles"].isin(extra_mols)
            ]

    final_df = pd.concat([merged_df, extra_descriptors])
    print(f"Final descriptors df has shape {final_df.shape}")
    assert (
        final_df.shape[0] == combined_df.shape[0]
    ), "Number of molecules in training/valid set and descriptors df do not match"
    final_df.to_pickle(f"{PRETRAINING_PATH}descriptors/{train_fname[:-6]}.pkl")

In [None]:
fill_descriptors_for_training_set(
    f"{PICKLES}descriptors_combined.pkl",
    "combined_processed_freq1000_block133_train",
    "combined_processed_freq1000_block133_val",
    f"{PRETRAINING_PATH}descriptors/combined_processed_freq1000_block133_extra_mols.pkl",
)

# For 100
# Descriptors df has shape (5770637, 210)
# Training/validation df has shape (5560874, 2)
# Extra molecules in the training set detected
# Final descriptors df has shape (5560874, 211)

# Definitions

In [None]:
def fit_pca(dataframe, n=3, sigma=None, whiten=False):
    scaler = StandardScaler()
    scaled_data = scaler.fit_transform(dataframe)
    if sigma is not None:
        scaled_data = scaled_data[(scaled_data <= sigma).all(axis=1)]
    pca = PCA(n_components=n, whiten=whiten)
    pca.fit(scaled_data)
    return scaler, pca


def pca_transform(pca, dataframe, n):
    assert (
        pca.n_components_ >= n
    ), f"PCA was fitted on {pca.n_components_} components, but {n} were requested."
    transformed = pca.transform(dataframe)
    return [transformed[:, i] for i in range(n)]


def plot_pca(scaler, pca, datapoints, yscale=1.05):
    fig = go.Figure()
    minX, minY, maxX, maxY = float("inf"), float("inf"), float("-inf"), float("-inf")
    traces = []
    for data, color, label in datapoints:
        transformed = pca.transform(
            scaler.transform(data[scaler.get_feature_names_out()])
        )
        xarr = transformed[:, 0]
        yarr = transformed[:, 1]
        minX = min(minX, min(xarr))
        minY = min(minY, min(yarr))
        maxX = max(maxX, max(xarr))
        maxY = max(maxY, max(yarr))
        fig.add_trace(
            go.Scatter(
                x=xarr,
                y=yarr,
                mode="markers",
                name=label,
                marker=dict(
                    size=5,
                    color=color,
                    showscale=True if isinstance(color, (list, np.ndarray)) else False,
                    colorscale="Viridis",
                    opacity=0.5,
                ),
            )
        )

    fig.update_layout(
        xaxis=dict(
            title="PCA Component 1",
            autorange=False,
            range=[yscale * minX, yscale * maxX],
        ),
        yaxis=dict(
            title="PCA Component 2",
            autorange=False,
            range=[yscale * minY, yscale * maxY],
        ),
    )
    return fig

In [None]:
def get_data_boundaries(data_list):
    combined = np.vstack(
        data_list
    )  # Combine both datasets to get overall min and max values

    min_val = (
        np.floor(np.min(combined, axis=0) / 10.0) * 10
    )  # Round to the nearest number divisible by 10
    max_val = np.ceil(np.max(combined, axis=0) / 10.0) * 10
    return min_val, max_val


def discretize_data(data, boundaries, bin_size):
    # Use 2D histogram to discretize data
    bins = [np.arange(boundaries[0][i], boundaries[1][i], bin_size) for i in range(2)]
    hist_data, xedges, yedges = np.histogram2d(data[:, 0], data[:, 1], bins=bins)

    # Compute bin centers
    xcenters = (xedges[:-1] + xedges[1:]) / 2
    ycenters = (yedges[:-1] + yedges[1:]) / 2

    return hist_data.T, xcenters, ycenters


def plot_heatmap(
    args_list,
    difference=False,
    bin_size=10,
    width=1280,
    height=720,
    all_differences=False,
):
    data_list, name_list = zip(*args_list)
    boundaries = get_data_boundaries(data_list)
    fig = go.Figure()

    if difference:
        if all_differences:
            traces = []
            for i, ((data_before, name_before), (data_after, name_after)) in enumerate(
                itertools.combinations(args_list, 2)
            ):
                hist_before, xcenters, ycenters = discretize_data(
                    data_before, boundaries, bin_size
                )
                hist_after, _, _ = discretize_data(data_after, boundaries, bin_size)
                diff = hist_after - hist_before
                label = f"|{name_after}|<br>-|{name_before}|"
                traces.append(
                    go.Heatmap(
                        x=xcenters,
                        y=ycenters,
                        z=diff,
                        zmid=0,
                        zmax=110,
                        zmin=-110,
                        colorscale="RdBu",
                        name=label,
                        showlegend=True,
                        visible=True if i == 0 else "legendonly",
                    )
                )
            for trace in traces:
                fig.add_trace(trace)
        else:
            assert (
                len(data_list) == 2
            ), f"To plot a difference, please provide only 2 data sources"
            hist_before, xcenters, ycenters = discretize_data(
                data_list[0], boundaries, bin_size
            )
            hist_after, _, _ = discretize_data(data_list[1], boundaries, bin_size)
            diff = hist_after - hist_before
            label = f"|{name_list[1]}|<br>-|{name_list[0]}|"
            fig.add_trace(
                go.Heatmap(
                    x=xcenters,
                    y=ycenters,
                    z=diff,
                    zmid=0,
                    colorscale="RdBu",
                    name=label,
                    showlegend=True,
                )
            )
    else:
        traces = []
        zmax = max(
            [discretize_data(data, boundaries, bin_size)[0].max() for data in data_list]
        )
        zmin = min(
            [discretize_data(data, boundaries, bin_size)[0].min() for data in data_list]
        )
        for i, (data, name) in enumerate(args_list):
            hist, xcenters, ycenters = discretize_data(data, boundaries, bin_size)
            if "al1 good" in name:
                mult = 1 / 50
            else:
                mult = 1
            traces.append(
                go.Heatmap(
                    x=xcenters,
                    y=ycenters,
                    z=hist,
                    name=name,
                    zmin=zmin,
                    zmax=mult * zmax,
                    showlegend=True,
                    visible=True if i == 0 else "legendonly",
                )
            )
        for trace in traces:
            fig.add_trace(trace)

    fig.update_layout(
        title=f"Difference in distribution: # of datapoints per bin ({bin_size=})",
        xaxis_title="X",
        yaxis_title="Y",
        width=width,
        height=height,
        legend=dict(x=1.2, y=1),
    )

    return fig

# PCA Analysis

## preprocessing & fitting

In [None]:
def remove_mols_with_nan_descriptors(descriptors_fname):
    descriptors = pd.read_pickle(
        f"{PRETRAINING_PATH}descriptors/{descriptors_fname}.pkl"
    ).drop(columns=["Unnamed: 0"])
    dataset = pd.read_csv(f"{PRETRAINING_PATH}datasets/{descriptors_fname}.csv.gz")
    assert set(descriptors["smiles"]) == set(
        dataset["smiles"]
    ), "Descriptors object contains smiles different from the training dataset"
    nan_smiles = set(descriptors[descriptors.isna().any(axis=1)]["smiles"])
    print(f"There are {len(nan_smiles)=}")
    descriptors_nonan = descriptors[~descriptors["smiles"].isin(nan_smiles)]
    dataset_nonan = dataset[~dataset["smiles"].isin(nan_smiles)]
    print(f"Had {descriptors.shape} before")
    print(f"Now have {descriptors_nonan.shape} after")
    assert set(descriptors_nonan["smiles"]) == set(
        dataset_nonan["smiles"]
    ), "Descriptors object contains smiles different from the training dataset"
    descriptors_nonan.to_pickle(
        f"{PRETRAINING_PATH}descriptors/{descriptors_fname}_nonan.pkl"
    )
    dataset_nonan.to_csv(f"{PRETRAINING_PATH}datasets/{descriptors_fname}_nonan.csv.gz")


def fit_pca_on_dataset(descriptors_fname, n_comps=100, drop_nan=False):
    descriptors = pd.read_pickle(
        f"{PRETRAINING_PATH}descriptors/{descriptors_fname}.pkl"
    ).drop(columns=["smiles", "Ipc"])
    nan_cols = descriptors.columns[descriptors.isna().any()]
    if drop_nan:
        descriptors.drop(columns=nan_cols, inplace=True)
    scaler, pca = fit_pca(descriptors, n=n_comps)
    pickle.dump(
        (scaler, pca),
        open(
            f"{SAMPLING_PATH}pca_weights/scaler_pca_{descriptors_fname}_{n_comps}.pkl",
            "wb",
        ),
    )
    return scaler, pca


nan_rows = remove_mols_with_nan_descriptors("combined_processed_freq1000_block133")
# scaler, pca = fit_pca_on_dataset("combined_processed_freq1000_block133", n_comps=120, drop_nan=True)
scaler, pca = fit_pca_on_dataset(
    "combined_processed_freq1000_block133_nonan", n_comps=120, drop_nan=False
)

In [None]:
# scaler, pca = pickle.load(open(f"{SAMPLING_PATH}pca_weights/scaler_pca_combined_processed_freq1000_block133_nonan", 'rb'))
import matplotlib.pyplot as plt

var_arr = pca.explained_variance_ratio_
print(sum(var_arr))
print(sum(var_arr[:2]))
cum_varr = np.cumsum(var_arr)
plt.plot([1 for _ in range(len(var_arr))])
plt.plot(cum_varr)
print(sum(var_arr[:100]))
print(np.where(cum_varr > 0.99))
print(np.where(cum_varr > 0.995))
print(np.where(cum_varr > 0.999))

## visualization

In [None]:
biscale = {
    "blue": ("#03045e", "#023e8a"),
    "purple": ("#7b2cbf", "#c77dff"),
    "green": ("#008000", "#70e000"),
    "red": ("#a4133c", "#ff4d6d"),
}
from IPython.display import Markdown, display

biscale.update(
    {
        "orange": ("#e85d04", "#faa307"),
        # "pink": ("#d00000", "#ff6f69"),
        # "yellow": ("#c5a880", "#ffdd00"),
        "dark_grey": ("#2e2e2e", "#5a5a5a"),
        # "gray": ("#6d6875", "#a5a5a5"),
        "teal": ("#006a71", "#48cae4"),
        # "violet": ("#3a0ca3", "#6f23ff"),
        "lime": ("#679436", "#aee833"),
        # "indigo": ("#3c096c", "#5e60ce"),
        "gold": ("#b08d57", "#f0e442"),
        # "cyan": ("#0077b6", "#00b4d8"),
        # "magenta": ("#7209b7", "#e43f5a"),
        "beige": ("#bfa58a", "#f5deb3"),
        # "turquoise": ("#00707b", "#00a19d"),
        # "olive": ("#6a994e", "#a7c957"),
        "maroon": ("#4a0000", "#800000"),
        # "coral": ("#ff4b5c", "#ff6b6b")
    }
)
for scale, colors in biscale.items():
    print(scale)
    display(
        Markdown(
            "<br>".join(
                f'<span style="font-family: monospace">{color} <span style="color: {color}">████████</span></span>'
                for color in colors
            )
        )
    )

In [None]:
biscale

In [None]:
scaler, pca = pickle.load(
    open(
        f"{SAMPLING_PATH}pca_weights/scaler_pca_combined_processed_freq100_block133.pkl",
        "rb",
    )
)
columns = pickle.load(
    open(
        f"{SAMPLING_PATH}descriptors/combined_processed_freq100_block133_columnlist.pkl",
        "rb",
    )
)
# descriptors = pd.read_pickle(f"{PRETRAINING_PATH}descriptors/combined_processed_freq100_block133.pkl")
gen100 = pd.read_pickle(f"{SAMPLING_PATH}descriptors/model4_baseline_temp1.0.pkl")
gen1000 = pd.read_pickle(f"{SAMPLING_PATH}descriptors/model5_baseline_temp1.0.pkl")

In [None]:
sample = 5_000
seed = 42
colors = {
    "gunmetal": "#31393C",
    "blue": "#4361EE",
    "plum": "#8E338C",
    "grape": "#7209B7",
    "red": "#D90429",
    "orange": "#FF7B00",
    "yellow": "#FFBA08",
    "mindaro": "#CBFF8C",
}

plot_pca(
    [
        (
            *pca_transform(
                pca,
                scaler.transform(gen100[columns].sample(n=sample, random_state=seed)),
                n=2,
            ),
            colors["gunmetal"],
            f"100 freq",
        ),
        (
            *pca_transform(
                pca,
                scaler.transform(gen1000[columns].sample(n=sample, random_state=seed)),
                n=2,
            ),
            colors["grape"],
            f"1000 freq",
        ),
    ],
    yscale=1.5,
).show()

In [None]:
scaler, pca = pickle.load(
    open(f"{SAMPLING_PATH}pca_weights/scaler_pca_moses+bindingdb.pkl", "rb")
)
columns = pickle.load(
    open(f"{SAMPLING_PATH}descriptors/descriptors_moses+bindingdb_columnlist.pkl", "rb")
)
load_baseline = lambda fname: pd.read_pickle(
    f"{SAMPLING_PATH}descriptors/{fname}.pkl"
).drop_duplicates(subset="smiles")
gpt_base = load_baseline("model1_baseline_temp1.0")
gpt_sub1 = load_baseline("model1_softsub_al1_temp1.0")
gpt_div1 = load_baseline("model1_softdiv_al1_temp1.0")
gpt_sub2 = load_baseline("model1_softsub_al2_temp1.0")
gpt_div2 = load_baseline("model1_softdiv_al2_temp1.0")

pre_base = "model1_baseline_threshold11"
pre_sub1 = "model1_softsub_al1_threshold11"
pre_div1 = "model1_softdiv_al1_threshold11"
pre_sub2 = "model1_softsub_al2_threshold11"
pre_div2 = "model1_softdiv_al2_threshold11"
l_base = lambda fname: gpt_base[
    gpt_base["smiles"].isin(
        pd.read_csv(f"{AL_PATH}training_sets/{pre_base}_{fname}.csv")["smiles"].unique()
    )
]
l_sub1 = lambda fname: gpt_sub1[
    gpt_sub1["smiles"].isin(
        pd.read_csv(f"{AL_PATH}training_sets/{pre_sub1}_{fname}.csv")["smiles"].unique()
    )
]
l_div1 = lambda fname: gpt_div1[
    gpt_div1["smiles"].isin(
        pd.read_csv(f"{AL_PATH}training_sets/{pre_div1}_{fname}.csv")["smiles"].unique()
    )
]
l_sub2 = lambda fname: gpt_sub2[
    gpt_sub2["smiles"].isin(
        pd.read_csv(f"{AL_PATH}training_sets/{pre_sub2}_{fname}.csv")["smiles"].unique()
    )
]
l_div2 = lambda fname: gpt_div2[
    gpt_div2["smiles"].isin(
        pd.read_csv(f"{AL_PATH}training_sets/{pre_div2}_{fname}.csv")["smiles"].unique()
    )
]

In [None]:
print(
    l_base("linear").shape,
    l_base("linear_noscore").shape,
    l_base("softmax_divf0.25").shape,
    l_base("softmax_sub").shape,
)
print(
    l_sub1("softmax_sub").shape,
    l_sub1("softmax_sub_noscore").shape,
    l_div1("softmax_divf0.25").shape,
    l_div1("softmax_divf0.25_noscore").shape,
)
print(
    l_sub2("softmax_sub").shape,
    l_sub2("softmax_sub_noscore").shape,
    l_div2("softmax_divf0.25").shape,
    l_div2("softmax_divf0.25_noscore").shape,
)

In [None]:
sample = 5_000
seed = 42
colors = {
    "gunmetal": "#31393C",
    "blue": "#4361EE",
    "plum": "#8E338C",
    "grape": "#7209B7",
    "red": "#D90429",
    "orange": "#FF7B00",
    "yellow": "#FFBA08",
    "mindaro": "#CBFF8C",
}

scatter = lambda loader, fname: pca_transform(
    pca, scaler.transform(loader(fname)[columns]), n=2
)
plot_pca(
    [
        (
            *pca_transform(
                pca,
                scaler.transform(gpt_base[columns].sample(n=sample, random_state=seed)),
                n=2,
            ),
            colors["gunmetal"],
            f"GPT baseline",
        ),
        (
            *pca_transform(
                pca,
                scaler.transform(gpt_sub1[columns].sample(n=sample, random_state=seed)),
                n=2,
            ),
            colors["grape"],
            f"GPT Softmax Sub",
        ),
        (
            *pca_transform(
                pca,
                scaler.transform(gpt_div1[columns].sample(n=sample, random_state=seed)),
                n=2,
            ),
            colors["red"],
            f"GPT Softmax Div0.25",
        ),
        # (*scatter('linear'), colors["blue"], f'AL1 Linear'),
        (*scatter(l_base, "linear_noscore"), colors["mindaro"], f"AL1 Diffusion"),
        # (*scatter(l_base, 'softmax_divf0.25'), colors["grape"], f'AL1 Softmax Div0.25'),
        # (*scatter(l_base, 'softmax_sub'), colors["red"], f'AL1 Softmax Sub'),
    ],
    yscale=1.5,
).show()

In [None]:
# @title
import matplotlib.pyplot as plt

var_arr = pca.explained_variance_ratio_
print(sum(var_arr))
print(sum(var_arr[:2]))
cum_varr = np.cumsum(var_arr)
plt.plot(var_arr)
plt.plot([1 for _ in range(len(var_arr))])
plt.plot(cum_varr)

In [None]:
heat = lambda loader, fname: pca.transform(scaler.transform(loader(fname)[columns]))[
    :, :2
]
plot_heatmap(
    [
        (
            pca.transform(
                scaler.transform(gpt_base[columns].sample(n=5_304, random_state=seed))
            )[:, :2],
            f"GPT baseline",
        ),
        (heat(l_base, "linear_noscore"), f"Base Diffusion"),
        # (pca.transform(scaler.transform(al1_softmax_div_nosc[columns]))[:, :2], f'AL1 Diffusion Softmax Div0.25'),
        # (pca.transform(scaler.transform(al1_softmax_sub_nosc[columns]))[:, :2], f'AL1 Diffusion Softmax Sub'),
        (heat(l_base, "linear"), f"AL1 Linear"),
        (heat(l_base, "softmax_divf0.25"), f"AL1 Softmax Div0.25"),
        (heat(l_base, "softmax_sub"), f"AL1 Softmax Sub"),
    ],
    difference=True,
    all_differences=True,
    bin_size=1.5,
    width=900,
    height=500,
)

In [None]:
plot_heatmap(
    [
        (
            pca.transform(
                scaler.transform(gpt_base[columns].sample(n=5_304, random_state=seed))
            )[:, :2],
            f"GPT baseline",
        ),
        (
            pca.transform(
                scaler.transform(gpt_sub1[columns].sample(n=5_406, random_state=seed))
            )[:, :2],
            f"GPT Sub1",
        ),
        (heat(l_sub1, "softmax_sub_noscore"), f"Sub1 Diffusion"),
        (heat(l_sub1, "softmax_sub"), f"AL2 Softmax Sub"),
    ],
    difference=True,
    all_differences=True,
    bin_size=1.5,
    width=900,
    height=500,
)

In [None]:
plot_heatmap(
    [
        (
            pca.transform(
                scaler.transform(gpt_base[columns].sample(n=5_304, random_state=seed))
            )[:, :2],
            f"GPT baseline",
        ),
        (
            pca.transform(
                scaler.transform(gpt_div1[columns].sample(n=5_363, random_state=seed))
            )[:, :2],
            f"GPT Div1",
        ),
        (heat(l_div1, "softmax_divf0.25_noscore"), f"Div1 Diffusion"),
        (heat(l_div1, "softmax_divf0.25"), f"AL2 Softmax Div"),
    ],
    difference=True,
    all_differences=True,
    bin_size=1.5,
    width=900,
    height=500,
)

In [None]:
plot_heatmap(
    [
        (
            pca.transform(
                scaler.transform(gpt_base[columns].sample(n=5_304, random_state=seed))
            )[:, :2],
            f"GPT baseline",
        ),
        (
            pca.transform(
                scaler.transform(gpt_sub2[columns].sample(n=5_406, random_state=seed))
            )[:, :2],
            f"GPT Sub2",
        ),
        (heat(l_sub2, "softmax_sub_noscore"), f"Sub2 Diffusion"),
        (heat(l_sub2, "softmax_sub"), f"AL3 Softmax Sub"),
    ],
    difference=True,
    all_differences=True,
    bin_size=1.5,
    width=900,
    height=500,
)

In [None]:
l_div2

In [None]:
plot_heatmap(
    [
        (
            pca.transform(
                scaler.transform(gpt_base[columns].sample(n=5_304, random_state=seed))
            )[:, :2],
            f"GPT baseline",
        ),
        (
            pca.transform(
                scaler.transform(gpt_div2[columns].sample(n=5_363, random_state=seed))
            )[:, :2],
            f"GPT Div2",
        ),
        (heat(l_div2, "softmax_divf0.25_noscore"), f"Div2 Diffusion"),
        (heat(l_div2, "softmax_divf0.25"), f"AL3 Softmax Div"),
    ],
    difference=True,
    all_differences=True,
    bin_size=1.5,
    width=900,
    height=500,
)

# Distribution Analysis

## Definitions

In [None]:
import numpy as np
import plotly.graph_objects as go
from scipy.stats import gaussian_kde
from graph import Graph
import pprint

pp = pprint.PrettyPrinter(indent=1, compact=False, width=100)


def load_dist(fname):
    return pd.read_csv(f"{SCORING_PATH}scored_dataframes/{fname}.csv")[
        "score"
    ].to_numpy()


def compute_cluster_scores(fname):
    good_data = pd.read_csv(f"{SCORING_PATH}scored_dataframes/{fname}.csv")
    cluster_to_scores = {}
    for index, row in good_data.iterrows():
        cluster_to_scores.setdefault(row["cluster_id"], []).append(row["score"])
    cluster_to_score = {
        cluster_id: np.mean(scores) for cluster_id, scores in cluster_to_scores.items()
    }
    return np.array(list(cluster_to_score.values()))


# dark to light
biscale = {
    "blue": ("#03045e", "#023e8a"),
    "purple": ("#7b2cbf", "#c77dff"),
    "green": ("#008000", "#70e000"),
    "red": ("#a4133c", "#ff4d6d"),
    "orange": ("#e85d04", "#faa307"),
    "teal": ("#006a71", "#48cae4"),
    "lime": ("#679436", "#aee833"),
    "gold": ("#b08d57", "#f0e442"),
    "beige": ("#bfa58a", "#f5deb3"),
    "maroon": ("#4a0000", "#800000"),
    "dark_grey": ("#000000", "#2e2e2e"),
}


def create_hist_trace(i, data, label, color, threshold, bin_step, trace_opacity):
    # Generate KDE for data
    density = gaussian_kde(data)
    xs = np.linspace(np.min(data), np.max(data), 200)
    density.covariance_factor = lambda: 0.25
    density._compute_covariance()

    hist_vals, bin_edges = np.histogram(
        data, bins=range(0, int(np.max(data)) + 2, bin_step), density=True
    )
    hist = go.Bar(
        x=bin_edges[:-1],
        y=hist_vals,
        name=label,
        opacity=trace_opacity,
        marker=dict(color=biscale[color][1]),
        hovertemplate=[f"[{int(i)}, {int(i + bin_step)})" for i in bin_edges[:-1]],
    )
    density_curve = go.Scatter(
        x=xs,
        y=density(xs),
        mode="lines",
        name=label + " Density",
        line=dict(color=biscale[color][0]),
    )

    above_threshold_pct = np.sum(data >= threshold) / len(data) * 100
    q25, q50, q75 = np.percentile(data, [25, 50, 75])
    f = lambda x: f"{x:0>4.1f}" if x < 10 else f"{x:.1f}"

    annotation = dict(
        x=0.95,
        y=1.0 - 0.04 * i,
        xref="paper",
        yref="paper",
        text=f"{label}: % > threshold = {f(above_threshold_pct)}, Q25 = {f(q25)}, Q50 = {f(q50)}, Mean = {f(data.mean())}, Q75 = {f(q75)}, max = {f(data.max())}",
        showarrow=False,
        font=dict(size=12),
    )

    return [hist, density_curve], annotation, max(hist_vals)


def plot_hist_density(
    prefix,
    descriptors_type,
    n_clusters,
    n_iters,
    selection,
    colors,
    title_spec,
    threshold=11,
    bin_step=2,
    trace_opacity=0.6,
    width=1280,
    height=600,
):
    fnames = [
        f"{prefix}_baseline_{descriptors_type}_k{n_clusters}",
        *(
            f"{prefix}_{descriptors_type}{n_clusters}_{selection}_al{i}_{descriptors_type}_k{n_clusters}"
            for i in range(1, n_iters + 1)
        ),
    ]
    labels = [
        f"{descriptors_type} k{n_clusters} baseline",
        *(
            f"{descriptors_type} k{n_clusters} {selection} AL{i}"
            for i in range(1, n_iters + 1)
        ),
    ]
    graph = Graph(BASE_PATH)
    traces, annotations = [], []
    freq = 0
    loader = load_dist if title_spec == "ligand" else compute_cluster_scores
    for i, (fname, label, color) in enumerate(zip(fnames, labels, colors)):
        new_traces, annotation, max_freq = create_hist_trace(
            i, loader(fname), label, color, threshold, bin_step, trace_opacity
        )
        traces.extend(new_traces)
        annotations.append(annotation)
        freq = max(freq, max_freq)

    # Create figure and add traces
    fig = go.Figure(
        data=traces,
        layout=go.Layout(
            bargap=0.2,
            barmode="overlay",
            shapes=[
                dict(
                    type="line",
                    x0=threshold,
                    x1=threshold,
                    y0=0,
                    y1=1,
                    yref="paper",  # refers to the entire plot for the y-dimension
                    line=dict(color="grey", width=2, dash="longdash"),
                )
            ],
            annotations=annotations,
        ),
    )

    # Show figure
    graph.update_parameters(
        dict(
            title=f"Distribution of {title_spec} scores based on DiffDock Poses",
            xaxis_title="Prolif Score",
            yaxis_title="Rel. Frequency",
            width=width,
            height=height,
            yrange=[0, freq * 1.05],
        )
    )
    graph.style_figure(fig)
    graph.save_figure(
        figure=fig,
        path=f"{BASE_PATH}/plots/{title_spec}_distribution/",
        fname=f"{descriptors_type}_k{n_clusters}_{selection}_al{n_iters}",
        html=True,
    )

## Individual Scores

In [None]:
configs = [
    ("mix", 100, 5, "softsub"),
    ("mix", 10, 3, "softsub"),
    ("mqn", 100, 5, "softsub"),
    ("mqn", 10, 4, "softsub"),
]
for calc_type in ["ligand", "cluster"]:
    for descriptor_type, n_clusters, n_iters, selection in configs:
        plot_hist_density(
            prefix="model7",
            descriptors_type=descriptor_type,
            n_clusters=n_clusters,
            n_iters=n_iters,
            selection=selection,
            colors=["dark_grey", "teal", "orange", "purple", "green", "red"],
            title_spec=calc_type,
            bin_step=2 if calc_type == "ligand" else 1,
            threshold=11,
            trace_opacity=0.6,
            width=1280,
            height=600,
        )

# BindingDB Analysis

In [None]:
from rdkit import Chem
from rdkit.Chem import Draw


def draw_molecule_from_smiles(smiles_string):
    # Create an RDKit molecule object from the SMILES string
    molecule = Chem.MolFromSmiles(smiles_string)

    # Check if the molecule was created successfully
    if molecule is not None:
        # Draw the molecule
        drawing = Draw.MolToImage(molecule)
        # Display the drawing (you can also save it to a file if desired)
        return drawing
    else:
        print("Invalid SMILES string.")


def load_bindingdb():
    columns = ["Ligand SMILES", "Ki (nM)", "IC50 (nM)", "Kd (nM)", "Target Name"]
    db = pd.read_csv(
        f"{PRETRAINING_PATH}datasets/original_files/BindingDB_all.tsv",
        sep="\t",
        usecols=columns,
    )
    print(f"Loaded {len(db)=} entries")
    db.dropna(subset=["Kd (nM)"], inplace=True)
    print(f"Removed entries with missing Kd, {len(db)=} entries left")
    db.drop_duplicates(inplace=True)
    print(f"Removed duplicates, {len(db)=} entries left")
    # Loaded len(db)=2765983 entries
    # Removed entries with missing Kd, len(db)=105706 entries left
    # Removed duplicates, len(db)=86620 entries left
    return db[["Ligand SMILES", "Target Name", "Kd (nM)"]]


def find_good_binders(df, cutoff=100):
    def parse_kd(value):
        if isinstance(value, (float, int)):
            return value
        elif isinstance(value, str):
            try:
                if ">" in value:
                    return float(value[1:])
                elif "<" in value:
                    return (
                        float(value[1:]) - 1
                    )  # or other logic if you want to handle the '<' differently
                return float(value)
            except ValueError:
                return float("inf")  # or another value that will be filtered out
        else:
            raise TypeError(f"{value=} is of type {type(value)=}")

    # Apply the custom function to the "Kd (nM)" column
    df["Kd (nM)"] = df["Kd (nM)"].apply(parse_kd)
    good_binders = df[df["Kd (nM)"] <= cutoff]
    return good_binders

In [None]:
def analyze_repeated_targets(good_binders):
    # Group by "Ligand SMILES" and count the unique "Target Name" values for each group
    target_counts = good_binders.groupby("Ligand SMILES")["Target Name"].nunique()

    # Count how many SMILES strings are associated with each number of unique targets
    repetition_counts = target_counts.value_counts().sort_index()

    print("Number of unique targets for each SMILES string:")
    print(repetition_counts)


def find_smiles_with_n_targets(good_binders, n):
    target_counts = good_binders.groupby("Ligand SMILES")["Target Name"].nunique()
    smiles_with_n_targets = target_counts[target_counts == n].index

    # Get the rows from the original DataFrame that match the SMILES strings with n targets
    rows_with_n_targets = good_binders[
        good_binders["Ligand SMILES"].isin(smiles_with_n_targets)
    ]

    # Group by "Ligand SMILES" again and get the list of unique target names for each
    target_names_by_smiles = rows_with_n_targets.groupby("Ligand SMILES")[
        "Target Name"
    ].unique()

    return target_names_by_smiles


# analyze_repeated_targets(good_binders)
# # Find SMILES strings associated with 5 unique targets and their corresponding target names
# smiles_with_5_targets = find_smiles_with_n_targets(good_binders, 328)
# print("SMILES strings associated with 5 unique targets and their target names:")
# for smiles, targets in smiles_with_5_targets.iteritems():
#     print(f"SMILES: {smiles}\nTargets: {', '.join(targets)}\n")


def calculate_descriptors_for_bindingdb(kd_cutoff, target_freq):
    df = pd.read_csv(
        f"{BASE_PATH}analysis/no_rare_targets_Kd<{kd_cutoff}_freq>{target_freq}.csv"
    )
    smiles = set(df["Ligand SMILES"].values)
    calculate_descriptors(
        smiles, "mix", f"bindingdb_mix_Kd<{kd_cutoff}_freq>{target_freq}"
    )
    calculate_descriptors(
        smiles, "mqn", f"bindingdb_mqn_Kd<{kd_cutoff}_freq>{target_freq}"
    )

In [20]:
import graph
import importlib

importlib.reload(graph)
from graph import Graph


def plot_cumulative_histogram_of_smiles_distribution(good_binders, suffix=""):
    graph = Graph(f"{BASE_PATH}plots/bindingdb_analysis/")
    # graph.create_folders([["html", "svg", "jpg"]])
    smiles_counts_per_target = good_binders.groupby("Target Name")[
        "Ligand SMILES"
    ].nunique()
    fig = go.Figure()
    fig.add_trace(
        go.Histogram(
            x=smiles_counts_per_target,
            cumulative=dict(enabled=True, direction="decreasing"),
            xbins=dict(start=0, end=smiles_counts_per_target.max(), size=1),
            marker=dict(opacity=0.7),
        )
    )
    thresholds = [50, 100, 200, 300, 500]
    counts = {
        threshold: sum(smiles_counts_per_target >= threshold)
        for threshold in thresholds
    }

    for i, (threshold, count) in enumerate(counts.items()):
        fig.add_annotation(
            x=0.95,  # Relative horizontal position in the figure
            y=0.95 - 0.05 * i,  # Relative vertical position, you can adjust as needed
            xref="paper",  # Position is relative to the figure
            yref="paper",  # Position is relative to the figure
            text=f"{count} targets have {threshold} or more SMILES",
            showarrow=False,
            align="right",
        )
    graph.update_parameters(
        dict(
            title="Cumulative Histogram of SMILES Distribution Across Targets",
            xaxis_title="Number of SMILES Strings or Fewer",
            yaxis_title="Number of Targets",
            showlegend=False,
        )
    )
    graph.style_figure(fig)
    graph.save_figure(
        fig, f"{BASE_PATH}plots/bindingdb_analysis/", "smiles_distribution" + suffix
    )
    return fig


def remove_rare_targets(kd_cutoff=100, target_freq=200):
    good_binders = find_good_binders(load_bindingdb(), kd_cutoff)
    print(f"{len(good_binders)=}")
    good_binders.drop_duplicates(subset=["Ligand SMILES"], inplace=True)
    print(f"{len(good_binders)=} after dropping duplicates")
    smiles_counts_per_target = good_binders.groupby("Target Name")[
        "Ligand SMILES"
    ].nunique()
    targets_to_keep = smiles_counts_per_target[
        smiles_counts_per_target >= target_freq
    ].index
    no_rare_targets = good_binders[good_binders["Target Name"].isin(targets_to_keep)]
    print(f"Post filtering we have {len(no_rare_targets)=}")
    # no_rare_targets.drop_duplicates(subset=["Ligand SMILES"], inplace=True)
    # print(f"Post dropping {len(no_rare_targets)=}")
    no_rare_targets.to_csv(
        f"{BASE_PATH}analysis/no_rare_targets_Kd<{kd_cutoff}_freq>{target_freq}.csv",
        index=False,
    )


def find_unique_targets(kd_cutoff, target_freq):
    df = pd.read_csv(
        f"{BASE_PATH}analysis/no_rare_targets_Kd<{kd_cutoff}_freq>{target_freq}.csv"
    )
    unique_targets = df["Target Name"].unique()
    colors = [
        "#e6194B",
        "#3cb44b",
        "#ffe119",
        "#4363d8",
        "#f58231",
        "#911eb4",
        "#46f0f0",
        "#f032e6",
        "#42d4f4",
        "#bfef45",
        "#fabebe",
        "#469990",
        "#e6beff",
        "#9A6324",
        "#800000",
        "#aaffc3",
    ]
    target_to_color = {target: color for target, color in zip(unique_targets, colors)}
    # safe this dictionary to a yaml file
    with open(f"{BASE_PATH}analysis/target_to_color_16.yaml", "w") as file:
        yaml.dump(target_to_color, file)
    return target_to_color

In [None]:
for cut in [1, 10, 100, 1000, 10_000, 100_000]:
    plot_cumulative_histogram_of_smiles_distribution(
        find_good_binders(load_bindingdb(), cutoff=cut), suffix=f"_cutoff{cut}"
    )

In [None]:
# remove_rare_targets(kd_cutoff=100, target_freq=100)
calculate_descriptors_for_bindingdb(kd_cutoff=100, target_freq=100)

In [None]:
find_unique_targets(100, 100)

In [2]:
def add_targets_to_descriptors(good_binders_fname, descriptors_fname):
    good_binders = pd.read_csv(f"{BASE_PATH}analysis/{good_binders_fname}.csv")
    good_binders.drop_duplicates(subset="Ligand SMILES", inplace=True)
    descriptors = pickle.load(
        open(f"{SAMPLING_PATH}descriptors/{descriptors_fname}.pkl", "rb")
    )
    print(f"Loaded {descriptors.shape} descriptors")
    descriptors = descriptors.merge(
        good_binders, left_on="smiles", right_on="Ligand SMILES", how="inner"
    )
    print(f"Added targets, {descriptors.shape} entries post merging")
    return descriptors

In [19]:
from IPython.display import Markdown, display

colors = [
    "#e6194B",
    "#3cb44b",
    "#ffe119",
    "#4363d8",
    "#f58231",
    "#911eb4",
    "#46f0f0",
    "#f032e6",
    "#42d4f4",
    "#bfef45",
    "#fabebe",
    "#469990",
    "#e6beff",
    "#9A6324",
    "#800000",
    "#aaffc3",
]

display(
    Markdown(
        "<br>".join(
            f'<span style="font-family: monospace">{color} <span style="color: {color}">████████</span></span>'
            for color in colors[:16]
        )
    )
)

<span style="font-family: monospace">#e6194B <span style="color: #e6194B">████████</span></span><br><span style="font-family: monospace">#3cb44b <span style="color: #3cb44b">████████</span></span><br><span style="font-family: monospace">#ffe119 <span style="color: #ffe119">████████</span></span><br><span style="font-family: monospace">#4363d8 <span style="color: #4363d8">████████</span></span><br><span style="font-family: monospace">#f58231 <span style="color: #f58231">████████</span></span><br><span style="font-family: monospace">#911eb4 <span style="color: #911eb4">████████</span></span><br><span style="font-family: monospace">#46f0f0 <span style="color: #46f0f0">████████</span></span><br><span style="font-family: monospace">#f032e6 <span style="color: #f032e6">████████</span></span><br><span style="font-family: monospace">#42d4f4 <span style="color: #42d4f4">████████</span></span><br><span style="font-family: monospace">#bfef45 <span style="color: #bfef45">████████</span></span><br><span style="font-family: monospace">#fabebe <span style="color: #fabebe">████████</span></span><br><span style="font-family: monospace">#469990 <span style="color: #469990">████████</span></span><br><span style="font-family: monospace">#e6beff <span style="color: #e6beff">████████</span></span><br><span style="font-family: monospace">#9A6324 <span style="color: #9A6324">████████</span></span><br><span style="font-family: monospace">#800000 <span style="color: #800000">████████</span></span><br><span style="font-family: monospace">#aaffc3 <span style="color: #aaffc3">████████</span></span>

# UMAP \ t-SNE

## Definitions

In [3]:
import umap
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
import plotly.graph_objects as go
from typing import List, Tuple, Union
import pprint
import pickle
import yaml

pp = pprint.PrettyPrinter(indent=2, compact=False, width=100)

In [4]:
def open_scored_file(fname: str) -> pd.DataFrame:
    return pd.read_csv(f"{SCORING_PATH}scored_dataframes/{fname}.csv")[
        ["smiles", "score"]
    ]


def load_scored_all_iters(
    prefix: str, descriptors_type: str, n_clusters: int, n_iters: int, selection: str
) -> pd.DataFrame:
    fnames = [
        f"{prefix}_baseline_{descriptors_type}_k{n_clusters}",
        *(
            f"{prefix}_{descriptors_type}{n_clusters}_{selection}_al{i}_{descriptors_type}_k{n_clusters}"
            for i in range(1, n_iters + 1)
        ),
    ]
    stacked = pd.concat([open_scored_file(fname) for fname in fnames])
    assert (
        stacked.drop_duplicates().shape == stacked.shape
    ), "There are duplicates in the array"
    return stacked


def open_descriptors_file(fname: str) -> pd.DataFrame:
    return pickle.load(open(f"{SAMPLING_PATH}descriptors/{fname}.pkl", "rb"))


def load_descriptors(
    prefix: str, descriptors_type: str, n_clusters: int, n_iters: int, selection: str
) -> pd.DataFrame:
    fnames = [
        f"{prefix}_baseline_{descriptors_type}_temp1.0",
        *(
            f"{prefix}_{descriptors_type}{n_clusters}_{selection}_al{i}_{descriptors_type}_temp1.0"
            for i in range(1, n_iters + 1)
        ),
    ]
    merged = pd.concat([open_descriptors_file(fname) for fname in fnames])
    print(f"{merged.shape[0]} descriptors loaded")
    merged.drop_duplicates(inplace=True)
    print(f"{merged.shape[0]} descriptors after dropping duplicates")
    return merged


def get_descriptors_for_scored(
    prefix: str,
    descriptors_type: str,
    n_clusters: int,
    n_iters: int,
    selection: str,
    override_fname: str = None,
) -> pd.DataFrame:
    scored = load_scored_all_iters(
        prefix, descriptors_type, n_clusters, n_iters, selection
    )
    if override_fname is None:
        descriptors = load_descriptors(
            prefix, descriptors_type, n_clusters, n_iters, selection
        )
    else:
        descriptors = open_descriptors_file(override_fname)
        descriptors.drop_duplicates(inplace=True)
    return pd.merge(scored, descriptors, on="smiles", how="inner")


def load_nona_column_names(pca_fname: str) -> np.ndarray:
    scaler, _ = pickle.load(open(f"{SAMPLING_PATH}pca_weights/{pca_fname}.pkl", "rb"))
    return scaler.get_feature_names_out()

In [5]:
def process_smile(smile, desc_type):
    mol = Chem.MolFromSmiles(smile)
    if mol is None:
        return None
    if desc_type == "mix":
        return Descriptors.CalcMolDescriptors(mol, missingVal=None, silent=False)
    else:
        descriptors = rdMolDescriptors.MQNs_(mol)
        assert (
            len(descriptors) == 42
        ), f"Expected 42 descriptors, got {len(descriptors)}"
        return {f"MQN{i}": descriptor for i, descriptor in enumerate(descriptors)}


def calculate_descriptors(smiles, desc_type, fname):
    keyToData = {}
    for smile in tqdm(smiles, total=len(smiles)):
        data = process_smile(smile, desc_type)
        if data is None:
            continue
        else:
            keyToData.setdefault("smiles", []).append(smile)
            for descriptor, value in data.items():
                keyToData.setdefault(descriptor, []).append(value)
    pd.DataFrame(keyToData).to_pickle(f"{SAMPLING_PATH}descriptors/{fname}.pkl")


def load_mqn_smiles():
    smiles_mqn100 = load_scored_all_iters("model7", "mqn", 100, 5, "softsub")
    smiles_mqn10 = load_scored_all_iters("model7", "mqn", 10, 4, "softsub")
    smiles_mqn = pd.concat([smiles_mqn100, smiles_mqn10])
    smiles_mqn.drop_duplicates(inplace=True)
    return smiles_mqn["smiles"].to_list()


def load_mix_smiles():
    smiles_mix100 = load_scored_all_iters("model7", "mix", 100, 5, "softsub")
    smiles_mix10 = load_scored_all_iters("model7", "mix", 10, 3, "softsub")
    smiles_mix = pd.concat([smiles_mix100, smiles_mix10])
    smiles_mix.drop_duplicates(inplace=True)
    return smiles_mix["smiles"].to_list()


# calculate_descriptors(load_mqn_smiles(), "mix", "mqn100upto5_10upto4_mix_descs")
# calculate_descriptors(load_mix_smiles(), "mqn", "mix100upto5_10upto3_mqn_descs")

In [6]:
def load_all_scored_in_desc_space(desc, override_fname):
    scored_mix100 = get_descriptors_for_scored(
        "model7", "mix", 100, 5, "softsub", override_fname if desc == "mqn" else None
    )
    scored_mix10 = get_descriptors_for_scored(
        "model7", "mix", 10, 3, "softsub", override_fname if desc == "mqn" else None
    )
    scored_mqn100 = get_descriptors_for_scored(
        "model7", "mqn", 100, 5, "softsub", override_fname if desc == "mix" else None
    )
    scored_mqn10 = get_descriptors_for_scored(
        "model7", "mqn", 10, 4, "softsub", override_fname if desc == "mix" else None
    )
    scored = pd.concat([scored_mix100, scored_mix10, scored_mqn100, scored_mqn10])
    scored.drop_duplicates(inplace=True)
    print(scored.shape)
    return scored

In [89]:
def explore_umap(
    data,
    desc_type: str,
    nearest_neighbors: list,
    min_distances: list,
    prefix: str,
    score_key: str = "score",
):
    reduced = reduce_scored_exploration(
        data=data,
        reduction="UMAP",
        desc_type=desc_type,
        reduction_parameters=[
            {"n_neighbors": nn, "min_dist": dist}
            for nn in nearest_neighbors
            for dist in min_distances
        ],
        score_key=score_key,
    )
    pickle.dump(
        reduced,
        open(
            f"{BASE_PATH}reducers/{prefix}_{desc_type}_umap_nn{','.join([str(elt) for elt in nearest_neighbors])}_mindist{','.join([str(elt) for elt in min_distances])}.pkl",
            "wb",
        ),
    )


def explore_tsne(
    data,
    desc_type: str,
    perplexities: list,
    early_exaggerations: list,
    prefix: str,
    score_key: str = "score",
):
    reduced = reduce_scored_exploration(
        data=data,
        reduction="t-SNE",
        desc_type=desc_type,
        reduction_parameters=[
            {"perplexity": iperp, "early_exaggeration": iee}
            for iee in early_exaggerations
            for iperp in perplexities
        ],
        score_key=score_key,
    )
    pickle.dump(
        reduced,
        open(
            f"{BASE_PATH}reducers/{prefix}_{desc_type}_tsne_perp{','.join([str(elt) for elt in perplexities])}_ee{','.join([str(elt) for elt in early_exaggerations])}.pkl",
            "wb",
        ),
    )


def explore_pca2d(
    data,
    desc_type: str,
    prefix: str,
    score_key: str = "score",
):
    reduced = reduce_scored_exploration(
        data=data,
        reduction="PCA",
        desc_type=desc_type,
        reduction_parameters={"reduce_to_2d": True},
        score_key=score_key,
    )
    pickle.dump(
        reduced,
        open(
            f"{BASE_PATH}reducers/{prefix}_{desc_type}_pca.pkl",
            "wb",
        ),
    )

In [57]:
def plot_2d_scatterplot(
    datapoints: List[Tuple[np.ndarray, Union[str, np.ndarray], str]],
    reduction_type: str,
    figure_path: str,
    figure_fname: str,
    title: str,
    force_colors=None,
    yscale: float = 1.05,
    width=900,
    height=500,
) -> go.Figure:
    fig = go.Figure()
    graph = Graph(figure_path)
    minX, minY, maxX, maxY = float("inf"), float("inf"), float("-inf"), float("-inf")
    for i, (data, color, label) in enumerate(datapoints):
        if force_colors is not None:
            text = color
            color = force_colors
        xarr = data[:, 0]
        yarr = data[:, 1]
        minX = min(minX, min(xarr))
        minY = min(minY, min(yarr))
        maxX = max(maxX, max(xarr))
        maxY = max(maxY, max(yarr))
        fig.add_trace(
            go.Scatter(
                x=xarr,
                y=yarr,
                mode="markers",
                name=label,
                visible="legendonly" if i > 0 else True,
                text=text,
                marker=dict(
                    size=5,
                    color=color,
                    showscale=True
                    if force_colors is None and isinstance(color, (list, np.ndarray))
                    else False,
                    colorscale="Hot",
                    opacity=0.5,
                ),
            )
        )
    graph.update_parameters(
        dict(
            title=title,
            xrange=[yscale * minX, yscale * maxX],
            yrange=[yscale * minY, yscale * maxY],
            xaxis_title=f"{reduction_type} Component 1",
            yaxis_title=f"{reduction_type} Component 2",
            width=width,
            height=height,
            showlegend=len(datapoints) > 1,
            xmirror=True,
            ymirror=True,
        )
    )
    graph.style_figure(fig)
    graph.save_figure(
        fig, figure_path, figure_fname, jpg=True, svg=True, pdf=False, html=True
    )
    return fig

In [87]:
def reduce_scored_exploration(
    data: pd.DataFrame,
    reduction: str,
    desc_type: str,
    reduction_parameters: list = None,
    score_key: str = "score",
):
    assert reduction in {"PCA", "UMAP", "t-SNE"}
    if reduction_parameters is None:
        reduction_parameters = {}
    scores = data[score_key].to_numpy()

    pca_fname = "scaler_pca_combined_processed_freq1000_block133_120"
    scaler, pca = pickle.load(open(f"{SAMPLING_PATH}pca_weights/{pca_fname}.pkl", "rb"))
    if desc_type == "mix":
        proper_columns = data[load_nona_column_names(pca_fname)]
        transformed = pca.transform(scaler.transform(proper_columns))
    elif desc_type == "mqn":
        proper_columns = data[[f"MQN{i}" for i in range(42)]]
        transformed = proper_columns
    else:
        raise KeyError(f"desc_type {desc_type} not supported")
    datapoints = []
    match reduction:
        case "PCA":
            if reduction_parameters.get("refit_pca", False):
                scaler = StandardScaler()
                pca = PCA(n_components=reduction_parameters["n_components"])
                scaled = scaler.fit_transform(proper_columns)
                transformed = pca.fit_transform(scaled)
                suffix = " (refitted)"
            elif reduction_parameters.get("reduce_to_2d", False):
                transformed = transformed[:, :2]
                suffix = " (2D)"
            else:
                suffix = ""
            datapoints.append((transformed, scores, "PCA Transformed" + suffix))
        case "UMAP":
            for kwargs in tqdm(reduction_parameters, total=len(reduction_parameters)):
                reducer = umap.UMAP(
                    metric="euclidean",
                    n_components=2,
                    random_state=42,
                    verbose=False,
                    **kwargs,
                )
                reduced = reducer.fit_transform(transformed)
                datapoints.append(
                    (
                        reduced,
                        scores,
                        f"{kwargs['n_neighbors']} neighbors, {kwargs['min_dist']} min dist",
                    )
                )
        case "t-SNE":
            for kwargs in tqdm(reduction_parameters, total=len(reduction_parameters)):
                tsne = TSNE(
                    n_components=2,
                    random_state=42,
                    metric="euclidean",
                    **kwargs,
                )
                reduced = tsne.fit_transform(transformed)
                datapoints.append(
                    (
                        reduced,
                        scores,
                        f"Perplexity {kwargs['perplexity']}, Early Exaggeration {kwargs['early_exaggeration']}",
                    )
                )
    return datapoints

In [76]:
import graph
import importlib

importlib.reload(graph)
from graph import Graph


def get_data_boundaries(data_list):
    combined = np.vstack(
        data_list
    )  # Combine both datasets to get overall min and max values

    min_val = (
        np.floor(np.min(combined, axis=0) / 10.0) * 10
    )  # Round to the nearest number divisible by 10
    max_val = np.ceil(np.max(combined, axis=0) / 10.0) * 10
    return min_val, max_val


def discretize_data_wscores(data, boundaries, bin_size, scores):
    bins = [np.arange(boundaries[0][i], boundaries[1][i], bin_size) for i in range(2)]
    hist_data, xedges, yedges, _ = scipy.stats.binned_statistic_2d(
        data[:, 0], data[:, 1], scores, statistic="mean", bins=bins
    )
    hist_data = np.nan_to_num(hist_data, nan=0)

    xcenters = (xedges[:-1] + xedges[1:]) / 2
    ycenters = (yedges[:-1] + yedges[1:]) / 2
    return hist_data.T, xcenters, ycenters


def discretize_data(data, boundaries, bin_size):
    # Use 2D histogram to discretize data
    bins = [np.arange(boundaries[0][i], boundaries[1][i], bin_size) for i in range(2)]
    hist_data, xedges, yedges = np.histogram2d(data[:, 0], data[:, 1], bins=bins)
    # Compute bin centers
    xcenters = (xedges[:-1] + xedges[1:]) / 2
    ycenters = (yedges[:-1] + yedges[1:]) / 2
    return hist_data.T, xcenters, ycenters


def create_heatmap_traces_for_all_differences(args_list, bin_size):
    data_list, name_list = zip(*args_list)
    boundaries = get_data_boundaries(data_list)
    traces = []
    for i, ((data_before, name_before), (data_after, name_after)) in enumerate(
        itertools.combinations(args_list, 2)
    ):
        hist_before, xcenters, ycenters = discretize_data(
            data_before, boundaries, bin_size
        )
        hist_after, _, _ = discretize_data(data_after, boundaries, bin_size)
        diff = hist_after - hist_before
        label = f"|{name_after}|<br>-|{name_before}|"
        traces.append(
            go.Heatmap(
                x=xcenters,
                y=ycenters,
                z=diff,
                zmid=0,
                zmax=110,
                zmin=-110,
                colorscale="RdBu",
                name=label,
                showlegend=True,
                visible=True if i == 0 else "legendonly",
            )
        )
    return traces


def create_heatmap_trace_for_difference(args_list, bin_size):
    assert (
        len(data_list) == 2
    ), f"To plot a difference, please provide only 2 data sources"
    data_list, name_list = zip(*args_list)
    boundaries = get_data_boundaries(data_list)
    hist_before, xcenters, ycenters = discretize_data(
        data_list[0], boundaries, bin_size
    )
    hist_after, _, _ = discretize_data(data_list[1], boundaries, bin_size)
    diff = hist_after - hist_before
    label = f"|{name_list[1]}|<br>-|{name_list[0]}|"
    return go.Heatmap(
        x=xcenters,
        y=ycenters,
        z=diff,
        zmid=0,
        colorscale="RdBu",
        name=label,
        showlegend=True,
    )


def create_heatmap_traces_for_all_scores(args_list, bin_size, scores, force_zmax=None):
    data_list, name_list = zip(*args_list)
    boundaries = get_data_boundaries(data_list)
    traces = []
    functor = (
        lambda x: discretize_data(*x[:-1])
        if scores is None
        else discretize_data_wscores(*x)
    )
    zmax = max(
        [functor([data, boundaries, bin_size, scores])[0].max() for data in data_list]
    )
    if force_zmax is not None:
        zmax = force_zmax
    zmin = min(
        [functor([data, boundaries, bin_size, scores])[0].min() for data in data_list]
    )
    for i, (data, name) in enumerate(args_list):
        hist, xcenters, ycenters = functor([data, boundaries, bin_size, scores])
        traces.append(
            go.Heatmap(
                x=xcenters,
                y=ycenters,
                z=hist,
                name=name,
                zmin=zmin,
                zmax=zmax,
                showlegend=True,
                colorscale="Thermal",
                visible=True if i == 0 else "legendonly",
            )
        )
    return traces


def plot_heatmap(
    args_list,
    title,
    xtitle,
    ytitle,
    figure_path: str,
    figure_fname: str,
    difference=False,
    bin_size=10,
    width=1280,
    height=720,
    all_differences=False,
    scores=None,
    force_zmax=None,
):
    fig = go.Figure()
    graph = Graph(f"{BASE_PATH}plots/reducers/")

    if difference:
        if all_differences:
            traces = create_heatmap_traces_for_all_differences(args_list, bin_size)
        else:
            traces = [create_heatmap_trace_for_difference(args_list, bin_size)]
    else:
        traces = create_heatmap_traces_for_all_scores(
            args_list, bin_size, scores, force_zmax
        )
    for trace in traces:
        fig.add_trace(trace)

    graph.update_parameters(
        dict(
            title=title,
            xaxis_title=xtitle,
            yaxis_title=ytitle,
            width=width,
            height=height,
            showlegend=len(traces) > 1,
        )
    )
    graph.style_figure(fig)
    fig.update_layout(
        legend=dict(x=1.2, y=1),
    )
    graph.save_figure(fig, figure_path, figure_fname, html=True)
    return fig

In [72]:
def visualize_pickle_heatmap(
    fname: str,
    method: str,
    desc_type: str,
    figure_fname: str,
    force_zmax: int = None,
    bin_size: float = 2,
    width: int = 900,
    height: int = 500,
    data_index=None,
):
    data = pickle.load(
        open(
            f"{BASE_PATH}reducers/{fname}.pkl",
            "rb",
        ),
    )
    custom_scores = data[0][1].copy()
    # custom_scores[custom_scores < 11] = 0
    if data_index is not None:
        args_list = [(data[data_index][0], data[data_index][2])]
    else:
        args_list = [(data[i][0], data[i][2]) for i in range(len(data))]
    return plot_heatmap(
        args_list=args_list,
        title=f"Scored molecules in {desc_type} space reduced with {method} (bin_size={bin_size})",
        xtitle=f"{method} Component 1",
        ytitle=f"{method} Component 2",
        figure_path=f"{BASE_PATH}plots/reducers/",
        figure_fname=figure_fname,
        bin_size=bin_size,
        width=width,
        height=height,
        scores=custom_scores,
        force_zmax=force_zmax,
    )


def visualize_pickle_scatter(
    fname: str,
    method: str,
    desc_type: str,
    figure_fname: str,
    width: int = 900,
    height: int = 500,
    data_index=None,
):
    data = pickle.load(
        open(
            f"{BASE_PATH}reducers/{fname}.pkl",
            "rb",
        ),
    )
    custom_scores = data[0][1].copy()
    with open(f"{BASE_PATH}analysis/target_to_color_16.yaml", "r") as f:
        target_to_color = yaml.load(f, Loader=yaml.FullLoader)
    colors = list(map(lambda x: target_to_color[x], custom_scores))
    if data_index is not None:
        data = [data[data_index]]
    return plot_2d_scatterplot(
        datapoints=data,
        title=f"Distribution of binders (Kd < 100 nM) in {desc_type} space after reduction with {method}",
        reduction_type=method,
        force_colors=colors,
        figure_path=f"{BASE_PATH}plots/bindingdb_analysis/",
        figure_fname=figure_fname,
        width=width,
        height=height,
        # scores=custom_scores,
        # force_zmax=force_zmax,
    )

## Runs

In [39]:
scored_mix = load_all_scored_in_desc_space("mix", "mqn100upto5_10upto4_mix_descs")
scored_mqn = load_all_scored_in_desc_space("mqn", "mix100upto5_10upto3_mqn_descs")

600106 descriptors loaded
574353 descriptors after dropping duplicates
400105 descriptors loaded
395994 descriptors after dropping duplicates
(20992, 211)
600236 descriptors loaded
588486 descriptors after dropping duplicates
500200 descriptors loaded
493236 descriptors after dropping duplicates
(20974, 44)


In [29]:
bindingdb_mix = add_targets_to_descriptors(
    "no_rare_targets_Kd<100_freq>100", "bindingdb_mix_Kd<100_freq>100"
)
bindingdb_mqn = add_targets_to_descriptors(
    "no_rare_targets_Kd<100_freq>100", "bindingdb_mqn_Kd<100_freq>100"
)

Loaded (7387, 210) descriptors
Added targets, (7387, 213) entries post merging
Loaded (7387, 43) descriptors
Added targets, (7387, 46) entries post merging


In [None]:
explore_umap(
    bindingdb_mix,
    "mix",
    [60, 70, 80, 90, 100],
    [0.99],
    prefix="bindingdb",
    perform_pca=42,
    score_key="Target Name",
)

In [49]:
explore_tsne(
    data=scored_mqn,
    desc_type="mqn",
    perplexities=[30, 60, 90],
    early_exaggerations=[40],
    # prefix="bindingdb",
    prefix="scored",
    # score_key="Target Name",
    score_key="score",
)

100%|██████████| 3/3 [03:33<00:00, 71.14s/it]


In [None]:
configs = [
    (bindingdb_mix, "mix", "bindingdb"),
    (bindingdb_mqn, "mqn", "bindingdb"),
    (scored_mix, "mix", "scored"),
    (scored_mqn, "mqn", "scored"),
]
for dataset, desc_type, prefix, score_key in configs:
    explore_pca2d(
        data=dataset,
        desc_type=desc_type,
        prefix=prefix,
        # prefix="scored",
        score_key="Target Name",
        # score_key="score",
    )

In [71]:
visualize_pickle_scatter(
    # fname=f"bindingdb_mix_tsne_perp30,60,90_ee40",
    fname=f"bindingdb_mqn_tsne_perp30,60,90_ee40",
    method="t-SNE",
    # desc_type="MIX",
    desc_type="MQN",
    # figure_fname=f"bindingdb_mqn_tsne_perp60_ee40",
    data_index=1,
    width=800,
    height=600,
)

In [86]:
visualize_pickle_heatmap(
    # fname=f"scored_mix_tsne_perp30,60,90_ee40",
    fname=f"scored_mqn_tsne_perp30,60,90_ee40",
    method="t-SNE",
    # desc_type="MIX",
    desc_type="MQN",
    figure_fname=f"scored_mqn_tsne_perp60_ee40",
    force_zmax=50,
    bin_size=1,
    width=700,
    height=500,
    data_index=1,
)

# Correlation

In [None]:
import plotly.graph_objs as go


def plot_correlation_circle(pca, features):
    pcs = pca.components_

    # Create a trace for the variable vectors
    vectors = go.Scatter(
        x=pcs[0, :],
        y=pcs[1, :],
        mode="lines+markers+text",
        text=features,
        textposition="top center",
        line=dict(color="red"),
        marker=dict(size=10, color="blue"),
        textfont=dict(size=8),
    )

    # Create a trace for the unit circle
    circle = go.Scatter(
        x=np.cos(np.linspace(0, 2 * np.pi, 100)),
        y=np.sin(np.linspace(0, 2 * np.pi, 100)),
        mode="lines",
        line=dict(color="blue", width=1),
        showlegend=False,
    )

    layout = go.Layout(
        title="Correlation Circle",
        autosize=False,
        width=800,
        height=800,
        showlegend=False,
        xaxis=dict(
            title=f"PC1 ({pca.explained_variance_ratio_[0]*100:.2f}%)",
            range=[-1.1, 1.1],
            zeroline=False,
            showgrid=True,
            domain=[0, 1],
        ),
        yaxis=dict(
            title=f"PC2 ({pca.explained_variance_ratio_[1]*100:.2f}%)",
            range=[-1.1, 1.1],
            zeroline=False,
            showgrid=True,
            domain=[0, 1],
        ),
    )

    fig = go.Figure(data=[vectors, circle], layout=layout)
    fig.show()


# Assuming pca is your PCA model fitted with sklearn and df is the pandas dataframe with your original data
plot_correlation_circle(pca, descriptors.columns.values)