In [None]:
from pathlib import Path
from itertools import combinations

import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from rdkit import Chem

from biosynfoni.moldrawing import draw, _get_highlight_loc_and_col

colourDict = {
    "fp": {  # colours from colorbrewer2.org for qualitative data
        "bsf": "#66c2a5",  # teal
        "maccs": "#fc8d62",  # orange
        "rdk": "#8da0cb",  # purpleblue
        "morgan": "#e78ac3",  # pink
    },
    "fp_ls": {
        "bsf": "-",
        "maccs": "-.",
        "rdk": (5, (10, 3)),
        "morgan": "dotted",
    },
    "taxonomy": {
        "Viridiplantae": "#B9C311",  # green
        "Bacteria": "#9BC2BA",  # teal
        "Fungi": "#CFC0D6",  # purple
        "Metazoa": "#FFAD61",  # light orange
        "Archaea": "#EB6737",  # soft red
        "Eukaryota": "#D2CEC4",  # pale grey
        "Cellular organisms": "#FFEAA0",  # yellow
        "Opisthokonta": "#FFC4CE",  # pink
    },
    "separation": {
        "1": "#081d58",  #'#57BAC0',  # navy,
        "2": "#225ea8",  #'#77BC4D',  # royal blue,
        "3": "#41b6c4",  #'#F3C55F',  # teal,
        "4": "#7fcdbb",  #'#F48861',  # turquoise,
        "5": "#a7d9b4",  #'#F7A8B8',  # lemon green,
        "6": "#c7e9b4",  #'#F9CDAE',  # pale green,
        "7": "#FFEAA0",  # yellow
        "8": "#FF8B61",  # orange
        # "-1": "#c7e9b4",  #'#797979',  # lemon green,
        "-1": "#797979",  #'#797979',  # grey,
        "random pairs": "#c7e9b4",
        "control": "#c7e9b4",
        
    },
    "pathways": {
        "shikimate": "#A783B6",  # purple
        "acetate": "#FF8B61",  # orange,
        "mevalonate": "#B9C311",  # green,
        "methylerythritol": "#6FB5C6",  # blue
        "sugar": "#FFC4CE",  # pink
        "amino": "#FFEAA0",  # yellow
        "amino_acid": "#FFEAA0",  # yellow
    },
    "class": {
        # "Terpenoids": "#9BC2BA",  # soft bluegreen
        "Terpenoids": "#B9C311",  # green
        "Alkaloids": "#B4CAD8",  # purple
        "Shikimates and Phenylpropanoids": "#A783B6",  # purple
        "Fatty acids": "#FF8B61",  # orange
        "Carbohydrates": "#FFC4CE",  # pink
        "Polyketides": "#C21100",  # soft red
        "Amino acids and Peptides": "#FFEAA0",  # yellow
        "No NP-Classifier prediction": "#797979",
        "None": "#595959",
        "Synthetic": "#393939",
        "Multiple": "#BBBBBB",
        # lowercase
        # "terpenoids": "#9BC2BA",  # soft bluegreen
        "terpenoids": "#B9C311",  # green
        "alkaloids": "#B4CAD8",  # purple
        "shikimates and phenylpropanoids": "#A783B6",  # purple
        "fatty acids": "#FF8B61",  # orange
        "carbohydrates": "#FFC4CE",  # pink
        "polyketides": "#C21100",  # soft red
        "amino acids and peptides": "#FFEAA0",  # yellow
        # chebi:
        "phenylpropanoid": "#A783B6",  # purple
        "fatty_acid": "#FF8B61",  # orange
        "polyketide": "#C21100",  # soft red
        "alkaloid": "#B4CAD8",  # purple
        # "isoprenoid": "#9BC2BA",  # soft bluegreen
        "isoprenoid": "#B9C311",  # green
        "carbohydrate": "#FFC4CE",  # pink
        "amino_acid": "#FFEAA0",  # yellow
        "synthetic": "#393939",  # grey
    },
    "NPClassifier prediction": {
        "Terpenoids": "#B9C311",  # green
        "Alkaloids": "#B4CAD8",  # purple
        "Shikimates and Phenylpropanoids": "#A783B6",  # purple
        "Fatty acids": "#FF8B61",  # orange
        "Carbohydrates": "#FFC4CE",  # pink
        "Polyketides": "#C21100",  # soft red
        "Amino acids and Peptides": "#FFEAA0",  # yellow
        "No NP-Classifier prediction": "#797979",
        "None": "#595959",
        "Synthetic": "#393939",
        "Multiple": "#BBBBBB",
    },
    "chebi class": {
        "phenylpropanoid": "#A783B6",  # purple
        "fatty_acid": "#FF8B61",  # orange
        "polyketide": "#C21100",  # soft red
        "alkaloid": "#B4CAD8",  # purple
        # "isoprenoid": "#9BC2BA",  # soft bluegreen
        "isoprenoid": "#B9C311",  # green
        "carbohydrate": "#FFC4CE",  # pink
        "amino_acid": "#FFEAA0",  # yellow
        "synthetic": "#393939",  # grey
    },
}
fp_ac_to_name = {
    "bsf": "Biosynfoni",
    "maccs": "MACCS",
    "rdk": "RDKit",
    "morgan": "Morgan",
}
fp_name_to_ac = {v: k for k, v in fp_ac_to_name.items()}
colourDict["fp"].update({fp_ac_to_name[k]: v for k, v in colourDict["fp"].items()})
colourDict["separation"].update({separation: colour
    for separation, colour in zip(
        range(1, 7),
        sns.color_palette("mako", n_colors=7),
        )
    })
# colourDict["separation"].update({str(k): v for k, v in colourDict["separation"].items()})

# set the biosynfoni style
plt.style.use("biostylefoni.mplstyle")

folder = Path.home() / "article_bsf"
fig_folder = folder / "figures"
fig_folder.mkdir(exist_ok=True)

# Figure 1: fingerprint example, pathway similarity example, biosynthetic distance results

## biosynthetic distance

In [None]:
import seaborn as sns

df = pd.read_csv(
    f"{Path().home()}/article_bsf/output/biosynthetic_distances.tsv", sep="\t"
)
df.separation = df.separation.astype(str).replace("-1", "random pairs")
df.separation = df.separation.apply(lambda x: x.zfill(2))
df.replace(-1, 0, inplace=True)
together = df.melt(id_vars=["separation"], value_vars=["bsf", "maccs", "rdk", "morgan"])
together.rename(
    columns={"value": "similarity", "variable": "fingerprint"}, inplace=True
)

hue_kws = {"ls": ["-", "-.", (5, (10, 3)), "dotted"]}
g = sns.FacetGrid(
    together,
    row="separation",
    hue="fingerprint",
    hue_kws=hue_kws,
    palette=colourDict["fp"],
    sharex=True,
    sharey=True,
    aspect=5,
    height=1,
)


g.refline(y=0, linewidth=1, linestyle="-", color="grey", clip_on=False)
g.map(sns.kdeplot, "similarity", fill=True, alpha=0.3, lw=0, zorder=-1)
g.map(
    sns.kdeplot,
    "similarity",
    lw=1.5,
    alpha=1,
    zorder=30,
    fill=False,
)

# # change the hue_kws to change the linestyle
# hue_kws = {"ls": ["-", "-", "-", "-"]}
# g.add_legend(bbox_to_anchor=(0.07, 0.9), loc="upper left")
# g.hue_kws = hue_kws
# g.map(sns.kdeplot, "similarity", fill=False, lw=1.5, zorder=-5, color="w")


# for the first column, add the separation as text above the refline
for ax, separation in zip(g.axes[:, 0], together.separation.unique()):
    ax.text(
        0,
        0.15,
        int(separation) if separation != "random pairs" else "random\npairs",
        ha="left",
        va="center",
        transform=ax.transAxes,
        fontsize=12,
        fontweight="medium",
    )
    # add the number of compounds in the separation
    n = len(df[df.separation == separation])
    ax.text(
        1,
        0.1,
        f"n={n}",
        ha="right",
        va="center",
        transform=ax.transAxes,
        fontsize=8,
        fontweight="medium",
    )


g.figure.subplots_adjust(
    hspace=-0.6,
    # wspace=-1,
)

g.set_titles("")
g.set(yticks=[], ylabel="")

handles, labels = g._legend_data.values(), g._legend_data.keys()
labels = [fp_ac_to_name[label] for label in labels]
g.add_legend(handles=handles, labels=labels,bbox_to_anchor=(0.09, 0.9), loc="upper left")

g.despine(bottom=True, left=True)
# add legend

# crop 10 pixels from the top
plt.savefig("ridgeplot.png", bbox_inches="tight", pad_inches=0.1)
img = plt.imread("ridgeplot.png")
# crop top 10 pixels
plt.imsave(
    fig_folder / "biosynthetic_distance.png", img[200:, :]
)  # to cut off the top 10 pixels due to high y-axis

In [None]:
import seaborn as sns
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset

df = pd.read_csv(
    f"{Path().home()}/article_bsf/output/biosynthetic_distances.tsv", sep="\t"
)
df.separation = df.separation.astype(str).replace("-1", "random pairs")
# make them all 01, 02, 03, 04, 05, 06, 07, 08, 09, 10
df.separation = df.separation.apply(lambda x: x.zfill(2))
df.replace(-1, 0, inplace=True)

# make a boxplot of the data for each column in the dataframe, per separation, all vertical boxplots
fp_colours = {
    "bsf": "#66c2a5",
    "maccs": "#fc8d62",
    "rdk": "#8da0cb",
    "morgan": "#e78ac3",
}
together = df.melt(id_vars=["separation"], value_vars=["bsf", "maccs", "rdk", "morgan"])
together.rename(
    columns={"value": "similarity", "variable": "fingerprint"}, inplace=True
)

palette = {
    separation: colour
    for separation, colour in zip(
        together.separation.unique()[:-1],
        sns.color_palette("mako", n_colors=len(together.separation.unique()) - 1),
    )
}
palette["random pairs"] = "#888"
g = sns.FacetGrid(
    together,
    row="fingerprint",
    hue="separation",
    palette=palette,
    sharex=True,
    sharey=True,
    aspect=2,
    height=1.5,
)
g.refline(y=0, linewidth=1, linestyle="-", color="grey", clip_on=False)

# add the multiple kde plots for each fingerprint
for fp in together.fingerprint.unique():
    g.map(sns.kdeplot, "similarity", fill=True, alpha=0.1, lw=0, zorder=2)
    g.map(
        sns.kdeplot,
        "similarity",
        lw=1,
        alpha=1,
        zorder=3,
        fill=False,
    )
    g.add_legend()
    for text in g.legend.texts:
        text.set_text(
            f"{int(text.get_text()) if text.get_text() != 'random pairs' else text.get_text()}"
        )
    g.map(sns.kdeplot, "similarity", fill=True, alpha=0.6, lw=5, zorder=1, color="w")
g.figure.subplots_adjust(hspace=-0.5)

# add the labels for each refline
for ax, fingerprint in zip(g.axes[:, 0], together.fingerprint.unique()):
    ax.text(
        0,
        0.1,
        fp_ac_to_name[fingerprint],
        ha="left",
        va="center",
        transform=ax.transAxes,
        fontsize=8,
        fontweight="medium",
    )

    # get the location of the last line for each ax
    line = ax.get_lines()[-1]
    xdata, ydata = line.get_xdata(), line.get_ydata()
    idx = np.where(ydata > ydata.max() / 3)[0][
        0
    ]  # first index where the ydata is higher than 1/3 of the max
    x, y = xdata[idx], ydata[idx]
    # add the text
    ax.text(
        x - 0.02,
        y,
        "random pairs",
        ha="right",
        va="center",
        fontsize=5,
        fontweight="bold",
        color=palette["random pairs"],
    )
    line = ax.get_lines()[together.separation.nunique()]
    xdata, ydata = line.get_xdata(), line.get_ydata()
    # last 10% of the x-axis
    idx = ydata.argmax()
    x, y = xdata[idx], ydata[idx]
    # add the text
    ax.text(
        x + 0.02,
        y - 0.01,
        "compound pairs \n  with one reaction \n        between them",
        ha="left",
        va="bottom",
        fontsize=5,
        fontweight="bold",
        color=palette["01"],
    )
    if fingerprint == "bsf":
        maxima = []
        n = together.separation.nunique()
        for i in range(n - 1):
            line = ax.get_lines()[n + i]
            xdata = line.get_xdata()
            ydata = line.get_ydata()
            idx = ydata.argmax()
            x, y = xdata[idx], ydata[idx]
            maxima.append((x, y))
        # add a curved arrow in white connecting the maxima
        ax.annotate(
            "",
            xy=maxima[0],
            xytext=maxima[-1],
            arrowprops=dict(
                arrowstyle="<-",
                lw=1.5,
                color="w",
                connectionstyle="arc3,rad=0.2",
            ),
            zorder=1000,
        )


g.set_titles("")
g.set(yticks=[], ylabel="")
# make ticklabels for the x-axis bigger

g.despine(bottom=True, left=True)

plt.savefig(fig_folder/"sm_biosynthetic_distance_separations.png", bbox_inches="tight", pad_inches=0.1)

## fingerprint example

In [None]:
import pandas as pd
from rdkit import Chem

from biosynfoni import draw_with_highlights

df = pd.read_csv(
    folder / "data" / "input" / "coconut_properties.csv",
)
fp = np.loadtxt(folder / "fps" / "coconut_bsf.csv", delimiter=",", dtype=int)
df["n_bits"] = fp.astype(bool).astype(int).sum(axis=1)

In [None]:
# sort the dataframe by the number of bits set
df = df.sort_values("n_bits", ascending=False)
row = df.iloc[140]  # 70 is nice

mol = Chem.MolFromSmiles(row.canonical_smiles)
mol

In [None]:
svg = draw_with_highlights(mol)
with open(fig_folder / "mol.svg", "w") as f:
    f.write(svg)

## pathway reconstruction

In [None]:
import ast
import networkx as nx
import pandas as pd
import numpy as np

df = pd.read_csv(
    f"{Path().home()}/article_bsf/output/reconstructed_pathways.tsv",
    sep="\t",
    header=0,
    index_col=0,
)
# read in lists as lists
for col in df.columns:
    df[col] = df[col].apply(lambda x: ast.literal_eval(x) if not pd.isna(x) else x)
# find where bsf_independent is same as true but maccs_independent is different

df["true_r"] = df["true"].apply(lambda x: x[::-1])


def correct(col):
    return (df[col] == df["true"]) | (df[col] == df["true_r"])


# fig, ax = plt.subplots()
for key, val in {
    "bsf": "biosynfoni",
    "maccs": "maccs",
    "rdk": "rdkit",
    "morgan": "morgan",
}.items():
    independent = df[correct(f"{key}_independent")]
    with_hint = df[correct(f"{key}_f_start") | correct(f"{key}_f_end")]
    travellings_salesman = df[correct(f"{key}_tsp")]

    # independent = df[(df[f"{key}_independent"] == df.true )| (df[f"{key}_independent"] == df.true_r)]
    # with_hint = df[(df[f"{key}_f_start"] == df.true) | (df[f"{key}_f_end"] == df.true)]

    print(
        val,
        "independent: ",
        independent.shape[0],
        "with hint: ",
        with_hint.shape[0],
        "out of ",
        df.shape[0],
    )
    ind_lengths = [len(x) for x in independent["true"]]
    with_lengths = [len(x) for x in with_hint["true"]]

    print(
        "independent: ",
        np.mean(ind_lengths),
        np.std(ind_lengths),
        "with hint: ",
        np.mean(with_lengths),
        np.std(with_lengths),
    )
    print("travellings_salesman: ", travellings_salesman.shape[0])

In [None]:
def longest_common_subsequence(true_list, list2):
    m, n = len(true_list), len(list2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if true_list[i - 1] == list2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + 1
            else:
                dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

    return dp[m][n]


def percentage(true_list, list2):
    return longest_common_subsequence(true_list, list2) / len(true_list)


a = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
b = [7, 8, 9, 10, 1, 2, 3, 4, 5]

print(longest_common_subsequence(a, b))

for key, val in {
    "bsf": "biosynfoni",
    "maccs": "maccs",
    "rdk": "rdkit",
    "morgan": "morgan",
}.items():
    count = 0
    for pathway in df.index:
        true_pathway = df.loc[pathway, "true"]
        predicted_pathway = df.loc[pathway, f"{key}_independent"]
        if not isinstance(predicted_pathway, list):
            continue
        lcs_length = longest_common_subsequence(true_pathway, predicted_pathway)
        if lcs_length >= 0.7 * len(true_pathway):
            count += 1
    print(
        f"{val}: {count} reactions have an LCS >= 0.7 of the length of the true pathway"
    )
    results = []

# plot the percentages function result of the independent and with_hint pathways for each key val per length of the true pathway
fig, ax = plt.subplots()
for key, val in {
    "bsf": "biosynfoni",
    "maccs": "maccs",
    "rdk": "rdkit",
    "morgan": "morgan",
}.items():
    results = []
    for pathway in df.index:
        true_pathway = df.loc[pathway, "true"]
        predicted_pathway = df.loc[pathway, f"{key}_independent"]
        if not isinstance(predicted_pathway, list):
            continue
        results.append(
            longest_common_subsequence(true_pathway, predicted_pathway)
            / len(true_pathway)
        )
    sns.histplot(results, bins=20, ax=ax, label=val, kde=True)
plt.legend()

In [None]:
# df["true"] = df["true"].apply(ast.literal_eval)
len(df.loc["PWY-8133"].true)

In [None]:
# draw the pathway
from rdkit import Chem

id_to_mol = {
    mol.GetProp("compound_id"): mol
    for mol in Chem.SDMolSupplier(f"{Path().home()}/article_bsf/data/input/metacyc.sdf")
}

In [None]:
pws = [
    "PWY-6915",
    "PWY-7135",
    "PWY-7483",
    "PWY-7711",
    "PWY-7736",
    "PWY-8133",
    "PWY2DNV-5",
]
compare = ["bsf_f_start", "rdk_f_start"]
# for pw in pws[1:]:
for pw in pws[0:]:
    mols = [id_to_mol[cpd] for cpd in df.loc[pw].true]
    # draw mols to grid
    print(
        *[
            id_
            for id_ in zip(
                df.loc[pw].true, df.loc[pw][compare[0]], df.loc[pw][compare[1]]
            )
        ],
        sep="\n",
    )
    break
Chem.Draw.MolsToGridImage(
    mols,
    molsPerRow=5,
    subImgSize=(200, 200),
    legends=[mol.GetProp("_Name") for mol in mols],
)
# get similarities between the mols

# Figure 2: Applicability domain - calculation times, coverage, substructure distribution

## calculation times

In [None]:
def times(fp_name) -> np.array:
    return np.loadtxt(
        f"{Path().home()}/article_bsf/fps/coconut_{fp_name}_times.csv",
        delimiter=",",
        dtype=float,
    )


def time_size_stats(times, sizes) -> np.array:
    """
    Function to get the data per fingerprint
    """
    assert len(times) == len(sizes), "The length of times and sizes should be the same"

    unique_sizes = np.unique(sizes)

    average_times = []
    for size in unique_sizes:
        average_times.append(np.mean(times[sizes == size]))

    std_times = []
    for size in unique_sizes:
        std_times.append(np.std(times[sizes == size]))

    return np.array(
        [
            (size, avg_time, std_time)
            for size, avg_time, std_time in zip(unique_sizes, average_times, std_times)
        ],
        dtype=[("size", int), ("average_time", float), ("std_time", float)],
    )

In [None]:
actual_names = fp_ac_to_name
fig, ax = plt.subplots(figsize=(5, 4))
properties_path = Path().home() / "article_bsf/data/input/coconut_properties.csv"
sizes = pd.read_csv(properties_path, index_col=0)["heavy_atom_count"].values.tolist()
for fp_name in actual_names.keys():
    ax.set_xlabel("Number of heavy atoms")
    ax.set_ylabel("Time (ms)")

    data = time_size_stats(times(fp_name)*1000, sizes)
    ax.plot(
        data["size"],
        data["average_time"],
        color=colourDict["fp"][fp_name],
        label=actual_names[fp_name],
        linestyle=colourDict["fp_ls"][fp_name],
        zorder=2,
    )

    # plot the 50% confidence interval around the average time per size
    ax.fill_between(
        data["size"],
        data["average_time"] - 0.5 * data["std_time"],
        data["average_time"] + 0.5 * data["std_time"],
        color=colourDict["fp"][fp_name],
        alpha=0.3,
        zorder=1,
    )

# plot the sizes behind it with the y axis on the right
ax_twin = ax.twinx()
ax_twin.hist(
    sizes,
    bins=np.linspace(0, data["size"].max(), data["size"].max()),
    alpha=0.1,
    color="black",
    zorder=0,
)
ax_twin.set_ylabel("Number of molecules")
# set all ticklabelsize to 8
ax.tick_params(axis="both", which="major", labelsize=8)
ax_twin.tick_params(axis="both", which="major", labelsize=8)


ax.legend(title="Fingerprint", loc="upper left", bbox_to_anchor=(1.16, 1))
fig.savefig(fig_folder / "generation_times_full.png", dpi=300, bbox_inches="tight")

In [None]:
actual_names = fp_ac_to_name
fig, ax = plt.subplots(figsize=(5, 4))
properties_path = Path().home() / "article_bsf/data/input/coconut_properties.csv"
sizes = pd.read_csv(properties_path, index_col=0)["heavy_atom_count"].values.tolist()
for fp_name in actual_names.keys():
    ax.set_xlabel("Number of heavy atoms")
    ax.set_ylabel("Time (ms)")

    data = time_size_stats(times(fp_name)*1000, sizes)
    data = data[data["size"] < 100]
    ax.plot(
        data["size"],
        data["average_time"],
        color=colourDict["fp"][fp_name],
        label=actual_names[fp_name],
        linestyle=colourDict["fp_ls"][fp_name],
        zorder=2,
    )

    # plot the 50% confidence interval around the average time per size
    ax.fill_between(
        data["size"],
        data["average_time"] - 0.5 * data["std_time"],
        data["average_time"] + 0.5 * data["std_time"],
        color=colourDict["fp"][fp_name],
        alpha=0.3,
        zorder=1,
    )

# plot the sizes behind it with the y axis on the right
ax_twin = ax.twinx()
ax_twin.hist(
    sizes,
    bins=np.linspace(0, data["size"].max(), data["size"].max()),
    alpha=0.1,
    color="black",
)
ax_twin.set_ylabel("Number of molecules")
# set all ticklabelsize to 8
ax.tick_params(axis="both", which="major", labelsize=8)
ax_twin.tick_params(axis="both", which="major", labelsize=8)


ax.legend(handles=handles, title="Fingerprint", loc="upper left", bbox_to_anchor=(1.16, 1))
fig.savefig(fig_folder / "generation_times_cropped.png", dpi=300, bbox_inches="tight")

In [None]:
print(np.average(times("maccs")),np.average(times("maccs"))*len(sizes))
print(np.average(times("morgan")),np.average(times("morgan"))*len(sizes))
print(np.average(times("rdk")),np.average(times("rdk"))*len(sizes))
print(np.average(times("bsf")),np.average(times("bsf"))*len(sizes))

## substructure occurrence

In [None]:
from pathlib import Path

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from biosynfoni.subkeys import get_names, get_smarts


ar = np.loadtxt(
    f"{Path().home()}/article_bsf/fps/coconut_bsf.csv",
    delimiter=",",
    dtype=int,
)
df = pd.DataFrame(ar, columns=get_names())
colours = [
        "grey",
        "grey",
        "grey",
        "#FFEAA0",
        "#FFEAA0",
        "#FFC4CE",  
        "#FFC4CE",  
        "#FFC4CE",  
        "#FFC4CE",  
        "#A783B6",
        "#A783B6",
        "#FF8B61",
        "#FF8B61",
        "#A783B6",
        "#A783B6",
        "#A783B6",
        "#B9C311",  
        "#FF8B61",
        "#FF8B61",
        "#FF8B61",
        "#FF8B61",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
    ]

In [None]:
fig, ax = plt.subplots(figsize=(1, 4))
# plot the minimum and the maximum of the substructure counts as points, and make a thick line between connect them
min_max = df.agg(["min", "max"])
for i in range(min_max.shape[1]):
    plt.plot(
        [min_max.iloc[0, i], min_max.iloc[1, i]],
        [i, i],
        linewidth=1,
        alpha=0.5,
        color=colours[i],
    )
    # plt.plot(
    #     [min_max.iloc[0, i], min_max.iloc[1, i]],
    #     [i, i],
    #     linewidth=0,
    #     markersize=2,
    #     marker="o",
    #     markerfacecolor=mpl.color_sequences["tab10"][i%10],
    #     markeredgecolor="none",
    # )
mean = df.agg(["mean"])

for i in range(mean.shape[1]):
    plt.plot(
        mean.iloc[0, i],
        i,
        marker="|",
        markersize=3,
        markeredgecolor="black",
        linewidth=0,
        alpha=0.5,
        zorder=4,
    )

q1_q3 = df.agg(["quantile"], q=[0.25, 0.75])
for i in range(q1_q3.shape[1]):
    plt.plot(
        [q1_q3.iloc[0, i], q1_q3.iloc[1, i]],
        [i, i],
        marker="o",
        linewidth=3,
        markersize=2,
        markerfacecolor="none",
        markeredgecolor=colours[i],
        color=colours[i],
    )
# set correct y labels from df.columns
plt.yticks(range(len(df.columns)), df.columns)
# reverse order of y labels
plt.gca().invert_yaxis()
plt.xscale("log")
plt.xlim(1, 500)
# remove x ticks and labels
plt.yticks([])

# show y axis line
# ax.spines["left"].set_linewidth(0.5)
plt.savefig(fig_folder / "occurrences.png", dpi=300, bbox_inches="tight")

In [None]:
# takes too long
# fig, ax = plt.subplots(figsize=(1, 5))
# # plot the minimum and the maximum of the substructure counts as points, and make a thick line between connect them

# sns.catplot(data=df, orient="h", alpha=0.01, s=1, ax=ax, kind="swarm", palette=colours)

# # set correct y labels from df.columns
# plt.yticks(range(len(df.columns)), df.columns)
# # reverse order of y labels
# plt.gca().invert_yaxis()
# plt.xscale("log")
# plt.xlim(1, 500)
# # remove x ticks and labels
# plt.yticks([])

# # show y axis line
# # ax.spines["left"].set_linewidth(0.5)

## PMI map

In [None]:
#!/usr/bin/env python3
import argparse, logging, os
import math
from collections import Counter

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.sparse import csc_matrix
from scipy.sparse.linalg import svds
from tqdm import tqdm

from figures import set_label_colors


def add_minuses(heatmap, array):
    for i in range(array.shape[0]):
        for j in range(array.shape[1]):
            text = ""
            # if array[i, j] > 0:
            #     text = "+"
            if array[i, j] < 0:
                text = "-"
            heatmap.text(
                j + 0.5,
                i + 0.5,
                text,
                horizontalalignment="center",
                verticalalignment="center",
                color="white",
            )
    return None


def get_labels() -> list[str]:
    keys = [
        "coa",
        "nadh",
        "nadph",
        "all standard aminos",
        "non-standard aminos",
        "open pyranose",
        "open furanose",
        "pyranose",
        "furanose",
        "indoleC2N",
        "phenylC2N",
        "c5n",
        "c4n",
        "phenylC3",
        "phenylC2",
        "phenylC1",
        "isoprene",
        "acetyl",
        "methylmalonyl",
        "ethyl",
        "methyl",
        "phosphate",
        "sulfonate",
        "fluorine",
        "chlorine",
        "bromine",
        "iodine",
        "nitrate",
        "epoxy",
        "ether",
        "hydroxyl",
        "c3 ring",
        "c4 ring",
        "c5 ring",
        "c6 ring",
        "c7 ring",
        "c8 ring",
        "c9 ring",
        "c10 ring",
    ]
    return keys


def get_colours() -> list[str]:
    colours = [
        "grey",
        "grey",
        "grey",
        "#FFEAA0",
        "#FFEAA0",
        "#FFC4CE",  # pink
        "#FFC4CE",  # pink
        "#FFC4CE",  # pink
        "#FFC4CE",  # pink
        "#A783B6",
        "#A783B6",
        "#FF8B61",
        "#FF8B61",
        "#A783B6",
        "#A783B6",
        "#A783B6",
        "#B9C311",  # green
        "#FF8B61",
        "#FF8B61",
        "#FF8B61",
        "#FF8B61",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
    ]
    return colours


def _pearson_correlation_matrix(fps: np.array) -> np.array:
    # randomly subsample 10000

    # fps = fps[np.random.choice(fps.shape[0], 10000, replace=False), 3:]
    # fps = fps[:, 3:]
    mat = np.corrcoef(fps, rowvar=False, dtype=np.float16)
    # print(fps.shape)
    # print(mat.shape)
    return mat


def correlation_heatmap(fps: np.array) -> None:
    keys = get_labels()
    correlations = _pearson_correlation_matrix(fps)
    np.savetxt("correlations.tsv", correlations, delimiter="\t", fmt="%.2f")
    logging.warning(correlations.shape)
    hm = sns.heatmap(
        correlations,
        xticklabels=keys,
        yticklabels=keys,
        # cmap="coolwarm",
        cmap="PiYG",  # more colorblind-friendly diverging colormap
        vmin=np.min(fps),
        vmax=np.max(fps),
        center=0,  # center of the colormap
    )
    # hm.text(0.5, 0.5, "test", horizontalalignment="center", verticalalignment="center")
    # for all negative values, add a minus sign
    add_minuses(hm, fps)
    colours = get_colours()
    xt, yt = hm.get_xticklabels(), hm.get_yticklabels()
    # set_label_colors(hm.get_xticklabels(), colours)
    # set_label_colors(hm.get_yticklabels(), colours)
    return hm


fps = np.loadtxt(Path.home() / "article_bsf" / "fps" / "coconut_bsf.csv", delimiter=",")


# Count the number of times each bit is set.
cx = Counter()
cxy = Counter()

for idx in tqdm(range(fps.shape[0])):
    for bit_idx, bit in enumerate(fps[idx]):
        if bit > 0:
            cx[bit_idx] += bit

        for bit_idx2, bit2 in enumerate(fps[idx]):
            # if bit_idx == bit_idx2:
            #     continue
            if bit > 0 and bit2 > 0:
                # cxy[(bit_idx, bit_idx2)] += 1
                cxy[(bit_idx, bit_idx2)] += min(bit, bit2)

# Create lookup between key and fingerprint index.
x2i, i2x = {}, {}
keys = get_labels()
for i, x_data in enumerate(keys):
    x2i[x_data] = i
    i2x[i] = x_data

# Build sparse PMI matrix.
sx = sum(cx.values())
sxy = sum(cxy.values())
data, rows, cols = [], [], []
for (x_data, y_data), n in cxy.items():
    rows.append(x_data)
    cols.append(y_data)
    data.append(math.log((n / sxy) / (cx[x_data] / sx) / (cx[y_data] / sx)))

PMI = csc_matrix((data, (rows, cols)))
mat = PMI.toarray()


In [None]:

# Visualize matrix.
fig = plt.figure(figsize=(8, 6))

hm = sns.heatmap(
    mat,
    xticklabels=keys,
    yticklabels=keys,
    # cmap="coolwarm",
    cmap="PiYG",  # more colorblind-friendly diverging colormap
    vmin=np.min(mat),
    vmax=np.max(mat),
    center=0,  # center of the colormap
)
add_minuses(hm, mat)

# make ticklabels bold
hm.set_yticklabels(hm.get_yticklabels(), fontweight="bold")
hm.set_xticklabels(hm.get_xticklabels(), fontweight="bold")

# set_label_colors(hm.get_xticklabels(), colours)
# set_label_colors(hm.get_yticklabels(), colours)
colours = get_colours()
set_label_colors(hm.get_xticklabels(), colours)
set_label_colors(hm.get_yticklabels(), colours)
# get the colorbar and rotate the ticklabels
cbar = hm.collections[0].colorbar
cbar.ax.tick_params(labelsize=12)
# make it shorter
cbar.ax.set_ylabel("Pointwise Mutual Information", rotation=90, labelpad=15)

plt.ylabel("")
plt.xlabel("")

plt.savefig(fig_folder / "pmi.png", bbox_inches="tight")

# set all tick labels to "L"
plt.xticks(hm.get_xticks(), ["    " for _ in hm.get_xticks()])
plt.yticks(hm.get_yticks(), ["    " for _ in hm.get_yticks()])

plt.savefig(fig_folder / "pmi_no_labels.png", bbox_inches="tight")


# plt.savefig( bbox_inches="tight")
# plt.close()

# # get a correlation heatmap as well
# corr_hm = correlation_heatmap(fps)
# plt.title("Pearson correlation - non-overlap Biosynfoni on COCONUT", size=12)
# plt.savefig(args.o.replace(".png", "_correlation.png"), bbox_inches="tight")
# plt.close()

In [None]:
#!/usr/bin/env python3
import argparse, logging, os
import math
from collections import Counter

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.sparse import csc_matrix
from scipy.sparse.linalg import svds
from tqdm import tqdm
from scipy.stats import zscore

# from helper import  set_label_colors


def add_minuses(heatmap, array):
    for i in range(array.shape[0]):
        for j in range(array.shape[1]):
            text = ""
            # if array[i, j] > 0:
            #     text = "+"
            if array[i, j] < 0:
                text = "-"
            heatmap.text(
                j + 0.5,
                i + 0.5,
                text,
                horizontalalignment="center",
                verticalalignment="center",
                color="white",
            )
    return None


def set_label_colors(ticklabels: list, colors: list) -> str:
    """
    Set the colours of labels

        Args:
            ticklabels (list): list of labels to change
            colors (list): list of colours to change to

        Returns:
            None
    """
    for label, color in zip(ticklabels, colors):
        # label.set_color(COLOUR_DICT[axis][label.get_text()])
        plt.setp(
            label,
            backgroundcolor=color,
            bbox=dict(
                facecolor=color,
                alpha=0.5,
                # boxstyle="round, rounding_size=0.8",
                boxstyle="round, rounding_size=0.7",
                edgecolor="none",
            ),
        )  # , height=0.3))
        # t.set_bbox(dict(facecolor=color, alpha=0.5, boxstyle="round"))  # , height=0.3))
    return None


def get_labels() -> list[str]:
    keys = [
        "coa",
        "nadh",
        "nadph",
        "all standard aminos",
        "non-standard aminos",
        "open pyranose",
        "open furanose",
        "pyranose",
        "furanose",
        "indoleC2N",
        "phenylC2N",
        "c5n",
        "c4n",
        "phenylC3",
        "phenylC2",
        "phenylC1",
        "isoprene",
        "acetyl",
        "methylmalonyl",
        "ethyl",
        "methyl",
        "phosphate",
        "sulfonate",
        "fluorine",
        "chlorine",
        "bromine",
        "iodine",
        "nitrate",
        "epoxy",
        "ether",
        "hydroxyl",
        "c3 ring",
        "c4 ring",
        "c5 ring",
        "c6 ring",
        "c7 ring",
        "c8 ring",
        "c9 ring",
        "c10 ring",
    ]
    return keys


def get_colours() -> list[str]:
    colours = [
        "grey",
        "grey",
        "grey",
        "#FFEAA0",
        "#FFEAA0",
        "#FFC4CE",  # pink
        "#FFC4CE",  # pink
        "#FFC4CE",  # pink
        "#FFC4CE",  # pink
        "#A783B6",
        "#A783B6",
        "#FF8B61",
        "#FF8B61",
        "#A783B6",
        "#A783B6",
        "#A783B6",
        "#B9C311",  # green
        "#FF8B61",
        "#FF8B61",
        "#FF8B61",
        "#FF8B61",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
        "grey",
    ]
    return colours


def _pearson_correlation_matrix(fps: np.array) -> np.array:
    # randomly subsample 10000

    # fps = fps[np.random.choice(fps.shape[0], 10000, replace=False), 3:]
    # fps = fps[:, 3:]
    mat = np.corrcoef(fps, rowvar=False, dtype=np.float16)
    # print(fps.shape)
    # print(mat.shape)
    return mat


def correlation_heatmap(fps: np.array) -> None:
    keys = get_labels()
    correlations = _pearson_correlation_matrix(fps)
    np.savetxt("correlations.tsv", correlations, delimiter="\t", fmt="%.2f")
    logging.warning(correlations.shape)
    hm = sns.heatmap(
        correlations,
        xticklabels=keys,
        yticklabels=keys,
        cmap="PiYG",  # more colorblind-friendly diverging colormap
        vmin=np.min(fps),
        vmax=np.max(fps),
        center=0,  # center of the colormap
    )
    # for all negative values, add a minus sign
    add_minuses(hm, fps)
    colours = get_colours()
    xt, yt = hm.get_xticklabels(), hm.get_yticklabels()
    set_label_colors(hm.get_xticklabels(), colours)
    set_label_colors(hm.get_yticklabels(), colours)
    return hm


fps = np.loadtxt(Path.home() / "article_bsf" / "fps" / "coconut_bsf.csv", delimiter=",")


# Count the number of times each bit is set.
cx = Counter()
cxy = Counter()

for idx in tqdm(range(fps.shape[0])):
    for bit_idx, bit in enumerate(fps[idx]):
        if bit > 0:
            cx[bit_idx] += bit

        for bit_idx2, bit2 in enumerate(fps[idx]):
            # if bit_idx == bit_idx2:
            #     continue
            if bit > 0 and bit2 > 0:
                # cxy[(bit_idx, bit_idx2)] += 1
                cxy[(bit_idx, bit_idx2)] += min(bit, bit2)

# Create lookup between key and fingerprint index.
x2i, i2x = {}, {}
keys = get_labels()
for i, x_data in enumerate(keys):
    x2i[x_data] = i
    i2x[i] = x_data

# Build sparse PMI matrix.
sx = sum(cx.values())
sxy = sum(cxy.values())
data, rows, cols = [], [], []
for (x_data, y_data), n in cxy.items():
    rows.append(x_data)
    cols.append(y_data)
    data.append(math.log((n / sxy) / (cx[x_data] / sx) / (cx[y_data] / sx)))

PMI = csc_matrix((data, (rows, cols)))
mat = PMI.toarray()
fig = plt.figure(figsize=(8, 6))
# hm = sns.heatmap(
#     mat,
#     xticklabels=keys,
#     yticklabels=keys,
#     cmap="PiYG",  # more colorblind-friendly diverging colormap
#     vmin=np.min(mat),
#     vmax=np.max(mat),
#     center=0,  # center of the colormap
# )
# Z-score normalization of the rows
zscored_mat = zscore(mat, axis=1)

# Update the heatmap with the z-scored matrix
hm = sns.heatmap(
    zscored_mat,
    xticklabels=keys,
    yticklabels=keys,
    cmap="PiYG",  # more colorblind-friendly diverging colormap
    vmin=np.min(zscored_mat),
    vmax=np.max(zscored_mat),
    center=0,  # center of the colormap
)
add_minuses(hm, mat)

hm.set_yticklabels(hm.get_yticklabels(), fontweight="bold")
hm.set_xticklabels(hm.get_xticklabels(), fontweight="bold")


plt.ylabel("Substructues")
plt.xlabel("Substructues")
plt.title(f"Pointwise Mutual Information of Biosynfoni Substructures", size=12)

In [None]:
from itertools import chain

len(
    set(
        chain(
            *[
                match
                for match in Biosynfoni(Chem.SDMolSupplier(sdf_path)[0]).matches
                for match in match
            ]
        )
    )
)
# Chem.SDMolSupplier(sdf_path)[0].GetNumHeavyAtoms()

In [None]:
df = pd.read_csv(
    f"{Path().home()}/article_bsf/data/raw_data/coconut_complete-10-2024.csv"
)
df = df[~df.organisms.isna()]

# from rdkit.Chem import PandasTools
# coconut_ = PandasTools.LoadSDF(f"{Path().home()}/article_bsf/data/_old_raw_data/COCONUT_DB.sdf")
# print(len(coconut_))

# coconut_ = coconut_.query("textTaxa != '[notax]'").copy()
# print(len(coconut_))
# coconut_
# r"textTaxa.*\n\[[^n]"

In [None]:
df[df.organisms.str.lower().str.contains("fungus")].shape

In [None]:
df_ = pd.read_csv(
    f"{Path().home()}/article_bsf/data/input/coconut_taxonomy.csv",
    header=None,
    names=["compounds", "taxonomy"],
)
print(df_.shape)  # 695133
df_ = df_[~df_.taxonomy.isna()]
print(df_.shape)  # 11286

In [None]:
df.columns

In [None]:
# show the projected data
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

df = pd.DataFrame(projected_X, columns=["x", "y"])
df["class"] = np.loadtxt(
    f"{Path().home()}/article_bsf/data/input/chebi_classes.csv",
    delimiter=",",
    dtype=str,
    usecols=1,
)[:1000]

sns.scatterplot(data=df, x="x", y="y", hue="class", palette="tab20")
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.show()

In [None]:
from sklearn import datasets
import pandas as pd

iris = datasets.load_iris()
X = iris.data
X = pd.DataFrame(X, columns=iris.feature_names)
from tmap.tda import mapper, Filter
from tmap.tda.cover import Cover
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler, MinMaxScaler

# Step1. initiate a Mapper
tm = mapper.Mapper(verbose=1)
# Step2. Projection
lens = [Filter.MDS(components=[0, 1], random_state=100)]
projected_X = tm.filter(X, lens=lens)
clusterer = DBSCAN(eps=0.75, min_samples=1)
cover = Cover(
    projected_data=MinMaxScaler().fit_transform(projected_X),
    resolution=20,
    overlap=0.75,
)

# Figure 3: Applications - unsupervised clustering, supervised classification, biosynthetic pathway reconstruction

## dimensionality reduction

In [None]:
from pathlib import Path

import pandas as pd
import numpy as np

# import PandasTools
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import PandasTools


df = pd.read_csv(
    f"{Path().home()}/article_bsf/output/bsf_tsne.csv", header=None, names=["bsf_x", "bsf_y"]
)
df[["maccs_x", "maccs_y"]] = np.loadtxt(
    f"{Path().home()}/article_bsf/output/maccs_tsne.csv", delimiter=",", dtype=float
)
# in this one, the tsne is better comparable if you mirror along the x axis
df["maccs_x"] = -df["maccs_x"]

df["class"] = np.loadtxt(
    f"{Path().home()}/article_bsf/data/input/chebi_classes.csv",
    delimiter=",",
    dtype=str,
    usecols=1,
)
df["class"] = df["class"].apply(lambda x: x.replace("_", " "))
df["id"] = np.loadtxt(
    f"{Path().home()}/article_bsf/data/input/chebi_classes.csv",
    delimiter=",",
    dtype=str,
    usecols=0,
)

# get a df from the sdf
sdf_path = Path().home()/"article_bsf/data/input/chebi.sdf"
df["mol"] = [mol for mol in Chem.SDMolSupplier(sdf_path)]
df["size"] = [mol.GetNumHeavyAtoms() for mol in df["mol"]]

# remove all molecules that have an "R" in them
df = df[~df["mol"].apply(lambda x: Chem.MolToSmiles(x).count("*") > 0)]

In [None]:
df.shape

In [None]:
fig, ax = plt.subplots(figsize=(6, 5))
sns.scatterplot(data=df[~df["class"].str.contains(";")], x="bsf_x", y="bsf_y", hue="class", palette="Spectral", alpha=0.5)
plt.gca().set_aspect("equal", adjustable="box")
plt.legend(loc="upper left", bbox_to_anchor=(1, 1), markerscale=3)
for handle in plt.gca().get_legend_handles_labels()[0]:
    handle.set_alpha(1)
    
plt.xlabel(""), plt.ylabel("")
plt.tick_params(axis="both", which="both", bottom=False,  left=False, labelbottom=False, labelleft=False)
plt.savefig(fig_folder / "tsne_bsf.png", dpi=300, bbox_inches="tight")

In [None]:
# exploring the middle fatty acids
sel = df[df["class"] == "fatty_acid"] 
sel = sel[(sel.bsf_x<10)& (sel.bsf_x>-10) & (sel.bsf_y<20) & (sel.bsf_y>-20)]

sel
Chem.Draw.MolsToGridImage(
    sel.mol.values,
    molsPerRow=5,
    subImgSize=(200, 200),
)

In [None]:
# exploring the middle fatty acids
sel = df[df["class"] == "fatty_acid"] 
sel = sel[(sel.bsf_x<10)& (sel.bsf_x>-10) & (sel.bsf_y<-80)]

sel
Chem.Draw.MolsToGridImage(
    sel.mol.values,
    molsPerRow=5,
    subImgSize=(200, 200),
)

In [None]:
# exploring the leftmost carbohydrate cluster
sel = df[df["class"] == "carbohydrate"] 
sel = sel[sel.bsf_x<0]
sel = sel[sel.bsf_y>50]
sel
Chem.Draw.MolsToGridImage(
    sel.mol.values,
    molsPerRow=5,
    subImgSize=(200, 200),
)

In [None]:
# exploring the leftmost carbohydrate cluster
sel = df[df["class"] == "carbohydrate"] 
sel = sel[sel.bsf_x<-80]
sel
Chem.Draw.MolsToGridImage(
    sel.mol.values,
    molsPerRow=5,
    subImgSize=(200, 200),
)


In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
sns.scatterplot(data=df[~df["class"].str.contains(";")], x="maccs_x", y="maccs_y", hue="class", palette="Spectral", alpha=0.5)
sns.scatterplot(data=sel, x="maccs_x", y="maccs_y", c="k",alpha=1)
plt.gca().set_aspect("equal", adjustable="box")
plt.legend(loc="upper left", bbox_to_anchor=(1, 1), markerscale=3)
for handle in plt.gca().get_legend_handles_labels()[0]:
    handle.set_alpha(1)

plt.tick_params(axis="both", which="both", bottom=False,  left=False, labelbottom=False, labelleft=False)
plt.savefig(fig_folder / "tsne_maccs_carbohydrates.png", dpi=300)

In [None]:
Chem.Draw.MolsToGridImage(
    sel[sel.maccs_y<-100].mol.values,
    molsPerRow=5,
    subImgSize=(200, 200),
)


In [None]:

Chem.Draw.MolsToGridImage(
    sel[sel.maccs_y>-10].mol.values,
    molsPerRow=5,
    subImgSize=(200, 200),
)

## classification

In [None]:
import re
import numpy as np
import pandas as pd

o_f = Path.home() / "article_bsf" / "output"
labels = np.loadtxt(o_f / "classifications.tsv", delimiter="\t", dtype=str)


def multilabel_and_dict(classifications: np.array) -> tuple[np.array, dict]:
    classes = set(";".join(map(str, classifications)).split(";"))
    class_to_id = {class_: i for i, class_ in enumerate(sorted(classes))}
    id_to_class = {i: class_ for class_, i in class_to_id.items()}
    classification_array = np.zeros((len(classifications), len(classes)), dtype=int)
    for i, classification in enumerate(classifications):
        for class_ in re.split(r"[;,]", classification):
            classification_array[i, class_to_id[class_]] = 1
    return classification_array, id_to_class


y_true, id_to_class = multilabel_and_dict(labels)


df = pd.read_csv(o_f / "ids.tsv", sep="\t", dtype=str, header=None, index_col=0)
df["y_true"] = y_true.tolist()
df["class"] = labels
df["k"] = np.loadtxt(o_f / "ks.csv", dtype=int)
files = (file for file in o_f.glob("*_proba.tsv") if not file.stem.startswith("tax"))
for n, file in enumerate(files):
    fp_name = file.stem.split("_")[0]
    df[file.stem] = np.loadtxt(file, delimiter="\t", dtype=float).tolist()
    df[file.stem.replace("proba", "pred")] = df[file.stem].apply(
        lambda x: (np.array(x) > 0.5).astype(int).tolist()
    )

In [None]:
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    precision_score,
    recall_score,
    f1_score,
)

# ig, ax = plt.subplots(2, 2, figsize=(5, 5))
# separate the binarised predictions from the probabilities per class
y = np.array(df["y_true"].tolist())
bsf_proba = np.array(df["bsf_proba"].tolist())
maccs_proba = np.array(df["maccs_proba"].tolist())
morgan_proba = np.array(df["morgan_proba"].tolist())
rdk_proba = np.array(df["rdk_proba"].tolist())


k_idx = {k: df[df["k"] == k].index for k in range(5)}

fp_names = [col.split("_")[0] for col in proba]
fp_idx = {fp: i for i, fp in enumerate(fp_names)}


def df_class(df, class_i):
    df = df.copy()
    cols = [
        col
        for col in df.columns
        if str(col).endswith("pred") or str(col).endswith("proba")
    ]
    cols += ["y_true"]
    df[cols] = np.array(df[cols].values.tolist())[:, :, class_i]
    return df


metrics = {
    "precision": precision_score,
    "recall": recall_score,
    "f1": f1_score,
    "roc_auc": roc_auc_score,
    "average_precision": average_precision_score,
}

In [None]:
a = np.array(
    df[[col for col in df.columns if str(col).endswith("pred")]].values.tolist()
)
a.shape

In [None]:
import pandas as pd
import numpy as np

# Example multi-label DataFrame
data = df

# Flatten to single-class DataFrame
cols = [
    col
    for col in data.columns
    if str(col).endswith("pred") or str(col).endswith("proba")
]
cols.append("y_true")
flattened_data = data.explode(cols).reset_index()
flattened_data["class_id"] = np.tile(range(len(data["y_true"][0])), len(data))
flattened_data

from sklearn.metrics import precision_score, recall_score, f1_score


# Function to compute metrics
def compute_metrics(df, metrics):
    results = []
    for k, fold_data in df.groupby("k"):
        for cls in fold_data["class_id"].unique():
            cls_data = fold_data[fold_data["class_id"] == cls]
            for clf in [col for col in df.columns if str(col).endswith("pred")]:
                metrics_result = {
                    metric.__name__: metric(
                        cls_data["y_true"].astype(int), cls_data[clf].astype(int)
                    )
                    for metric in metrics
                }
                metrics_result.update({"classifier": clf, "class_id": cls, "k": k})
                results.append(metrics_result)
    return pd.DataFrame(results)


metrics = [
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    average_precision_score,
]
metric_results = compute_metrics(flattened_data, metrics)
melted_results = metric_results.melt(
    id_vars=["class_id", "k", "classifier"],
    value_vars=[metric.__name__ for metric in metrics],
    var_name="metric",
    value_name="value",
)
melted_results["classifier"] = melted_results["classifier"].str.replace("_pred", "")
replacements = {
    "bsf": "Biosynfoni",
    "maccs": "MACCS",
    "morgan": "Morgan",
    "rdk": "RDKit",
}
melted_results["classifier"] = melted_results["classifier"].replace(replacements)

In [None]:
from matplotlib import lines as mlines

palette = colourDict["fp"]
g = sns.catplot(
    data=melted_results,
    x="classifier",
    y="value",
    col="class_id",
    row="metric",
    hue="classifier",
    kind="strip",  # Change to 'strip' for individual dots
    height=2,
    aspect=0.5,
    palette=palette,
    # dodge=True,  # Separate dots by classifier
    alpha=0.4,  # Make dots slightly transparent for better overlap visibility
    jitter=0.0,  # Add jitter to spread the dots
)

# Remove individual plot titles
g.set_titles("")
g.tick_params(
    axis="x",
    labelrotation=90,
    size=0,
)
g.tick_params(axis="y", size=0)
g.set_ylabels("")

# Customize plot
g.set_axis_labels("", "")
g.fig.subplots_adjust(hspace=0, wspace=0)

# show gridlines
for ax in g.axes.flatten():
    ax.grid(True, axis="y", linestyle="--", alpha=0.5)
    ax.spines["left"].set_linewidth(0.5)
    ax.spines["bottom"].set_linewidth(0.5)
    ax.spines["right"].set_linewidth(0.5)
    ax.spines["top"].set_linewidth(0.5)
    ax.spines["right"].set_visible(True)
    ax.spines["top"].set_visible(True)


# Add column titles (over the top of the plot grid)
for col_idx, col_name in enumerate(g.col_names):
    # get the x position of the axes
    xs = [(ax.get_position().x0 + ax.get_position().x1) / 2 for ax in g.axes[0, :]]
    text = id_to_class[col_idx].replace("_", "\n")
    if text == "phenylpropanoid":
        text = "phenyl-\npropanoid"
    g.fig.text(
        x=xs[col_idx],  # Center each title over the column
        y=g.axes[0, 0].get_position().y1 + 0.03,  # Position above the plots
        s=text,
        ha="center",
        va="top",
        fontsize=8,
        fontweight="regular",
    )

# Add row titles (on the left side of the grid)
for row_idx, row_name in enumerate(g.row_names):
    ys = [(ax.get_position().y0 + ax.get_position().y1) / 2 for ax in g.axes[:, 0]]
    g.fig.text(
        x=0.03,  # Position to the left of the plots
        y=ys[row_idx],
        s=row_name.replace("_", " "),
        ha="right",
        va="center",
        fontsize=8,
        fontweight="regular",
        rotation=90,
    )

# add a legend from ax[0,0]
legend_labels = melted_results["classifier"].unique()
handles = [
    mlines.Line2D(
        [],
        [],
        marker="o",
        color="w",
        markeredgecolor="none",
        markerfacecolor=palette[l],
        markersize=6,
        label=l,
    )
    for l in legend_labels
]

# Add the legend to the plot (using `fig.legend`)
# get x0 and y1 of the first axis
x0, y0 = g.axes[0, 0].get_position().x0, g.axes[0, 0].get_position().y0
g.fig.legend(
    handles=handles,
    title="classifier",
    loc="lower left",
    bbox_to_anchor=(x0 + 0.005, y0 + 0.005),
    frameon=True,
    fontsize=8,
    title_fontsize=8,
)
plt.savefig(fig_folder / "sm_metrics.png", dpi=300, bbox_inches="tight")

In [None]:
from matplotlib import lines as mlines

palette = colourDict["fp"]

plt.figure()
g = sns.catplot(
    data=melted_results[melted_results.metric == "f1_score"],
    x="classifier",
    y="value",
    col="class_id",
    hue="classifier",
    kind="strip",  # Change to 'strip' for individual dots
    height=1.7,
    aspect=0.5,
    palette=palette,
    s=40,
    # dodge=True,  # Separate dots by classifier
    alpha=0.4,  # Make dots slightly transparent for better overlap visibility
    jitter=0.0,  # Add jitter to spread the dots
)

# Remove individual plot titles
g.set_titles("")
g.tick_params(
    axis="x",
    labelrotation=90,
    size=0,
)
g.tick_params(axis="y", size=0)
g.set_ylabels("")

# Customize plot
g.set_axis_labels("", "F1 score")
g.fig.subplots_adjust(wspace=0)

# show gridlines
for ax in g.axes.flatten():
    ax.grid(True, axis="y", linestyle="--", alpha=0.5)
    ax.spines["left"].set_linewidth(0.5)


# Add column titles (over the top of the plot grid)
for col_idx, col_name in enumerate(g.col_names):
    # get the x position of the axes
    xs = [(ax.get_position().x0 + ax.get_position().x1) / 2 for ax in g.axes[0, :]]
    text = id_to_class[col_idx].replace("_", "\n")
    if text == "phenylpropanoid":
        text = "phenyl-\npropanoid"
    g.fig.text(
        x=xs[col_idx],  # Center each title over the column
        y=g.axes[0, 0].get_position().y1 + 0.18,  # Position above the plots
        s=text,
        ha="center",
        va="top",
        fontsize=8,
        fontweight="medium",
    )

# add a legend from ax[0,0]
legend_labels = melted_results["classifier"].unique()
handles = [
    mlines.Line2D(
        [],
        [],
        marker="o",
        color="w",
        markeredgecolor="none",
        markerfacecolor=list(palette.values())[i],
        markersize=7,
        label=legend_labels[i],
    )
    for i in range(len(legend_labels))
]

# Add the legend to the plot (using `fig.legend`)
# get x0 and y1 of the first axis
x0, y0 = g.axes[0, 0].get_position().x0, g.axes[0, 0].get_position().y0
g.fig.legend(
    handles=handles,
    title="classifier",
    loc="center left",
    bbox_to_anchor=(1, 0.5),
    frameon=True,
    fontsize=8,
    title_fontsize=8,
)

# set ylim
g.set(ylim=(0.71, 1.03))
# save the plot
plt.savefig(fig_folder / "f1_score.png", dpi=300)

### taxonomy

In [None]:
from pathlib import Path
import json
import numpy as np
import pandas as pd

# import PandasTools
import networkx as nx
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import PandasTools
from sklearn.metrics import classification_report
import sklearn.metrics as metrics
import matplotlib.pyplot as plt

tax_info = Path.home() / "taxonomy.tsv"


tax_info = pd.read_csv(tax_info, sep="\t\|\t", index_col=0)
tax_info["flags\t|"] = tax_info["flags\t|"].str.strip("\t|")
tax_info.rename(columns={"flags\t|": "flags"}, inplace=True)

tax_info


# get a tree of the taxonomies in tax_info with direction

G = nx.DiGraph()
for i, row in tqdm(tax_info.iterrows(), total=tax_info.shape[0]):
    G.add_node(i, **row.to_dict())
    if row["parent_uid"] != "null":
        G.add_edge(row["parent_uid"], i)


uid_name = tax_info["name"].to_dict()
name_uid = {v.lower(): k for k, v in uid_name.items()}

In [None]:
sdf_df = pd.read_csv(
    f"{Path().home()}/article_bsf/data/input/temp_coconut.csv", sep=","
)
taxonomies = sdf_df["organisms"].str.lower().str.split("|").explode()

taxonomies = taxonomies.map(name_uid)

In [None]:
# i_f = Path.home() / "article_bsf" / "data" / "input"
# sdf_df = PandasTools.LoadSDF(i_f / "coconut.sdf", )

In [None]:
# taxonomies = sdf_df["organisms"].str.lower().str.split("|").unique().explode()
# taxonomies


# taxonomies

In [None]:
# tax_info[tax_info["name"].str.contains("bacteria")]
# tax_info[tax_info["rank"].str.contains("kingdom")]
tax_info[tax_info["rank"] == "kingdom"]

# get children of a node
# list(G.successors(5267059))

In [None]:
# tax_info[tax_info["name"] =="Bacteria"]
# tax_info.head(30)
tax_info[tax_info["rank"] == "domain"]

In [None]:
[uid_name[x] for x in list(G.successors(844192))]

In [None]:
def get_by_rank(uid, G, rank="kingdom"):
    # go up the tree until you find a kingdom
    while "rank" in G.nodes[uid].keys() and G.nodes[uid]["rank"] != rank:
        uid = list(G.predecessors(uid))[0]
    return G.nodes[uid]


kingdoms = [
    (get_by_rank(taxonomy[0], G) if taxonomy else None) for taxonomy in taxonomies
]
domains = [
    (get_by_rank(taxonomy[0], G, rank="domain") if taxonomy else None)
    for taxonomy in taxonomies
]
# kingdoms
# G.nodes[taxonomies[1][0]]['rank']

In [None]:
kingdoms = [(kingdom["name"] if kingdom else None) for kingdom in kingdoms]
domains = [(domain["name"] if domain else None) for domain in domains]

In [None]:
len(kingdoms), len(sdf_df)
kingdoms = pd.Series(kingdoms, name="kingdom", index=sdf_df.index)
kingdoms["domains"] = domains

kingdoms.to_csv(i_f / "temp_coconut_tax.csv", header=True)
kingdoms.value_counts()

In [None]:
kingdoms.value_counts()

In [None]:
o_f = Path.home() / "article_bsf" / "output"

y = np.loadtxt(o_f / "y.csv", delimiter=",", dtype=int)
ids = np.loadtxt(o_f / "tax_ids.tsv", delimiter=",", dtype=str)
y_proba = np.loadtxt(o_f / "tax_bsf_proba.tsv", delimiter="\t", dtype=int)
y_pred = (y_proba > 0.5).astype(int)

cl_idx = o_f / "tax_class_labels.json"
with open(cl_idx, "r") as f:
    class_labels = json.load(f)

# get the classification report
report = classification_report(y, y_pred, target_names=class_labels)

In [None]:
X = np.loadtxt(
    Path.home() / "article_bsf" / "fps" / "coconut_bsf.csv", delimiter=",", dtype=int
)
y = np.loadtxt(
    Path.home() / "article_bsf" / "data" / "input" / "coconut_taxonomy.csv",
    delimiter=",",
    dtype=str,
)
ids, y = y[:, 0], y[:, 1]

idx = np.where(y != "")

X = X[idx]
y = y[idx]
ids = ids[idx]

# y has multiple classes, so we need to split them and then turn it into a binary matrix
from sklearn.preprocessing import MultiLabelBinarizer

mlb = MultiLabelBinarizer()
y = mlb.fit_transform(y.split(";") for y in y)
y
class_labels = mlb.classes_
class_labels

In [None]:
# {'bootstrap': False, 'max_depth': 50, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 10, 'n_estimators': 50}
0.7972974395381592

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)


param_grid = {
    "n_estimators": [50, 100, 200, 500],
    "max_depth": [20, 50, 100],
    "min_samples_split": [2, 5, 10],
    "min_samples_leaf": [1, 2, 4],
}

rf = RandomForestClassifier()
grid_search = GridSearchCV(
    estimator=rf, param_grid=param_grid, cv=5, n_jobs=-1, verbose=0
)
grid_search.fit(X_train, y_train)

print(grid_search.best_params_)
print(grid_search.best_score_)
print(grid_search.best_estimator_)

# get the classification report
y_pred = grid_search.predict(X_test)
report = classification_report(y_test, y_pred, target_names=class_labels)
print(report)

In [None]:
# get the classification report for the best model
y_pred = grid_search.predict(X_test)
report = classification_report(y_test, y_pred)

# get the confusion matrix
cm = metrics.confusion_matrix(y_test, y_pred)

# plot the confusion matrix
import seaborn as sns

sns.heatmap(cm, annot=True, fmt="d")

In [None]:
# now try other classifiers
from sklearn.svm import SVC
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier

classifiers = {
    "SVC": SVC(),
    "GradientBoostingClassifier": GradientBoostingClassifier(),
    "LogisticRegression": LogisticRegression(),
    "MLPClassifier": MLPClassifier(),
}

for name, classifier in classifiers.items():
    classifier.fit(X_train, y_train)
    y_pred = classifier.predict(X_test)
    report = classification_report(y_test, y_pred, target_names=class_labels)
    print(report)

    # get the confusion matrix
    cm = metrics.confusion_matrix(y_test, y_pred)

    # plot the confusion matrix
    import seaborn as sns

    sns.heatmap(cm, annot=True, fmt="d")
    plt.title(name)

# Supplementary

## 1. Coverage

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

sdf_path = f"{Path().home()}/article_bsf/data/input/coconut.sdf"
props_path = f"{Path().home()}/article_bsf/data/input/coconut_properties.csv"

sizes = pd.read_csv(props_path, index_col=0)["heavy_atom_count"].values
coverages = np.loadtxt(
    f"{Path().home()}/article_bsf/fps/coconut_coverage.csv", delimiter=","
)


cmap = sns.color_palette("mako_r", as_cmap=True)
print("average coverage", np.mean(coverages))


df = pd.DataFrame.from_dict({"sizes": sizes, "coverages": coverages}, orient="index").T
df = df[df.sizes < 100]
plt.hexbin(
    df.sizes, df.coverages, gridsize=50, cmap=cmap, edgecolors="white", linewidths=0.1
)
plt.colorbar(label="Number of molecules")
plt.xlabel("Molecular size")
plt.ylabel("Coverage")
plt.title("Biosynfoni's coverage of natural products")
plt.savefig(fig_folder / "sm_coverage_vs_size_cropped.png", dpi=300, bbox_inches="tight")

df = pd.DataFrame.from_dict({"sizes": sizes, "coverages": coverages}, orient="index").T
plt.hexbin(
    df.sizes, df.coverages, gridsize=50, cmap=cmap, edgecolors="white", linewidths=0.1
)
plt.colorbar(label="Number of molecules")
plt.xlabel("Molecular size")
plt.ylabel("Coverage")
plt.title("Biosynfoni's coverage of natural products")
plt.savefig(fig_folder / "sm_coverage_vs_size_full.png", dpi=300, bbox_inches="tight")

mean_coverage = np.mean(coverages)
std_coverage = np.std(coverages)


plt.figure(figsize=(2, 3))
sns.violinplot(data=coverages, inner="box", edgecolor="none", alpha=0.6, density_norm="width");
plt.xlabel("coconut dataset")
plt.ylabel("atom coverage")
plt.savefig(fig_folder / "sm_coverage_violinplot.png", dpi=300, bbox_inches="tight")

# fig, ax = plt.subplots()
# plt.xlim(0, 200)
# plt.ylim(0, 1)
# size_coverage = np.array(
#     [(s, c) for s, c in zip(sizes, coverages)],
#     dtype=[("size", int), ("coverage", float)],
# )

# sns.scatterplot(
#     data=pd.DataFrame(size_coverage),
#     x="size",
#     y="coverage",
#     ax=ax,
#     edgecolor="none",
#     alpha=0.1,
#     s=4,
# )

## Figure 1 - Chemical space of datasets

In [None]:
from pathlib import Path


import pandas as pd
import numpy as np
from rdkit import Chem
import seaborn as sns

import umap

In [None]:
maccs_path = f"{Path().home()}/article_bsf/coconut_maccs.csv"
maccs = np.loadtxt(maccs_path, delimiter=",", dtype=int)
label = pd.read_csv(
    f"{Path().home()}/article_bsf/data/raw_data/coconut_complete-10-2024.csv"
)

In [None]:
# UMAP
reducer = umap.UMAP(n_components=2, random_state=42)
random_subset = np.random.choice(maccs.shape[0], 10000, replace=False)
subset_maccs = maccs[random_subset]
subset_labels = label.iloc[random_subset]

In [None]:
embedding = reducer.fit_transform(subset_maccs)

In [None]:
sns.scatterplot(
    x=embedding[:, 0], y=embedding[:, 1], hue=subset_labels["rotatable_bond_count"]
)  # ,  legend="full")

In [None]:
# get the indices of points that have embedding x < -10 and 0 < y < 5
points = embedding[
    (embedding[:, 0] < -10) & (0 < embedding[:, 1]) & (embedding[:, 1] < 5)
]
indices = np.where(
    (embedding[:, 0] < -10) & (0 < embedding[:, 1]) & (embedding[:, 1] < 5)
)[0]


of_interest = subset_labels.iloc[indices]

# draw molecules from smiles
mols = [Chem.MolFromSmiles(smiles) for smiles in of_interest["canonical_smiles"]]

Chem.Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(200, 200))

In [None]:
# plot the molecules corresponding to the close points
sdf = Chem.SDMolSupplier(sdf_path)

for i, j in zip(*close_points):
    mol_i = sdf[int(random_subset[i][0])]
    mol_j = sdf[int(random_subset[j][0])]
    break
mol_i, mol_j

In [None]:
mol_i

In [None]:
mol_j

## 5. Biosynthetic distance visualisations

In [None]:
def cleanfmt(text):
    """
    Clean a string or list of strings to be used as labels in a plot

        Args:
            text (str or list): text to clean

        Returns:
            str or list: cleaned text

    Remarks:
        - replaces underscores with spaces
        - makes all text lowercase
    """
    if isinstance(text, str):
        return text.replace("_", " ").lower()
    elif isinstance(text, list):
        newtext = []
        for t in text:
            if isinstance(t, str):
                newtext.append(t.replace("_", " ").lower())
            else:
                newtext.append(t)
        return newtext
    else:
        return text


def _set_ax_boxplot_i_colour(
    ax_boxplot: mpl.container.BarContainer,
    i: int,
    colour: str,
    inner_alpha: float = 0.6,
):
    """
    Set the colour of a boxplot element

        Args:
            ax_boxplot (matplotlib.container.BarContainer): the boxplot to change
            i (int): the index of the element to change
            colour (str): the colour to change to
            inner_alpha (float): the alpha of the inner colour, optional. Default is 0.6

        Returns:
            matplotlib.container.BarContainer: the changed boxplot

    """
    translucent = mpl.colors.to_rgba(colour, inner_alpha)

    ax_boxplot["boxes"][i].set_facecolor(translucent)
    ax_boxplot["boxes"][i].set_edgecolor(colour)
    ax_boxplot["medians"][i].set_color(colour)
    ax_boxplot["whiskers"][i * 2].set_color(colour)
    ax_boxplot["whiskers"][i * 2 + 1].set_color(colour)
    ax_boxplot["caps"][i * 2].set_color(colour)
    ax_boxplot["caps"][i * 2 + 1].set_color(colour)
    ax_boxplot["fliers"][i].set_markeredgecolor(translucent)
    return ax_boxplot


def scatter_boxplots(
    df: pd.DataFrame,
    col_x: str,
    col_y: str,
    color_by: str = "stepnum",
    *args,
    **kwargs,
) -> plt.Figure:
    """
    Make a scatterplot with boxplots on the axes

        Args:
            df (pd.DataFrame): dataframe to plot
            col_x (str): column to plot on the x-axis
            col_y (str): column to plot on the y-axis
            figtitle (str): title of the figure
            color_by (str): column to colour by, optional. Default is "stepnum"
            *args: other arguments to pass to scatterplot
            **kwargs: other keyword arguments to pass to scatterplot
        Returns:
            plt.Figure: the figure

    """
    # make a square figure
    fig = plt.figure(figsize=(5, 5))
    # fig, ax = plt.subplots()
    # add gridspec for subplots

    gs = fig.add_gridspec(
        2,
        2,
        width_ratios=(4, 1),
        height_ratios=(1, 4),
        left=0.1,
        right=0.9,
        bottom=0.1,
        top=0.9,
        wspace=-1,
        hspace=-5,
    )

    sc_ax = fig.add_subplot(gs[1, 0])
    legax = fig.add_subplot(gs[0, 1])

    # Set aspect of the Axes manually to have points on 0 and 1 show better
    # ax.set_xlim(-0.05, 1.05)
    # ax.set_ylim(-0.05, 1.05)

    # Get Data
    all_data_x = [
        np.array(df[df[color_by] == category][col_x].to_numpy(dtype=float))
        for category in df[color_by].unique()
    ]
    all_data_x = [x[~np.isnan(x)] for x in all_data_x]
    all_data_y = [
        np.array(df[df[color_by] == category][col_y].tolist())
        for category in df[color_by].unique()
    ]
    all_data_y = [y[~np.isnan(y)] for y in all_data_y]

    top_bp_ax = fig.add_subplot(gs[0, 0], sharex=sc_ax)
    right_bp_ax = fig.add_subplot(gs[1, 1], sharey=sc_ax)

    top_bp_ax.tick_params(length=0, labelbottom=False, labelsize=5)
    right_bp_ax.tick_params(length=0, labelrotation=90, labelleft=False, labelsize=5)
    legax.tick_params(length=0, labelleft=False, labelbottom=False, labelsize=0)
    sc_ax.tick_params(length=0)

    labels = [
        f"{category}" if category != "-1" else "control"
        for category in df[color_by].unique()
    ]
    # make boxplots where the boxes are 5px apart
    xplot = top_bp_ax.boxplot(
        all_data_x,
        vert=False,
        patch_artist=True,
        labels=labels,
        positions=[0.6 * i for i in range(len(all_data_x))],
    )
    yplot = right_bp_ax.boxplot(
        all_data_y,
        vert=True,
        patch_artist=True,
        labels=labels,
        positions=[0.6 * i for i in range(len(all_data_y))],
    )

    i = 0
    for category in df[color_by].unique()[::-1]:
        colour = colourDict[color_by][category]

        label = f"{category}" if category != "-1" else "random pairs"

        scatterplot = sc_ax.scatter(
            x=col_x,
            y=col_y,
            data=df[df[color_by] == category],
            c=colour,
            label=label,
            alpha=0.5,
            edgecolors="none",
            zorder=3,
            *args,
            **kwargs,
        )

    for category in df[color_by].unique():
        colour = colourDict[color_by][category]
        alpha = 0.8
        _set_ax_boxplot_i_colour(xplot, i, colour, inner_alpha=alpha)
        _set_ax_boxplot_i_colour(yplot, i, colour, inner_alpha=alpha)
        label = f"{category}" if category != "-1" else "random pairs"

        # scatter empty df, to get legend in right format in right position
        leg = legax.scatter(
            x=col_x,
            y=col_y,
            data=df[df[color_by] == category][0:0],
            c=colour,
            label=label,
            alpha=0.6,
            edgecolors="none",
            s=10,
        )
        leg.set_facecolor(mpl.colors.to_rgba(colour, alpha=alpha))

        i += 1

    # ==================================================

    legax.legend(loc="lower left", prop={"size": 6}, frameon=False)

    # # info for square drawing
    # squareside = 0.2
    # s_color = "#7A7979AA"
    # s_color = mpl.colors.to_rgba("#7A7979AA", alpha=0.3)
    # linewidth = 1

    # ax.set_xticklabels([0,0.2,0.4,0.6,0.8,1.0])
    sc_ax.set_xlabel(cleanfmt(col_x), labelpad=10)
    sc_ax.set_ylabel(cleanfmt(col_y), labelpad=10)
    # ax_xobs[0].set_title(figtitle, loc="center", pad=20)

    sc_ax.grid(True, alpha=0.3, linewidth=0.5, mouseover=True)
    gs.tight_layout(fig)

    return fig


df = pd.read_csv(
    f"{Path().home()}/article_bsf/output/biosynthetic_distances.tsv", sep="\t"
)  # , index_col=0)
df.separation = df.separation.astype(str).replace("control", "-1")
df.replace(
    -1, 0, inplace=True
)  # any for which we could not calculate a similarity, we set to 0


for fp_name in ["maccs", "morgan", "rdk"]:
    fig = scatter_boxplots(
        df,
        "bsf",
        fp_name,
        color_by="separation",
    )
    
    fig.savefig(fig_folder / f"sm_biosynthetic_scatterplot_{fp_name}.png", dpi=300, bbox_inches="tight")


In [None]:
# def scatter_boxplots(df, x, y, color_by="separation", *args, **kwargs):
#     fig = plt.figure(figsize=(5, 5))
#     gs = fig.add_gridspec(2, 2, wspace=-1)
#     sc_ax = fig.add_subplot(gs[1, 0])
#     legax = fig.add_subplot(gs[0, 1])
#     color_by = "separation"

#     categories = df[color_by].unique().tolist()
#     x_data = [
#         np.array(df[df[color_by] == cat][x].to_numpy(dtype=float)) for cat in categories
#     ]
#     x_data = [k[~np.isnan(k)] for k in x_data]
#     y_data = [np.array(df[df[color_by] == cat][y].tolist()) for cat in categories]
#     y_data = [k[~np.isnan(k)] for k in y_data]

#     top_bp_ax = fig.add_subplot(gs[0, 0], sharex=sc_ax)
#     right_bp_ax = fig.add_subplot(gs[1, 1], sharey=sc_ax)

#     top_bp_ax.tick_params(length=0, labelbottom=False, labelsize=5)
#     right_bp_ax.tick_params(length=0, labelrotation=-30, labelleft=False, labelsize=5)
#     legax.tick_params(length=0, labelleft=False, labelbottom=False, labelsize=0)
#     sc_ax.tick_params(length=0)

#     # labels = df[color_by].unique()

#     xplot = top_bp_ax.boxplot(
#         x_data,
#         vert=False,
#         patch_artist=True,
#         labels=categories,
#         positions=[0.6 * i for i in range(len(x_data))],
#     )
#     yplot = right_bp_ax.boxplot(
#         y_data,
#         vert=True,
#         patch_artist=True,
#         labels=categories,
#         positions=[0.6 * i for i in range(len(y_data))],
#     )
#     [tick.set_rotation(90) for tick in right_bp_ax.get_xticklabels()]

#     # colormap is mako, get evenly spaced colours
#     cat_ = [cat for cat in categories if cat != "random pairs"]
#     colours = {
#         cat: col
#         for cat, col in zip(
#             cat_,
#             sns.cubehelix_palette(
#                 start=0.5, rot=-0.75, reverse=True, n_colors=len(cat_) + 2
#             )[1:-1],
#         )
#     }
#     colours["random pairs"] = "#797979"

#     for category in categories[::-1]:
#         sc_ax.scatter(
#             x=x,
#             y=y,
#             data=df[df[color_by] == category],
#             c=colours[category],
#             label=str(category),
#             alpha=0.8,
#             edgecolors="none",
#             # zorder=3,
#         )

#     for i, category in enumerate(categories):
#         # alpha = 0.6
#         alpha = 0.8
#         _set_ax_boxplot_i_colour(xplot, i, colours[category], inner_alpha=alpha)
#         _set_ax_boxplot_i_colour(yplot, i, colours[category], inner_alpha=alpha)

#         leg = legax.scatter(
#             x=x,
#             y=y,
#             data=df[df[color_by] == category][0:0],
#             color=colours[category],
#             label=str(category),
#             # alpha=0.5,
#             edgecolors="none",
#             s=10,
#         )
#         leg.set_facecolor(mpl.colors.to_rgba(colours[category], alpha=alpha))

#     legax.legend(loc="lower left", prop={"size": 6}, frameon=False)

#     sc_ax.set_xlabel(cleanfmt(x), labelpad=10)
#     sc_ax.set_ylabel(cleanfmt(y), labelpad=10)
#     sc_ax.grid(True, alpha=0.3, linewidth=0.5, mouseover=True)

#     gs.tight_layout(fig)

#     return fig

In [None]:
# # later for visualisation:

# def get_square(
#     df: pd.DataFrame,
#     col1: str,
#     col2: str,
#     range1: tuple[float, float],
#     range2: tuple[float, float],
# ) -> pd.DataFrame:
#     """returns the molecule pair's inchis for 'dots' in a given square of the scatter plot"""
#     square = df.loc[
#         (df[col1] >= range1[0])
#         & (df[col1] <= range1[1])
#         & (df[col2] >= range2[0])
#         & (df[col2] <= range2[1])
#     ]
#     return square


# def draw_molpair(
#     pair: list[Chem.Mol], annotation: str = "", highlighting: bool = True
# ) -> None:
#     for i in range(len(pair)):
#         highlighting_info = None
#         if highlighting:
#             highlighting_info = get_highlight_mapping(mol=pair[i])
#         svg_text = moldrawing.draw(
#             pair[i], highlight_atoms_bonds_mappings=highlighting_info
#         )
#         if annotation:
#             svg_text = svg_text.replace(
#                 "</svg>",
#                 f'<text x="30" y="30" font-size="20" font-family="montserrat">{annotation}</text></svg>',
#             )
#         with open(f"{annotation}_{i}.svg", "w") as f:
#             f.write(svg_text)
#     return None


# def draw_squares(
#     square_df: pd.DataFrame,
#     pair_columns: tuple[str, str] = ("mol1", "mol2"),
#     squarename: str = "origin",
#     highlighting: bool = True,
# ) -> None:
#     """draws the molecules in the squares
#     input: (pd.DataFrame) square_df -- the dataframe containing the squares
#     (str) pair_columns -- the name of the column containing the molecule pairs
#     (str) squarename -- the name of the square
#     """
#     if square_df.empty:
#         return None
#     for _, row in tqdm(
#         square_df.iterrows(),
#         desc=f"drawing {squarename} squares",
#         total=square_df.shape[0],
#         position=1,
#     ):
#         pair = [row[pair_columns[0]], row[pair_columns[1]]]
#         pathway = row["pathway"]
#         outfilename = outfile_namer(f"{squarename}_{pathway}")
#         draw_molpair(pair, annotation=outfilename, highlighting=highlighting)
#     return None


# def loopsquares(
#     df: pd.DataFrame,
#     x_fp: str = "biosynfoni",
#     y_fps: list[str] = ["rdkit", "maccs", "morgan"],
#     size: int = 0.2,
# ) -> None:
#     for i, y_fp in tqdm(
#         enumerate(y_fps), desc="looping squares", leave=False, position=0
#     ):
#         _, iwd = output_direr(f"./{x_fp}_{y_fp}_squares")
#         min_val, max_val = 0.0, 1.0
#         min_border = 0.0 + size
#         max_border = 1.0 - size
#         left_bottom = get_square(
#             df,
#             x_fp,
#             y_fp,
#             (min_val, min_border),
#             (min_val, min_border),
#         )
#         left_top = get_square(
#             df,
#             x_fp,
#             y_fp,
#             (min_val, min_border),
#             (max_border, max_val),
#         )
#         right_bottom = get_square(
#             df,
#             x_fp,
#             y_fp,
#             (max_border, max_val),
#             (min_val, min_border),
#         )
#         right_top = get_square(
#             df,
#             x_fp,
#             y_fp,
#             (max_border, max_val),
#             (max_border, max_val),
#         )
#         exactly_middle = get_square(
#             df,
#             x_fp,
#             y_fp,
#             (0.5, 0.5),
#             (min_val, max_val),
#         )
#         draw_squares(left_bottom, squarename=f"{y_fp}_origin")
#         draw_squares(left_top, squarename=f"{y_fp}_left_top")
#         draw_squares(right_bottom, squarename=f"{y_fp}_right_bottom")
#         draw_squares(right_top, squarename=f"{y_fp}_right_top")
#         if i == 0:
#             draw_squares(exactly_middle, squarename=f"{x_fp}_middle")
#         os.chdir(iwd)
#     return None


# def biosynthetic_distance_analysis(pairs_df, metric):
#     df = pd.read_csv(structures_path, sep="\t", header=None, index_col=0)
#     logging.info(
#         f"{old_n_rows[0]-df.shape[0]} pathways dropped due to lack of mol pairs\n\n"
#     )

#     _, iwd = output_direr("./biosynthetic_distance")  # move to outputdir

#     pairs = pairs_per_separation(df)

#     dist_df = f"{outfile_namer(metric)}.tsv"
#     mols_pkl = f"{outfile_namer('mols')}.pkl"
#     if not os.path.exists(dist_df) or not os.path.exists(mols_pkl):
#         df = add_fp_to_df(
#             pairs,
#             fp_types=FP_FUNCTIONS.keys(),
#         )
#         df = get_all_similarity_scores(df, metric=metric)
#         mols = df.copy()
#         # save mols as pickle
#         mols.to_pickle(f"{outfile_namer('mols')}.pkl")
#         # remove mols from df
#         df = df.drop(columns=["mol1", "mol2"])
#         df.to_csv(dist_df, sep="\t", index=True)

#     mols = pd.read_pickle(f"{outfile_namer('mols')}.pkl")
#     df = mols
#     df["pathway"] = df.index

#     logging.debug(df.shape, mols.shape, df.columns)
#     logging.debug(df, pairs, pairs.columns)

#     # if args.annotate:
#     #     # pw_tax_file = "../../../metacyc/pathways_taxid.txt"
#     #     # tax_text_file = "../../../metacyc/cleaner_classes.dat"
#     #     pw_tax_file, tax_text_file = args.annotate
#     #     annotated_df = annotate_pathways(comparison_df, pw_tax_file, tax_text_file)

#     df["stepnum"] = df["separation"].apply(str)

#     logging.info("getting scatterplots...")
#     fp_combs = list(itertools.combinations(fp_names, 2))
#     for combination in tqdm(fp_combs, desc="getting scatterplots"):
#         scatter = fm.scatter_boxplots(
#             df,
#             col_x=combination[0],
#             col_y=combination[1],
#             figtitle=f"{args.metric} for different reaction step numbers",
#             color_by="stepnum",
#         )
#         filename = outfile_namer(f"{combination[0]}_{combination[1]}_{args.metric}.png")
#         fm.savefig(scatter, filename)

#     onestep = df[df["stepnum"] == "1"]
#     onestep.to_csv(f'{outfile_namer("onestep")}.tsv', sep="\t", index=False)
#     logging.info("getting squares...")
#     for fp in [
#         "biosynfoni",
#         "overlap_binosynfoni",
#         "overlap_biosynfoni",
#         "interoverlap_biosynfoni",
#     ]:
#         loopsquares(
#             onestep,
#             fp,
#             ["rdkit", "maccs", "morgan", "maccsynfoni", "overlap"],
#             size=0.2,
#         )
#     # loopsquares(
#     #     onestep,
#     #     "biosynfoni",
#     #     ["rdkit", "maccs", "morgan", "maccsynfoni", "overlap"],
#     #     size=0.2,
#     # )

#     # save the biosynfoni version for reference
#     logging.info("saving current biosynfoni version...")
#     save_version(defaultVersion)

#     os.chdir(iwd)
#     logging.info("done\nbyebye")
#     exit(0)
#     return None

## 2. Fingerprint heatmaps

In [None]:
import sys, os, logging
import argparse

import numpy as np
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

plt.ioff()


from biosynfoni.inoutput import outfile_namer
from biosynfoni.subkeys import fpVersions, defaultVersion, get_names, get_pathway
from figures import (
    heatmap,
    annotate_heatmap,
    savefig,
    set_label_colors_from_categories,
    custom_cmap,
)


# def count_distributions(coco, zinc, substructure_names):
#     """WIP: Plots substructure count distribution for coco and zinc"""
#     npcs = np.loadtxt(
#         "npcs.tsv", dtype="str", delimiter="\t"
#     )  # just added, not checked
#     s_coco = coco[npcs[:, 0] == "Alkaloids"]
#     # random subsample of zinc
#     np.random.seed(42)
#     s_zinc = zinc[np.random.choice(zinc.shape[0], size=s_coco.shape[0], replace=False)]

#     for i in range(3, len(substructure_names)):
#         # np.histogram(coco[:,i])
#         # print(np.mean(coco[:,i]))
#         fig = plt.figure()
#         nonzero = s_coco[:, i][s_coco[:, i] > 0]
#         if np.max(nonzero) == 0:
#             continue
#         n, bins, edges = plt.hist(
#             nonzero,
#             bins=np.max(nonzero) - 1,
#             color="green",
#             alpha=0.7,
#             histtype="step",
#             align="left",
#         )

#         plt.title(
#             f"substructure counts for {substructure_names[i]}, {len(nonzero)} nonzero values"
#         )
#         plt.xticks(bins)
#         plt.xlabel("substructure counts")
#         plt.ylabel("number of compounds")
#         plt.tight_layout()

#     for i in range(3, len(substructure_names)):
#         # np.histogram(coco[:,i])
#         # print(np.mean(zinc[:,i]))
#         fig = plt.figure()
#         nonzero = s_zinc[:, i][s_zinc[:, i] != 0]
#         if np.max(nonzero) < 2:
#             continue
#         n, bins, edges = plt.hist(
#             nonzero,
#             bins=np.max(nonzero) - 1,
#             color="purple",
#             alpha=0.7,
#             rwidth=1,
#             histtype="step",
#             align="mid",
#         )
#         plt.title(
#             f"histogram of substructure counts for {substructure_names[i]}, {len(nonzero)} nonzero values"
#         )
#         plt.xticks(bins)
#         plt.xlabel("substructure counts")
#         plt.ylabel("number of compounds")

#     plt.close()
#     return None


def heatmap_array(
    fps: np.array,
    max_height: int = 30,
    percentages=False,
    accumulative=True,
    end_accumulative=False,
):
    """
    Make an array for a heatmap of fingerprint count per substructure

        Args:
            fps (np.array): fingerprint array
            max_height (int): maximum height of heatmap
            percentages (bool): whether to return in percentages
            accumulative (bool): whether to return accumulative counts
            end_accumulative (bool): whether to return accumulative counts only for the last height
        Returns:
            np.array: array for heatmap (height x substructures)

    Remarks:
        - if accumulative is True, then the heatmap will show the number of compounds that have at least that many substructures
        - if accumulative is False, then the heatmap will show the number of compounds that have exactly that many substructures
        - if end_accumulative is True, then the heatmap will show the number of compounds that have at least that many substructures for the last height

    """
    heat_array = np.zeros((max_height, fps.shape[1]))
    for i in range(max_height):
        if accumulative:
            countrow = np.count_nonzero(fps > i, axis=0)
        else:
            if end_accumulative and i == max_height - 1:
                # for last height, count all remaining values
                countrow = np.count_nonzero(fps > i, axis=0)
            else:
                countrow = np.count_nonzero((fps == i + 1), axis=0)
        heat_array[max_height - 1 - i] = countrow

    if percentages:
        heat_array = heat_array / fps.shape[0] * 100
    return heat_array.astype(int)


def fp_heatmap(
    fp_hm_array: np.array,
    subslabels: list = [],
    size: tuple[int] = (10, 6),
    percentages: bool = False,
    annotate: bool = False,
    color_scheme: str = "Purples",
    title: str = "Representative substructure count for compound collection",
    top_acc_array=None,
    standard_colour: bool = False,
):
    """
    Plot a heatmap of fingerprint count per substructure

        Args:
            fp_hm_array (np.array): array for heatmap (height x substructures)
            subslabels (list): list of substructure labels
            size (tuple): size of plot
            percentages (bool): whether to return in percentages
            annotate (bool): whether to annotate the heatmap
            color_scheme (str): colour scheme for heatmap
            title (str): title of plot
            top_acc_array (np.array): array for heatmap of top accumulative counts.
                                        if None, then no top accumulative counts will be plotted.
                                        default is None.
            standard_colour (bool): whether to colour substructure labels according to biosynfoni pathway
        Returns:
            matplotlib.figure.Figure: figure of heatmap
    """
    cbarlab = "number of compounds"
    if percentages:
        cbarlab = "% of compounds"

    height = fp_hm_array.shape[0]
    fig, ax = plt.subplots(figsize=size, dpi=500)
    if not subslabels:
        subslabels = [f"subs{i}" for i in range(1, fp_hm_array.shape[1] + 1)]
    subslabels = [x.replace("_", " ") for x in subslabels]

    yaxlabels = [(height + 1 - i) for i in range(1, height + 1)]
    if top_acc_array is not None:
        yaxlabels[0] = f"≥{height}"
        maxtop = top_acc_array[~np.isnan(top_acc_array)].max()
        maxfp = fp_hm_array[~np.isnan(fp_hm_array)].max()
        maxval = max(maxtop, maxfp)
        im2, cbar2 = heatmap(
            top_acc_array,
            # ['>11']+[(height+1-i) for i in range(1, height + 1)],
            yaxlabels,
            subslabels,
            ax=ax,
            cmap=custom_cmap("Greys", first_color="#ffffff00"),
            # cmap = "PiYG",
            cbar_kw={
                "drawedges": False,
                "shrink": 0.3,
                "pad": -0.05,
                "aspect": 10,
            },
            vmin=0,
            vmax=maxval,
        )
        # rotate cbar labels -90
        cbar2.set_label(f"{cbarlab} ≥{height}", rotation=90, va="bottom", labelpad=15)

    im, cbar = heatmap(
        fp_hm_array,
        # [(height+1-i) for i in range(1, height + 1)],
        yaxlabels,
        subslabels,
        ax=ax,
        cmap=custom_cmap(color_scheme, first_color="#ffffff00"),
        # cmap = "PiYG",
        cbarlabel=cbarlab,
        vmin=0,
        cbar_kw={"drawedges": False, "shrink": 0.3, "pad": 0.02, "aspect": 10},
    )
    cbar.set_label(f"{cbarlab}", rotation=90, va="bottom", labelpad=15)

    # texts = annotate_heatmap(im, valfmt="{x:.1f}")
    if annotate:
        texts = annotate_heatmap(im, valfmt="{x:.0f}", size=7)
    if standard_colour:
        set_label_colors_from_categories(
            ax.get_xticklabels(),
            get_pathway(version=defaultVersion),
            colourDict["pathways"],
        )
    # plt.figure(figsize=(10,6))
    ax.set_xlabel("substructure", labelpad=10)
    ax.set_ylabel("counts", labelpad=10)
    ax.set_title(title, loc="center", pad=20)
    fig.tight_layout()
    return fig


def over_under_divide(fps: np.array, limit: int = 10, percentages: bool = True):
    """
    Divide the heatmap array into two arrays: one for values under the limit, and one for values over the limit.
    """
    full = heatmap_array(
        fps,
        max_height=limit + 1,
        percentages=percentages,
        accumulative=False,
        end_accumulative=True,
    )
    under, over = full.astype(float).copy(), full.astype(float).copy()
    under[0] = np.nan
    over[1:] = np.nan
    return under, over


def fp_heatmap_accumulative(fp_arr: np.array, limit: int = 10, *args, **kwargs):
    """
    Make a heatmap of fingerprint count per substructure, with accumulative end counts

        Args:
            fp_arr (np.array): fingerprint array
            limit (int): maximum height of heatmap
        Returns:
            matplotlib.figure.Figure: figure of heatmap

    Remarks:
        - the heatmap will show the number of compounds that have at least that many substructures for the last height
        - this helps reduce the height of the heatmap, as the top accumulative counts are often much higher than the rest
    """
    under, over = over_under_divide(fp_arr, limit, percentages=True)
    hm = fp_heatmap(
        under,
        *args,
        percentages=True,
        top_acc_array=over,
        **kwargs,
    )
    return hm

In [None]:
fps = np.loadtxt(folder / "fps" / "coconut_bsf.csv", dtype=int, delimiter=",")
substructure_names = get_names(version=defaultVersion)
if fps.shape[1] != len(substructure_names):
    substructure_names = [str(i + 1) for i in range(fps.shape[1])]

# # binary heatmap
# fp_heatmap(
#     heatmap_array(fps, max_height=1, percentages=True, accumulative=False),
#     subslabels=substructure_names,
#     title=f"Distribution of {fp_name} substructure counts",
#     color_scheme="Greys",
#     percentages=True,
#     size=(15, 1),
#     standard_colour=True,
# )


hm = fp_heatmap_accumulative(
    fps,
    limit=10,
    title=f"Biosynfoni QR of all compounds in COCONUT",
    subslabels=substructure_names,
    color_scheme="GnBu",
    standard_colour=True,
)

hm.savefig(fig_folder / "sm_qr_coconut.png", dpi=300, bbox_inches="tight")


In [None]:
fps = np.loadtxt(folder / "fps" / "chebi_bsf.csv", dtype=int, delimiter=",")
classes = np.loadtxt(
    folder / "data" / "input" / "chebi_classes.csv",
    dtype="str",
    delimiter=",",
    usecols=1,
)
classes[classes == "fatty_acid;isoprenoid"] = "isoprenoid"
classes = np.where(np.core.defchararray.find(classes, ";") != -1, "multiple", classes)
classes[classes == ""] = "None"

assert len(classes) == fps.shape[0], "check classes file"

hms = []
for classif in np.unique(classes):
    idx = np.where(classes == classif)
    focus = fps[idx]
    if not classif:
        classif = "None"


    hm =    fp_heatmap_accumulative(
            focus,
            limit=10,
            title=f"Biosynfoni QR of all {len(focus)} {classif.replace('_', ' ')} compounds",
            subslabels=substructure_names,
            color_scheme="GnBu",
            standard_colour=True,
        )
    hms.append(hm)
    hm.savefig(fig_folder / f"sm_qr_chebi_{classif}.png", dpi=300, bbox_inches="tight")


## 4. chemical space visualisation

In [None]:
from pathlib import Path

import pandas as pd
import numpy as np

fig, ax = plt.subplots(2, 2, figsize=(6, 6))
for i, fp in enumerate(["bsf", "maccs", "morgan", "rdk"]):
    df = pd.read_csv(
        f"{Path().home()}/article_bsf/output/{fp}_tsne.csv",
        header=None,
        names=["x", "y"],
    )
    df["class"] = np.loadtxt(
        f"{Path().home()}/article_bsf/data/input/chebi_classes.csv",
        delimiter=",",
        dtype=str,
        usecols=1,
    )
    df["id"] = np.loadtxt(
        f"{Path().home()}/article_bsf/data/input/chebi_classes.csv",
        delimiter=",",
        dtype=str,
        usecols=0,
    )

    # remove all molecules that have an "R" in them

    sns.scatterplot(
        data=df[~df["class"].str.contains(";")],
        x="x",
        y="y",
        hue="class",
        palette="Spectral",
        alpha=0.5,
        ax=ax.flatten()[i],
    )
    ax.flatten()[i].set_title(fp, fontweight="regular")
    ax.flatten()[i].tick_params(
        axis="both",
        which="both",
        bottom=False,
        left=False,
        labelbottom=False,
        labelleft=False,
    )
    ax.flatten()[i].legend().remove()
    ax.flatten()[i].set_xlabel("")
    ax.flatten()[i].set_ylabel("")
for handle in ax[0,1].get_legend_handles_labels()[0]:
    handle.set_alpha(1)
plt.gca().set_aspect("equal", adjustable="box")
ax[0,1].legend(loc="upper left", bbox_to_anchor=(1, 1), markerscale=3)
plt.suptitle(f"t-SNE of single class Chebi compounds", fontsize=20, fontweight="bold")
plt.savefig(fig_folder / f"sm_tsne_single.png", dpi=300, bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(6, 6))
for i, fp in enumerate(["bsf", "maccs", "morgan", "rdk"]):
    df = pd.read_csv(
        f"{Path().home()}/article_bsf/output/{fp}_tsne.csv",
        header=None,
        names=["x", "y"],
    )
    df["class"] = np.loadtxt(
        f"{Path().home()}/article_bsf/data/input/chebi_classes.csv",
        delimiter=",",
        dtype=str,
        usecols=1,
    )
    df["id"] = np.loadtxt(
        f"{Path().home()}/article_bsf/data/input/chebi_classes.csv",
        delimiter=",",
        dtype=str,
        usecols=0,
    )

    # remove all molecules that have an "R" in them

    sns.scatterplot(
        data=df[~df["class"].str.contains(";")],
        x="x",
        y="y",
        hue="class",
        palette="Spectral",
        alpha=0.5,
        ax=ax.flatten()[i],
    )
    sns.scatterplot(
        data=df[df["class"].str.contains(";")],
        x="x",
        y="y",
        c="darkgrey",
        alpha=0.5,
        ax=ax.flatten()[i],
        label="multiple",
    )
    ax.flatten()[i].set_title(fp, fontweight="regular")
    ax.flatten()[i].tick_params(
        axis="both",
        which="both",
        bottom=False,
        left=False,
        labelbottom=False,
        labelleft=False,
    )
    ax.flatten()[i].legend().remove()
    ax.flatten()[i].set_xlabel("")
    ax.flatten()[i].set_ylabel("")
for handle in ax[0,1].get_legend_handles_labels()[0]:
    handle.set_alpha(1)
plt.gca().set_aspect("equal", adjustable="box")
ax[0,1].legend(loc="upper left", bbox_to_anchor=(1, 1), markerscale=3)
plt.suptitle(f"t-SNE of Chebi compounds", fontsize=20, fontweight="bold")
plt.savefig(fig_folder / f"sm_tsne.png", dpi=300, bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(6, 6))
for i, fp in enumerate(["bsf", "maccs", "morgan", "rdk"]):
    df = pd.read_csv(
        f"{Path().home()}/article_bsf/output/{fp}_umap.csv",
        header=None,
        names=["x", "y"],
    )
    df["class"] = np.loadtxt(
        f"{Path().home()}/article_bsf/data/input/chebi_classes.csv",
        delimiter=",",
        dtype=str,
        usecols=1,
    )
    df["id"] = np.loadtxt(
        f"{Path().home()}/article_bsf/data/input/chebi_classes.csv",
        delimiter=",",
        dtype=str,
        usecols=0,
    )

    # remove all molecules that have an "R" in them

    sns.scatterplot(
        data=df[~df["class"].str.contains(";")],
        x="x",
        y="y",
        hue="class",
        palette="Spectral",
        alpha=0.5,
        ax=ax.flatten()[i],
    )
    ax.flatten()[i].set_title(fp, fontweight="regular")
    ax.flatten()[i].tick_params(
        axis="both",
        which="both",
        bottom=False,
        left=False,
        labelbottom=False,
        labelleft=False,
    )
    ax.flatten()[i].legend().remove()
    ax.flatten()[i].set_xlabel("")
    ax.flatten()[i].set_ylabel("")
for handle in ax[0,1].get_legend_handles_labels()[0]:
    handle.set_alpha(1)
plt.gca().set_aspect("equal", adjustable="box")
ax[0,1].legend(loc="upper left", bbox_to_anchor=(1, 1), markerscale=3)
plt.suptitle(f"UMAP of single class Chebi compounds", fontsize=20, fontweight="bold")
plt.savefig(fig_folder / f"sm_umap_single.png", dpi=300, bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(6, 6))
for i, fp in enumerate(["bsf", "maccs", "morgan", "rdk"]):
    df = pd.read_csv(
        f"{Path().home()}/article_bsf/output/{fp}_umap.csv",
        header=None,
        names=["x", "y"],
    )
    df["class"] = np.loadtxt(
        f"{Path().home()}/article_bsf/data/input/chebi_classes.csv",
        delimiter=",",
        dtype=str,
        usecols=1,
    )
    df["id"] = np.loadtxt(
        f"{Path().home()}/article_bsf/data/input/chebi_classes.csv",
        delimiter=",",
        dtype=str,
        usecols=0,
    )

    # remove all molecules that have an "R" in them

    sns.scatterplot(
        data=df[~df["class"].str.contains(";")],
        x="x",
        y="y",
        hue="class",
        palette="Spectral",
        alpha=0.5,
        ax=ax.flatten()[i],
    )
    sns.scatterplot(
        data=df[df["class"].str.contains(";")],
        x="x",
        y="y",
        c="darkgrey",
        alpha=0.5,
        ax=ax.flatten()[i],
        label="multiple",
    )
    ax.flatten()[i].set_title(fp, fontweight="regular")
    ax.flatten()[i].tick_params(
        axis="both",
        which="both",
        bottom=False,
        left=False,
        labelbottom=False,
        labelleft=False,
    )
    ax.flatten()[i].legend().remove()
    ax.flatten()[i].set_xlabel("")
    ax.flatten()[i].set_ylabel("")
for handle in ax[0,1].get_legend_handles_labels()[0]:
    handle.set_alpha(1)
plt.gca().set_aspect("equal", adjustable="box")
ax[0,1].legend(loc="upper left", bbox_to_anchor=(1, 1), markerscale=3)
plt.suptitle(f"UMAP of Chebi compounds", fontsize=20, fontweight="bold")
plt.savefig(fig_folder / f"sm_umap.png", dpi=300, bbox_inches="tight")

## 7. biosynfoni substructures

In [None]:
from biosynfoni import subkeys
from rdkit import Chem

smarts = subkeys.get_smarts(version=subkeys.defaultVersion)
mols = [Chem.MolFromSmarts(smart) for smart in smarts]

# Draw the molecules in a grid
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole

Draw.MolsToGridImage(mols, molsPerRow=3)



## Figure 4 - Clustermaps

In [None]:
# import sys, os, argparse, logging

# import numpy as np
# import seaborn as sns
# import scipy.cluster.hierarchy as sch
# import matplotlib as mpl
# import matplotlib.pylab as plt
# from matplotlib.patches import Patch
# import pandas as pd
# from tqdm import tqdm

# # matplotlib.use('Agg')       #if in background

# from biosynfoni.subkeys import get_names, get_pathway
# from utils import set_style
# from utils.colours import colourDict
# from utils.figures import set_label_colors_from_categories


# def cli():
#     """
#     Command line interface for clustermap
#     """
#     parser = argparse.ArgumentParser()

#     parser.add_argument("fingerprints", help="Fingerprint file")
#     parser.add_argument("labels", help="Labels file")
#     parser.add_argument(
#         "-s",
#         "--subsample",
#         required=False,
#         type=int,
#         help="subsample size",
#         default=None,
#     )
#     parser.add_argument(
#         "-r",
#         "--seed",
#         "--randomseed",
#         required=False,
#         type=int,
#         help="seed for subsampling",
#         default=None,
#     )

#     args = parser.parse_args()
#     args.fingerprints = os.path.abspath(args.fingerprints)
#     args.labels = os.path.abspath(args.labels)
#     return args


# class recursion_depth:
#     def __init__(self, limit):
#         self.limit = limit
#         self.default_limit = sys.getrecursionlimit()

#     def __enter__(self):
#         sys.setrecursionlimit(self.limit)

#     def __exit__(self, type, value, traceback):
#         sys.setrecursionlimit(self.default_limit)


# class ClusterMap:
#     def __init__(self, df, labels, metric, method) -> None:
#         self.df = df
#         self.indexes = df.index.values
#         self.labels = labels
#         self.colordict = self.get_colordict()
#         self.metric = metric
#         self.method = method
#         logging.debug(f"calculating distances with {metric} and {method}...")
#         self.distances = self.get_distances()
#         self.clustering = self.get_clustering()
#         # self.distances = None
#         # self.clustering = None
#         # self.tree = self.get_tree()
#         self.colors, self.handles = self._get_category_colors_handles(self.labels)
#         logging.debug(f"plotting clustermap with {metric} and {method}...")
#         self.clustermap = self.seacluster()
#         self.clusterfig = self.get_clusterplot()
#         plt.close()
#         pass

#     def get_distances(self):
#         """
#         calculates distances from data frame using metric
#         """
#         return sch.distance.pdist(self.df, metric=self.metric)

#     def set_distances(self, distances):
#         """
#         sets distances from data frame
#         """
#         self.distances = distances
#         return None

#     def get_clustering(self):
#         """
#         calculates clustering from distances using method
#         """
#         # plt.title(out_file)
#         with recursion_depth(10000):
#             clustering = sch.linkage(self.distances, method=self.method)
#         # plt.close()
#         return clustering

#     def get_tree(self):
#         """returns dendrogram tree"""
#         tree = sch.dendrogram(
#             self.clustering, leaf_font_size=2, color_threshold=4, labels=self.indexes
#         )
#         return tree

#     def seacluster(self):
#         """
#         Makes a seaborn clustermap

#         Returns:
#         seacluster: seaborn clustermap object
#         """
#         # comp_color, comp_handles = get_gnsr_diff_color() #compounds

#         cmap = _cmap_makezerowhite("mako")

#         # self.distances = self.set_distances(self.get_distances())
#         # self.clustering = self.get_clustering()
#         # fig, ax = plt.subplots(figsize=(15,6))
#         # sns.set(font_scale=0.5)
#         with recursion_depth(20000):
#             seacluster = sns.clustermap(
#                 self.df,
#                 method=self.method,
#                 metric=self.metric,
#                 xticklabels=1,  # every 1: label
#                 # robust = True,
#                 # square = True,
#                 row_colors=self.colors,
#                 # col_colors = phyl_color,
#                 # z_score = 1,
#                 # figsize= (len(list(data_frame)), min(200, len(data_frame.index.values))),
#                 cmap=cmap,
#                 cbar_kws={
#                     "shrink": 0.3,
#                     "label": "counts",
#                     # "orientation": "horizontal",
#                     # "labelweight": "bold",
#                 },
#                 cbar_pos=(1.05, 0.2, 0.03, 0.4),
#             )
#         # make x tick labels bold

#         return seacluster

#     def get_clusterplot(self, legend_title: str = "class") -> plt.gca:
#         """returns plt of clustermap with legend of categories"""

#         # plt.figure(figsize=(15,6))

#         # plt.setp(seacluster.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)
#         # plt.setp(seacluster.ax_heatmap.xaxis.get_majorticklabels(), rotation=30, fontsize=8)
#         # plt.setp(seacluster.ax_heatmap.xaxis.get_majorticklabels(), rotation=-30)
#         # seacluster.set_xlim([0,2])

#         # seacluster.set_title(TITLE, fontsize=16, fontdict={})
#         # plt.title(
#         #     f"Hierarchical clustering of compounds and substructures {self.method}, {self.metric}",
#         #     loc="center",
#         # )  # + 'z scored')
#         self.clustermap.fig.suptitle(
#             f"Hierarchical clustering of compounds and substructures {self.method}, {self.metric}",
#             x=0.6,
#             y=1.1,
#             weight="bold",
#             size=18,
#         )
#         self.clustermap.ax_heatmap.set_xticklabels(
#             self.clustermap.ax_heatmap.get_xmajorticklabels(),
#             fontsize=10,
#             fontweight="semibold",
#         )
#         self._set_substructure_colours()

#         # seacluster = self.seacluster()

#         # improve layout
#         # plt.tight_layout()

#         # add legend
#         handles = self.handles
#         legend_colors = [
#             Patch(facecolor=handles[name], edgecolor="#FFFFFF00")
#             for name in handles.keys()
#         ]

#         plt.legend(
#             legend_colors,
#             handles.keys(),
#             title=legend_title,
#             bbox_to_anchor=(1, 0.9),
#             bbox_transform=plt.gcf().transFigure,
#             loc="upper left",
#             frameon=False,
#             edgecolor="#FFFFFF",
#             fontsize=10,
#         )
#         # set legend title size
#         plt.setp(plt.gca().get_legend().get_title(), fontsize=10)

#         # plt.show()
#         # return dendrogram_linkage
#         return plt.gca()

#     def save_clustermap(self, fmt: str = "svg") -> None:
#         """saves clustermap to file

#         Args:
#             fmt: str, file format
#         Returns:
#             None
#         """
#         out_file = f"clustermap_{self.method}_{self.metric}.{fmt}"
#         # self.clusterfig.savefig(out_file, format=fmt)
#         self.clustermap.savefig(out_file, format=fmt)
#         # plt.savefig(out_file, format=fmt)
#         return None

#     def get_dendogram_tree(self):
#         """returns dendrogram tree object for branch cutting"""
#         return self.clustermap.dendrogram_row.dendrogram

#     def get_dendogram_linkage(self) -> np.ndarray:
#         """returns dendrogram linkage object"""
#         return self.clustermap.dendrogram_row.linkage

#     def get_colordict(self) -> dict:
#         """returns colour dictionary for categories
#         Returns:
#             colordict: dict, category: color
#         """
#         # colordict = {
#         #     "Terpenoids": sns.color_palette("Set3")[6],  # green
#         #     "Alkaloids": sns.color_palette("Set3")[9],  # purple
#         #     "Shikimates and Phenylpropanoids": sns.color_palette("Set3")[4],  # blue
#         #     "Fatty acids": sns.color_palette("Set3")[5],  # orange
#         #     "Carbohydrates": sns.color_palette("Set3")[7],  # pink
#         #     "Polyketides": sns.color_palette("Set3")[3],  # light red
#         #     "Amino acids and Peptides": "bisque",
#         #     # "No NP-Classifier prediction": "grey",
#         #     "None": "grey",
#         #     "Synthetic": "black",
#         # }
#         if self.labels[0][0].isupper():
#             colordict = colourDict["NPClassifier prediction"]
#         else:
#             colordict = colourDict["chebi class"]
#         return colordict

#     def _get_category_colors_handles(
#         self, categories: pd.Series
#     ) -> tuple[pd.DataFrame, dict]:
#         """uses colour dictionary to assign colors to the categories"""
#         network_dict = {}
#         categories.fillna("None", inplace=True)

#         for ind, cat in categories.items():
#             # network_dict[ind] = [self.colordict[str(cat).split(',')[0]]] #in case of multiple categories, take the first one
#             network_dict[ind] = [self.colordict[str(cat)]]
#         network_colors = pd.DataFrame.from_dict(network_dict, orient="index")
#         network_colors.columns = [""]
#         handles = self.colordict
#         return network_colors, handles

#     def _set_substructure_colours(self):
#         """sets substructure colours in clustermap"""
#         pathways = get_pathway()
#         substructures = self.df.columns
#         subs_to_pathways = {a: b for a, b in zip(substructures, pathways)}
#         ticklabels = self.clustermap.ax_heatmap.get_xticklabels()
#         # access text from ticklabels
#         pathways = [subs_to_pathways[x.get_text()] for x in ticklabels]
#         if len(pathways) != len(ticklabels):
#             logging.warning(
#                 f"cannot set {pathways} substructure for {ticklabels} colours"
#             )
#             return None
#         else:
#             set_label_colors_from_categories(
#                 ticklabels, pathways, colourDict["pathways"]
#             )


# def _cmap_makezerowhite(
#     default_cmap: str = "mako",
# ) -> mpl.colors.LinearSegmentedColormap:
#     # define color map:-------------------------------------------------
#     cmap = sns.color_palette(default_cmap, as_cmap=True)  # define the colormap
#     # extract all colors from the .jet map
#     cmaplist = [cmap(i) for i in range(cmap.N)]
#     # force the first color entry to be white (to distinguish 0 from low values)
#     cmaplist[0] = (1.0, 1.0, 1.0, 0)
#     # create the new map
#     cmap = mpl.colors.LinearSegmentedColormap.from_list("Custom cmap", cmaplist, cmap.N)
#     return cmap


# def main():
#     set_style()
#     args = cli()

#     filetype = "svg"
#     # version = input_file.split("/")[-1].split("_")[-1].split(".")[0]
#     substructure_names = [x.replace("_", " ") for x in get_names()]

#     fp = pd.read_csv(args.fingerprints, sep=",", header=None, dtype=int)
#     if fp.shape[1] == len(substructure_names):
#         fp.columns = substructure_names
#     db_name = args.fingerprints.split("/")[-1].split(".")[0].split("_")[0]
#     fp_name = args.fingerprints.split("/")[-1].split(".")[0].split("_")[1]

#     npcs = pd.read_csv(
#         args.labels,
#         sep="\t",
#         header=None,
#         dtype=str,
#         usecols=[0],
#     )

#     # as all isoprenoids are fatty acids according to chebi:
#     npcs.replace("fatty_acid,isoprenoid", "isoprenoid", inplace=True)
#     # filter out multiple-prediction compounds
#     npcs.fillna(",", inplace=True)
#     fp = fp[~npcs[0].str.contains(",")]
#     npcs = npcs[~npcs[0].str.contains(",")]

#     # # filter out only-zero columns in df ~~~~~~~~~~~~~~~ CHECK IF THIS APPLIES TO YOUR PURPOSES ~~~~~~~~~~~~~~~
#     # fp = fp.loc[:, (fp != 0).any(axis=0)]
#     # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

#     # subsample indexes
#     if args.subsample:
#         np.random.seed(args.seed)
#         idx = np.random.choice(fp.index, args.subsample, replace=False)
#         fp = fp.loc[idx]
#         npcs = npcs.loc[idx]

#     npcs_series = npcs.iloc[:, 0]

#     # check if indexes are the same
#     assert fp[fp.index != npcs.index].empty
#     assert npcs[npcs.index != npcs_series.index].empty
#     assert type(npcs_series) == pd.Series

#     # if args.synthetic: #under construction
#     #     synthetic_fp = pd.read_csv(args.synthetic, dtype=int)
#     #     synthetic_fp = np.random.choice(
#     #         synthetic_fp.shape[0], fp.shape[0], replace=False
#     #     )
#     #     fp = np.concatenate((fp, synthetic_fp))
#     #     synthetic_labels = np.array(["synthetic" for _ in range(synthetic_fp.shape[0])])
#     #     labels = np.concatenate((labels, synthetic_labels))

#     iwd = os.getcwd()
#     # make a directory in grandparent directory called clustermaps
#     # os.chdir("../../")
#     os.makedirs(f"clustermaps/{db_name}/{fp_name}", exist_ok=True)
#     os.chdir(f"clustermaps/{db_name}/{fp_name}")

#     # # debugging
#     # clustermap = ClusterMap(fp, npcs_series, "euclidean", "average")
#     # clustermap.save_clustermap(fmt=filetype)

#     for method in tqdm(["average", "complete", "single", "weighted"]):
#         for metric in tqdm(
#             [
#                 "euclidean",
#                 "cityblock",
#                 "cosine",
#                 "correlation",
#                 "hamming",
#                 "jaccard",
#                 "mahalanobis",
#                 "chebyshev",
#                 "canberra",
#                 "braycurtis",
#                 "dice",
#                 "kulsinski",
#                 "matching",
#                 "rogerstanimoto",
#                 "russellrao",
#                 "sokalmichener",
#                 "sokalsneath",
#                 "yule",
#             ],
#             leave=False,
#         ):
#             # errors can occur for some metrics if they have too small sample sets, or with certain combinations:
#             try:
#                 clustermap = ClusterMap(fp, npcs_series, metric, method)
#                 clustermap.save_clustermap(fmt=filetype)
#             except:
#                 logging.warning(f"failed for {method} and {metric}")
#             pass
#         pass

#     os.chdir(iwd)
#     return None


# if __name__ == "__main__":
#     main()

## Figure 5 - Importances Classification

In [None]:
# from sys import argv

# import numpy as np
# import matplotlib.pyplot as plt
# import matplotlib as mpl

# from biosynfoni.subkeys import defaultVersion, get_values, get_pathway
# from utils.figures import cat_to_colour
# from utils.colours import colourDict
# from utils import set_style


# def main():
#     """
#     Plot the feature importances from a random forest model as a barplot
#     """
#     importances = np.loadtxt(argv[1], delimiter="\t", dtype=float)
#     bsf_name = defaultVersion
#     substructure_names = get_values("name", version=bsf_name)
#     pathways = get_pathway(version=bsf_name)
#     colors = cat_to_colour(pathways, colourDict["pathways"])
#     if len(substructure_names) != importances.shape[1]:
#         print("WARNING: substructure names not equal to importances")
#         substructure_names = [f"{i}" for i in range(importances.shape[1])]
#         colors = ["#888888" for _ in range(importances.shape[1])]
#     set_style()

#     means = np.mean(importances, axis=0)

#     # set plot size
#     # default: 6.4, 4.8
#     ratio = importances.shape[1] / 39
#     plt.figure(figsize=(ratio * 6.4, 4.8))

#     # plot barplot
#     barplot = plt.bar(substructure_names, means)
#     # add standard deviations as error bars
#     stds = np.std(importances, axis=0)
#     print(stds.shape)
#     e1 = plt.errorbar(substructure_names, means, yerr=stds, fmt="o", color="#606060")
#     e2 = plt.errorbar(substructure_names, means, yerr=stds, fmt="none", color="#606060")

#     plt.xticks(range(len(substructure_names)), substructure_names, rotation=90)
#     # set bar colours
#     for i, bar in enumerate(barplot):
#         bar.set_color(colors[i])

#     plt.xticks(rotation=90)
#     plt.ylabel("feature importance")
#     plt.xlabel("substructure")
#     plt.suptitle("Feature importances for random forest model", weight="bold")
#     plt.title("(averaged over k-fold cross validation with k=5)", weight="light")
#     plt.tight_layout()
#     plt.savefig(argv[1].replace(".tsv", ".png"), bbox_inches="tight")
#     return None


# if __name__ == "__main__":
#     main()

## Figure 6 - Confusion Matrix heatmaps

In [None]:
# import argparse, logging
# import os

# import matplotlib as mpl
# import matplotlib.pyplot as plt
# import numpy as np


# def set_style() -> None:
#     """
#     Set the style of the plot to biostylefoni
#     """
#     # get path of this script
#     script_path = os.path.dirname(os.path.realpath(__file__))
#     parent_path = os.path.dirname(script_path)
#     utils_path = os.path.join(parent_path, "utils")
#     print(utils_path)
#     style_path = os.path.join(utils_path, "biostylefoni.mplstyle")
#     # set style
#     plt.style.use(style_path)
#     return None


# def parse_cms_files(m_path: str) -> tuple[np.array, np.array]:
#     """
#     Parse the confusion matrices and names from a file

#     Args:
#         m_path (str): path to file with confusion matrices

#     Returns:
#         tuple[np.array, np.array]: array of confusion matrices and array of names (i.e. )
#     """
#     cms = np.loadtxt(
#         m_path,
#         delimiter="\t",
#         dtype=int,
#         skiprows=1,
#         usecols=(1, 2, 3, 4),
#     )
#     # for each row, split the array into a 2x2 array
#     cms = cms.reshape(cms.shape[0], 2, 2)

#     # now get only the 'index names'
#     cm_names = np.loadtxt(
#         m_path,
#         delimiter="\t",
#         dtype=str,
#         skiprows=1,
#         usecols=(0),
#     )
#     return cms, cm_names


# def get_matrices(cms: np.array) -> tuple[list[np.array], list[np.array]]:
#     """
#     Take an array of confusion matrices and return a list of matrices and a list of normalised matrices

#     Args:
#         cms (np.array): array of confusion matrices

#     Returns:
#         tuple[list[np.array], list[np.array]]: a list of matrices and a list of normalised matrices
#     """
#     matrices = []
#     norm_matrices = []
#     perc_matrices = []
#     for i in range(cms.shape[0]):
#         # # make random matrix with values between 0 and 100000
#         # matrix = np.random.randint(0, 100000, size=(2, 2))
#         matrix = cms[i]
#         # normalise matrix
#         norm_matrix = matrix / matrix.sum(axis=1, keepdims=True)
#         # turn normalised matrix into percentages
#         perc_matrix = norm_matrix * 100
#         # append matrices to list
#         matrices.append(matrix)
#         norm_matrices.append(norm_matrix)
#         perc_matrices.append(perc_matrix)
#     assert len(matrices) == len(perc_matrices), "#matrices don't match"
#     return matrices, perc_matrices


# def main(matrix_path, ):
#     matrix_path = Path(matrix_path)
#     ml_input = matrix_path.stem.split("_")[-1]

#     # read in the confusion matrix and names
#     cms, cm_names = parse_cms_files(matrix_path)
#     # print(cms, cm_names)

#     # get the matrices and the percentage versions for each category
#     matrices, perc_matrices = get_matrices(cms)
#     assert len(matrices) == len(cm_names), "#matrices and #categories don't match"

#     # make subplots
#     fig, axs = plt.subplots(
#         1, len(matrices), figsize=(len(matrices), 2), dpi=500
#     )  # , sharey=True) #sharing y makes the y axis ticks appear in each subplot

#     # make a heatmap in each subplot
#     for i, ax in enumerate(axs):
#         cmap_name = "Greys"
#         cmap = mpl.colormaps[cmap_name]
#         # plot heatmap
#         im = ax.imshow(perc_matrices[i], cmap=cmap_name, vmin=0, vmax=100)
#         # remove ticks
#         ax.set_xticks([])
#         ax.set_yticks([])
#         ax.set_xticklabels([])
#         ax.set_yticklabels([])
#         # set title
#         title = cm_names[i].replace("_", "\n")
#         axtitle = ax.set_title(title, fontsize=7, fontweight=600, wrap=True)
#         # force the wrap line width to be shorter
#         axtitle._get_wrap_line_width = lambda: 600.0  #  wrap to 600 screen pixels
#         # annotate values in each box, with dark text for light background and light text for dark background
#         fontweight = 500
#         a_size = 5
#         for j in range(2):
#             for k in range(2):
#                 if perc_matrices[i][j][k] < 50:
#                     text = ax.text(
#                         k,
#                         j,
#                         f"{round(matrices[i][j][k], 2)}\n({round(perc_matrices[i][j][k], 2)}%)",
#                         ha="center",
#                         va="center",
#                         color=cmap(1.0),
#                         fontsize=a_size,
#                         fontweight=fontweight,
#                     )
#                 else:
#                     text = ax.text(
#                         k,
#                         j,
#                         # round(matrices[i][j][k], 2),
#                         f"{round(matrices[i][j][k], 2)}\n({round(perc_matrices[i][j][k], 2)}%)",
#                         ha="center",
#                         va="center",
#                         color=cmap(0.0),
#                         fontsize=a_size,
#                         fontweight=fontweight,
#                     )

#     # set ticks
#     axs[0].set_yticks([0, 1])
#     axs[0].set_yticklabels(["P", "N"], fontsize=6, fontweight=500)
#     axs[0].set_xticks([0, 1])
#     axs[0].set_xticklabels(["P", "N"], fontsize=6, fontweight=500)

#     # set y label
#     axs[0].set_ylabel("truth", fontsize=7, fontweight=600)
#     axs[0].set_xlabel("prediction", fontsize=7, fontweight=600)

#     # set common title
#     fig.suptitle(
#         f"confusion matrices for multilabel RF on {ml_input}",
#         fontsize=9,
#         fontweight=600,
#     )

#     # set suptitle on y axis (for later when looping across all folders)
#     # fig.text(0.02, 0.5, "confusion matrices for multilable RF on", fontsize=8, fontweight=600, rotation=90, va="center")

#     # set tight layout
#     plt.tight_layout()

#     # set colorbar
#     cbar = fig.colorbar(im, ax=axs, shrink=0.8, orientation="horizontal")
#     cbar.ax.tick_params(labelsize=6)
#     # add label to colorbar
#     cbar.ax.set_xlabel("% of compounds", fontsize=7)

#     # get path to folder where confusion matrices are, using os.path.dirname
#     save_path = "/".join(matrix_path.split("/")[:-1])
#     logging.info(matrix_path, save_path)
#     # save figure
#     plt.savefig(
#         matrix_path.replace("_matrix.txt", "_heatmap.png"), dpi=500, bbox_inches="tight"
#     )
#     plt.clf()
#     plt.close()
#     exit(0)
#     return None


# if __name__ == "__main__":
#     main()

## Figure 7 - Butina Clustering with distance matrices

In [None]:
# clusters = Butina.ClusterData(
#         dist_matrix, num_fps, 0.2, isDistData=True
#     )

Chains per length cut

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

num_per_length = np.loadtxt(
    Path.home() / "article_bsf" / "output" / "num_chains_per_length.csv",
    dtype=int,
    delimiter=",",
)
num_per_length

# plot on log scale y
plt.figure(figsize=(10, 6))
plt.plot(
    num_per_length[:, 0],
    num_per_length[:, 1],
    marker="o",
    markerfacecolor=mpl.color_sequences["tab10"][0],
    markeredgecolor="none",
    markersize=5,
)
plt.yscale("log")
plt.xlabel("chain length")
plt.ylabel("number of chains")
# add grid
plt.grid(True)
# add finer grid
plt.grid(which="minor", linestyle="--")
plt.title("Number of chains per length")

## Exploratory: chains

In [None]:
import os
from functools import partial
from pathlib import Path

import numpy as np
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt


from biosynfoni.inoutput import *

pathway_data = pd.read_csv(
    "~/article_bsf/data/input/metacyc_pathways.tsv", sep="\t", index_col=0
)
pathway_data

In [None]:
lengths = []
for idx, info in pathway_data.groupby("pathway_id"):
    lengths.append(len(info["reaction_id"].tolist()))

# histogram of chain lengths
plt.hist(lengths, bins=50)
plt.show()

# lengths above 10
lengths = [length for length in lengths if length > 10]

plt.hist(lengths, bins=50)
plt.show()

In [None]:
reaction_info = pathway_data[["reaction_id", "left", "direction", "right"]]
reaction_info[reaction_info["left"].str.contains(" ")]

In [None]:
def get_pathway_graphs(pathway_data):
    pathways = {}
    for pathway_id, info in pathway_data.groupby("pathway_id"):
        pathways[pathway_id] = nx.DiGraph()
        for _, row in info.iterrows():
            pathways[pathway_id].add_edge(
                row["left"], row["right"], reaction_id=row["reaction_id"]
            )
    return pathways


def get_longest_chains(pathways: dict):  # , with_compounds: list | None = None):
    for pathway_id, pathway_graph in pathways.items():
        # if not with_compounds is None:
        #     nodes_to_remove = [node for node in pathway_graph.nodes if node not in with_compounds]
        #     pathway_graph.remove_nodes_from(nodes_to_remove)
        # if graph is cyclic, iterate over all nodes to find the longest path with all_simple_paths
        if not nx.is_directed_acyclic_graph(pathway_graph):
            longest_path = []
            for node in pathway_graph.nodes:
                for path in nx.all_simple_paths(
                    pathway_graph, source=node, target=pathway_graph.nodes
                ):
                    if len(path) > len(longest_path):
                        longest_path = path
        else:
            longest_path = nx.dag_longest_path(pathway_graph)
        yield pathway_id, longest_path

In [None]:
pathways = get_pathway_graphs(pathway_data)
longest_chains = dict(
    get_longest_chains(pathways)
)  # , with_compounds=compounds["UNIQUE-ID"].tolist()))

longest_chains = dict(
    sorted(longest_chains.items(), key=lambda x: len(x[1]), reverse=True)
)

# get the edge information
# nx_graph[longest_path[i]][longest_path[i+1]]

In [None]:
len(longest_chains["PWY-8152"])

In [None]:
# draw graph
nx.draw(pathways["PWY-8152"], with_labels=True)

In [None]:
# save all the graphs
if not os.path.exists("pathway_graphs"):
    Path("pathway_graphs").mkdir(exist_ok=False)

for pathway_id, pathway_graph in pathways.items():
    nx.write_edgelist(pathway_graph, f"pathway_graphs/{pathway_id}.edgelist")

In [None]:
from biosynfoni import Biosynfoni
from rdkit import Chem

mol = Chem.MolFromSmiles("CCO")
bsf = Biosynfoni(mol).fingerprint
bsf

## 