In [13]:
import json
from functools import cache
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
import plotly.express as px
from loguru import logger
from rdkit import Chem
from rdkit.Chem import DataStructs, rdMolDescriptors
from rdkit.Chem.GraphDescriptors import BertzCT
from tqdm import tqdm

In [26]:
def read_json_files(folder_path: str) -> Any:
    folder_full_path = f"{folder_path}"
    folder = Path(folder_full_path)
    json_files = folder.rglob("*.json")
    results = []

    for json_file_path in json_files:
        with Path(json_file_path).open("r") as f:
            results.append(json.load(f))
    json_file_ids = [file.stem for file in list(folder.rglob("*.json"))]
    return results, json_file_ids


@cache
def compute_fingerprint(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return Chem.RDKFingerprint(mol, maxPath=8, fpSize=2048)


@cache
def tanimoto_similarity(smiles1, smiles2):
    fp1 = compute_fingerprint(smiles1)
    fp2 = compute_fingerprint(smiles2)
    return DataStructs.FingerprintSimilarity(fp1, fp2, metric=DataStructs.TanimotoSimilarity)


@cache
def compute_molecular_formula(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return rdMolDescriptors.CalcMolFormula(mol)


def filter_by_molecular_formula(results: list, molecular_formula: str):
    filtered_results = []
    for result in results:
        if compute_molecular_formula(result) == molecular_formula:
            filtered_results.append(result)
    return filtered_results


def filter_final_population(full_results: dict):
    all_smiles_in_final_population = [*full_results["final_population"]]
    original_smiles_molecular_formula = compute_molecular_formula(full_results["original_smiles"])
    return filter_by_molecular_formula(all_smiles_in_final_population, original_smiles_molecular_formula)


# max tanimoto initial population
def compute_max_tanimoto_initial_population(full_results: dict):
    original_smiles = get_original_smiles(full_results)
    initial_population = full_results["initial_population"][:20]
    tanimoto_similarities = []
    for smiles in initial_population:
        tanimoto_similarities.append(tanimoto_similarity(original_smiles, smiles))
    return max(tanimoto_similarities)


def get_original_smiles(full_results: dict):
    return full_results["original_smiles"]


@cache
def compute_molecule_size(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return mol.GetNumAtoms()


@cache
def compute_molecular_complexity(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return BertzCT(mol)


def compute_results(full_results: dict):
    original_smiles = get_original_smiles(full_results)
    filtered_results = filter_final_population(full_results)

    tanimoto_similarities = []
    for result in filtered_results:
        tanimoto_similarities.append(tanimoto_similarity(original_smiles, result))

    if not tanimoto_similarities:
        return {
            "1": False,
            "5": False,
            "10": False,
            "max": 0,
            "best_score": 0,
            "mol_size": compute_molecule_size(original_smiles),
        }
    return {
        "1": tanimoto_similarities[0] == 1,
        "5": 1 in tanimoto_similarities[:5],
        "10": 1 in tanimoto_similarities[:10],
        "Found": 1 in tanimoto_similarities,
        "max": tanimoto_similarities[0],
        "best_score": full_results["final_population"][filtered_results[0]],
        "mol_size": compute_molecule_size(original_smiles),
        "max_initial_population": compute_max_tanimoto_initial_population(full_results),
        "molecular_complexity": compute_molecular_complexity(original_smiles),
    }


def stats_save(seed, folder_path):
    ex, ids = read_json_files(folder_path)
    stats = {}
    for i, result in enumerate(tqdm(ex)):
        stats[int(ids[i])] = compute_results(result)

    data_as_df = pd.DataFrame(stats).T.sort_index()
    data_as_df.index.name = "molecule_id"
    data_as_df.to_csv(f"results_{seed}.csv")
    # logger.info(f"Saved stats for seed {seed}")
    stats = data_as_df.mean()
    # logger.info(f"Mean stats for seed {seed}: <br>{stats}")
    # logger.info(f"{len(data_as_df)} molecules in seed {seed}")
    return data_as_df

In [24]:

def merge_json_values(file_paths: list[str]) -> dict[str, list[Any]]:
    """
    Reads multiple JSON files and combines values for matching keys.

    Args:
        file_paths (List[str]): List of paths to JSON files

    Returns:
        Dict[str, List[Any]]: Dictionary where each key maps to a list of values from all files
    """
    # try:
    # Initialize result dictionary
    combined_data: dict[str, list[Any]] = {}

    # Process each file
    for file_path in file_paths:
        with Path(file_path).open("r") as f:
            data = json.load(f)

            # For first file, initialize lists for each key
            if not combined_data:
                combined_data = {key: [value] for key, value in data.items()}
            else:
                # Verify keys match
                if set(data.keys()) != set(combined_data.keys()):
                    raise ValueError(f"Keys in {file_path} don't match the keys in other files")

                # Append values to existing lists
                for key, value in data.items():
                    combined_data[key].append(value)
    # combined_data["final_population"] = [{smiles:score}, {smiles:score}, ...]
    # list of keys
    list_of_keys_final_population = [list(d.keys()) for d in combined_data["final_population"]]
    list_of_keys_final_population = [item for sublist in list_of_keys_final_population for item in sublist]
    list_of_values_final_population = [list(d.values()) for d in combined_data["final_population"]]
    list_of_values_final_population = [item for sublist in list_of_values_final_population for item in sublist]
    dataframe = pd.DataFrame({"smiles": list_of_keys_final_population, "score": list_of_values_final_population})
    # dataframe.to_csv("temp.csv")
    # only unique smiles and highest score
    dataframe = dataframe.groupby("smiles").max().reset_index()
    # combined_data["final_population"] = dict(ChainMap(*combined_data["final_population"]))
    combined_data["final_population"] = dict(zip(dataframe["smiles"], dataframe["score"]))
    # final_population sorted by score final_population: {smiles: score}
    combined_data["final_population"] = dict(sorted(combined_data["final_population"].items(), key=lambda x: x[1], reverse=True))
    combined_data["initial_population"] = combined_data["initial_population"][0]
    combined_data["best_score"] = combined_data["final_population"][list(combined_data["final_population"].keys())[0]]
    combined_data["original_smiles"] = combined_data["original_smiles"][0]
    combined_data["tanimoto_similarity"] = tanimoto_similarity(
        combined_data["original_smiles"], list(combined_data["final_population"].keys())[0]
    )
    combined_data["best_individual"] = list(combined_data["final_population"].keys())[0]
    return combined_data


def save_combined_json(combined_data: dict[str, list[Any]], output_path: str) -> None:
    """
    Saves the combined data to a JSON file.

    Args:
        combined_data (Dict[str, List[Any]]): Combined data to save
        output_path (str): Path where to save the combined JSON
    """
    with Path(output_path).open("w") as f:
        json.dump(combined_data, f, indent=2)


def json_files(parent_folder_pattern: str, range_number: int = 1001) -> None:
    from glob import glob

    for i in range(range_number):
        try:
            files = glob(f"{parent_folder_pattern}*/{i}.json")
            # files = [file for file in files if "seed_1000" not in file]
            # Combine the data
            result = merge_json_values(files)
            if result:
                # Optionally save to a new JSON file
                save_combined_json(result, f"{parent_folder_pattern}/{i}.json")
        except Exception as e:
            print(i)
            continue


json_files("final_results_graphga", 1001)

In [27]:
def count_correct_seed(stats, top_n=1):
    return stats[stats[f"{top_n}"] == 1].__len__()


def check_no_true_duplicates(all_stats: list):
    # concatenate all dataframes
    all_stats = pd.concat(all_stats)
    all_stats.reset_index(inplace=True)
    # keep the duplicate with the highest score
    all_stats = (
        all_stats.sort_values("best_score", ascending=False)
        .drop_duplicates("molecule_id")
        .sort_values("molecule_id")
        .reset_index(drop=True)
    )
    all_stats.to_csv("all_stats.csv")
    return True


def count_scores(seeds: list[int], folder_path="final_results_graphga") -> tuple:
    top_1_seeds, top_5_seeds, top_10_seeds, top_50_seeds = [], [], [], []
    all_stats = []
    for seed in seeds:
        stats = stats_save(seed, folder_path)
        len_seed = len(stats)
        all_stats.append(stats)
        top_1 = count_correct_seed(stats, 1)
        top_5 = count_correct_seed(stats, 5)
        top_10 = count_correct_seed(stats, 10)
        top_50 = count_correct_seed(stats, "Found")
        top_1_seeds.append(top_1)
        top_5_seeds.append(top_5)
        top_10_seeds.append(top_10)
        top_50_seeds.append(top_50)
    check_no_true_duplicates(all_stats)
    overall_top_1 = sum(top_1_seeds) / len_seed
    overall_top_5 = sum(top_5_seeds) / len_seed
    overall_top_10 = sum(top_10_seeds) / len_seed
    overall_top_50 = sum(top_50_seeds) / len_seed
    return overall_top_1, overall_top_5, overall_top_10, overall_top_50


# for i in [555, 333, 222, 888]:
#     top_1, top_5, top_10, top_50 = count_scores([i])
top_overall_1, top_overall_5, top_overall_10, top_overall_50 = count_scores([1000], "final_results_graphga")
logger.info(f"Top 1 overall: {top_overall_1}")
logger.info(f"Top 5 overall: {top_overall_5}")
logger.info(f"Top 10 overall: {top_overall_10}")
logger.info(f"Percentage found: {top_overall_50}")


100%|██████████| 1001/1001 [04:16<00:00,  3.90it/s]
[32m2024-11-12 11:00:43.782[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m46[0m - [1mTop 1 overall: 0.7602397602397603[0m
[32m2024-11-12 11:00:43.783[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m47[0m - [1mTop 5 overall: 0.8781218781218781[0m
[32m2024-11-12 11:00:43.783[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m48[0m - [1mTop 10 overall: 0.8891108891108891[0m
[32m2024-11-12 11:00:43.783[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m49[0m - [1mPercentage found: 0.8981018981018981[0m


In [29]:
import plotly.graph_objects as go


def create_barchart(top_1, top_5, top_10, top_50):
    best_literature_top_1, best_literature_top_5, best_literature_top_10 = 0.6699, 0.8409, 0.8650
    dataframe_plot = pd.DataFrame(
        {
            "Metric": [
                "ref. top 1",
                "top 1",
                "ref. top 5",
                "top 5",
                "in the population",
            ],
            "Value": [best_literature_top_1, top_1, best_literature_top_5, top_5, top_50],
        }
    )

    custom_colors = {
        "ref. top 1": "rgba(152, 86, 86, 0.3)",
        "top 1": "rgba(152, 86, 86, 0.7)",
        "ref. top 5": "rgba(86, 86, 152, 0.3)",
        "top 5": "rgba(86, 86, 152, 0.7)",
        "in the population": "rgba(86, 86, 86, 0.7)",
    }

    # Create figure
    fig = go.Figure()

    # Add bars
    for metric in dataframe_plot["Metric"]:
        fig.add_trace(
            go.Bar(
                x=[dataframe_plot[dataframe_plot["Metric"] == metric]["Value"].iloc[0]],
                y=[metric],
                orientation="h",
                marker_color=custom_colors[metric],
                width=0.6,
                showlegend=False,
            )
        )

    # Update layout
    sizing = 305
    fig.update_layout(
        template="plotly_white",
        width=sizing * 1.618,
        height=sizing,
        margin={"l": 40, "r": 20, "t": 20, "b": 40},
        plot_bgcolor="white",
        font={"color": "rgb(120, 120, 120)"},
        bargap=0.15,
        bargroupgap=0.03,
        shapes=[
            # Add x-axis line
            {
                "type": "line",
                "xref": "paper",
                "yref": "y",
                "x0": 0,
                "x1": 1,
                "y0": -0.5,
                "y1": -0.5,  # Position below the bottom bar
                "line": {"color": "rgb(120, 120, 120)", "width": 1},
            },
            # Add y-axis line
            {
                "type": "line",
                "xref": "x",
                "yref": "paper",
                "x0": 0,
                "y0": 0,
                "x1": 0,
                "y1": 1,
                "line": {"color": "rgb(120, 120, 120)", "width": 1},
            },
        ],
    )

    # Update x-axis
    fig.update_xaxes(
        title_text="fraction of correct predictions",
        range=[-0.0, 1.02],  # Add padding
        showgrid=False,
        zeroline=False,
        showline=False,
        ticks="outside",
        tickwidth=1,
        tickcolor="rgb(120, 120, 120)",
        ticklen=5,
        tickvals=np.arange(0, 1.2, 0.2),
        tickformat=".1f",
    )

    # Update y-axis
    fig.update_yaxes(
        title_text="metric",
        showgrid=True,
        gridwidth=1,
        gridcolor="rgba(173, 216, 230, 0.3)",
        zeroline=False,
        showline=False,
        ticks="outside",
        tickwidth=1,
        tickcolor="rgb(120, 120, 120)",
        ticklen=5,
    )

    # Add reference line for perfect score
    fig.add_vline(
        x=1,
        line_dash="dot",
        line_color="rgb(120, 120, 120)",
        annotation={
            "text": "perfect score",
            "textangle": 90,
            "x": 1.01,
            "yref": "paper",
            "y": 0.7,
            "showarrow": False,
            "font": {"color": "rgb(120, 120, 120)"},
        },
    )
    # remove grid
    fig.update_xaxes(showgrid=False)
    fig.update_yaxes(showgrid=False)
    fig.write_image("figures/barchart.pdf", scale=2)
    fig.show()
    return fig


# Usage:
fig = create_barchart(top_overall_1, top_overall_5, top_overall_10, top_overall_50)

In [30]:
def plot_distr_of_scores_when_correct_vs_not_correct(prop_to_plot, color_column, stat=None):
    stats = pd.read_csv("results_1000.csv")
    stats = stats[stats["best_score"] != 0]

    # bin mol sizes each 5
    # mol_size = stats["mol_size"]
    # rename 50 to Top 50
    stats = stats.rename(columns={"5": "Top 5", "10": "Top 10", "1": "Top 1"})
    # fig = px.histogram(stats, x="best_score", color="Top 50", nbins=50, histnorm="percent")
    # separate correct and incorrect

    # fig = px.histogram(stats, x="best_score", color="Top 5", nbins=25, barmode="overlay")
    # update colors to match the bar chart
    # update colors of the legend "Found"
    # custom_colors = {
    #     "True": "#8C96E9",  # Light purple/blue
    #     "False": "#FFA07A"  # Soft coral/pink
    # }

    fig = px.histogram(
        stats,
        x=prop_to_plot,
        color=color_column,
        nbins=30,
        barmode="overlay",
        histnorm=stat,
        color_discrete_map={
            True: "rgb(152, 86, 86)",  # Previously used for "ref. top 5"
            False: "rgb(152, 146, 186)",  # Previously used for "top 1"
        },
        # color_discrete_map=custom_colors,
    )
    # add side histograms with size di
    # color name
    xaxis_titles = {
            "best_score": "cosine similarity",
            "mol_size": "molecule size",
            "max_initial_population": "max initial population similarity",
            "molecular_complexity": "Bertz complexity"
        }
    fig.update_xaxes(
        title_text=xaxis_titles.get(prop_to_plot, prop_to_plot),
        showgrid=False,
        zeroline=False,
        showline=False,
        ticks="outside",
        tickwidth=1,
        tickcolor='rgb(120, 120, 120)',
        ticklen=5
    )
    # Configure y-axis
    fig.update_yaxes(
        title_text='count' if stat is None else stat,
        showgrid=False,
        gridwidth=1,
        gridcolor='rgba(173, 216, 230, 0.3)',
        zeroline=False,
        showline=False,
        ticks="outside",
        tickwidth=1,
        tickcolor='rgb(120, 120, 120)',
        ticklen=5
    )
    # x axis title
    if prop_to_plot == "best_score":
        xaxis_title = "cosine similarity"
    elif prop_to_plot == "mol_size":
        xaxis_title = "molecule size"
    elif prop_to_plot == "max_initial_population":
        xaxis_title = "max initial population similarity"
    elif prop_to_plot == "molecular_complexity":
        xaxis_title = "Bertz complexity"
        # log average
        logger.info(stats["molecular_complexity"].mean())
    fig.update_xaxes(title_text=xaxis_title)
    # cmu sans serif
    fig.update_layout(font_family="CMU Sans Serif")
    fig.update_traces(marker={"line": {"width": 0.2, "color": "DarkSlateGrey"}})
    fig.update_traces(opacity=1)
    # range 0.5 to 1
    if prop_to_plot == "best_score":
        fig.update_xaxes(range=[0.5, 1], tickvals=np.arange(0.6, 1.11, 0.1))
        fig.add_vline(x=1, line_dash="dot", annotation_text="perfect similarity", annotation_position="bottom right")
    # fig.update_xaxes(showline=True, linecolor="black", linewidth=0.4)
    if prop_to_plot == "mol_size":
        fig.update_xaxes(range=[4, 36])
    # perfect similarity line
    # rotate vlinae text
    fig.update_annotations(textangle=90)
    # white background
    fig.update_layout(plot_bgcolor="white")
    # width and height
    fig.update_layout(width=250 * 1.6180339887, height=250)
    # tight layout
    fig.update_layout(margin={"l": 0, "r": 0, "t": 0, "b": 0})
    fig.write_image(f"figures/correct_vs_not_correct_{prop_to_plot}_{color_column}.pdf", scale=5)
    fig.show()


plot_distr_of_scores_when_correct_vs_not_correct("mol_size", color_column="Top 1")
plot_distr_of_scores_when_correct_vs_not_correct("mol_size", color_column="Top 5")
plot_distr_of_scores_when_correct_vs_not_correct("mol_size", color_column="Top 10")
plot_distr_of_scores_when_correct_vs_not_correct("mol_size", color_column="Found")

plot_distr_of_scores_when_correct_vs_not_correct("molecular_complexity", color_column="Top 1")  # , stat="percent")
plot_distr_of_scores_when_correct_vs_not_correct("molecular_complexity", color_column="Top 5")  # , stat="percent")
plot_distr_of_scores_when_correct_vs_not_correct("molecular_complexity", color_column="Top 10")  # , stat="percent")
plot_distr_of_scores_when_correct_vs_not_correct("molecular_complexity", color_column="Found")  # , stat="percent")

plot_distr_of_scores_when_correct_vs_not_correct("best_score", color_column="Top 1")  # , stat="percent")
plot_distr_of_scores_when_correct_vs_not_correct("best_score", color_column="Top 5")  # , stat="percent")
plot_distr_of_scores_when_correct_vs_not_correct("best_score", color_column="Top 10")  # , stat="percent")
plot_distr_of_scores_when_correct_vs_not_correct("best_score", color_column="Found")  # , stat="percent")

plot_distr_of_scores_when_correct_vs_not_correct("max_initial_population", color_column="Top 1")  # ,stat="percent")
plot_distr_of_scores_when_correct_vs_not_correct("max_initial_population", color_column="Top 5")  # ,stat="percent")
plot_distr_of_scores_when_correct_vs_not_correct("max_initial_population", color_column="Top 10")  # ,stat="percent")
plot_distr_of_scores_when_correct_vs_not_correct("max_initial_population", color_column="Found")  # ,stat="percent")


[32m2024-11-12 11:05:42.812[0m | [1mINFO    [0m | [36m__main__[0m:[36mplot_distr_of_scores_when_correct_vs_not_correct[0m:[36m74[0m - [1m728.3674094826654[0m


[32m2024-11-12 11:05:42.865[0m | [1mINFO    [0m | [36m__main__[0m:[36mplot_distr_of_scores_when_correct_vs_not_correct[0m:[36m74[0m - [1m728.3674094826654[0m


[32m2024-11-12 11:05:42.918[0m | [1mINFO    [0m | [36m__main__[0m:[36mplot_distr_of_scores_when_correct_vs_not_correct[0m:[36m74[0m - [1m728.3674094826654[0m


[32m2024-11-12 11:05:42.969[0m | [1mINFO    [0m | [36m__main__[0m:[36mplot_distr_of_scores_when_correct_vs_not_correct[0m:[36m74[0m - [1m728.3674094826654[0m


In [31]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from sklearn.calibration import calibration_curve


def plot_score_calibration(results_path="all_stats.csv", n_bins=10, top_k=5, strategy="uniform"):
    """
    Creates a calibration plot for best scores showing the relationship between
    predicted similarity scores and actual success rates.

    Args:
        results_path (str): Path to the CSV file with results
        n_bins (int): Number of bins for calibration curve
    """
    # Read and preprocess data
    stats = pd.read_csv(results_path)
    stats = stats[stats["best_score"] != 0]

    # Calculate calibration curve
    # Using Top 5 as the binary outcome and best_score as the prediction
    y_true = stats[f"{top_k}"].astype(bool)
    y_pred = stats["best_score"]
    # scale to 0-1
    y_pred = (y_pred - y_pred.min()) / (y_pred.max() - y_pred.min())
    prob_true, prob_pred = calibration_curve(y_true, y_pred, n_bins=n_bins, strategy=strategy)

    # Create the calibtion plot
    fig = go.Figure()

    # Add perfect calibration line
    fig.add_trace(
        go.Scatter(
            x=[0, 1], y=[0, 1], mode="lines", name="perfect calibration", line={"dash": "dash", "color": "gray"}, showlegend=True
        )
    )

    # Add calibration curve
    fig.add_trace(
        go.Scatter(
            x=prob_pred,
            y=prob_true,
            mode="lines+markers",
            name="pipeline calibration",
            line={"color": "rgb(152, 86, 86)"},
            marker={"size": 8, "color": "rgb(152, 86, 86)", "line": {"width": 1, "color": "DarkSlateGrey"}},
            showlegend=True,
        )
    )
    sizing = 250
    # Update layout
    fig.update_layout(
        font_family="CMU Sans Serif",
        plot_bgcolor="white",
        width=sizing * 1.6180339887,
        height=sizing,
        margin={"l": 0, "r": 0, "t": 0, "b": 0},
        xaxis={
            "title": "scaled cosine similarity",
            "range": [0, 1],
            "tickvals": np.arange(0, 1.1, 0.2),
        },
        yaxis={
            "title": f"top-{top_k} accuracy",
            "range": [0, 1],
            "tickvals": np.arange(0, 1.1, 0.2),
        },
    )
    fig.update_yaxes(
        showgrid=False,
        gridcolor="rgba(173, 216, 230, 0.3)",
        zeroline=False,
        showline=False,
        ticks="outside",
        tickwidth=1,
        tickcolor="rgb(120, 120, 120)",
        ticklen=5,
    )
    fig.update_xaxes(
        showgrid=False,
        gridcolor="rgba(173, 216, 230, 0.3)",
        zeroline=False,
        showline=False,
        ticks="outside",
        tickwidth=1,
        tickcolor="rgb(120, 120, 120)",
        ticklen=5,
    )

    # rotate x-ticks
    # fig.update_xaxes(tickangle=60)
    # Add histogram of predictions as small bars at bottom
    hist, bins = np.histogram(y_pred, bins=n_bins, range=(0, 1))
    hist = hist / hist.max() * 0.5  # Normalize to 10% of plot height

    # compute the accuracy of the calibration
    # accuracy = np.abs(prob_true - prob_pred).mean()
    # add text to the plot
    # fig.add_annotation(
    #     x=0.5,
    #     y=0.1,
    #     text=f"Calibration accuracy: {accuracy:.3f}",
    #     showarrow=False,
    #     font={"size": 12, "color": "black"},
    # )
    fig.update_layout(
        template="plotly_white",
        width=sizing * 1.618,
        height=sizing,
        margin={"l": 40, "r": 20, "t": 20, "b": 40},
        plot_bgcolor="white",
        font={"color": "rgb(120, 120, 120)"},
        bargap=0.15,
        bargroupgap=0.03,
        shapes=[
            # Add x-axis line
            {
                "type": "line",
                "xref": "paper",
                "yref": "y",
                "x0": 0,
                "x1": 1,
                "y0": 0,
                "y1": 0,  # Position below the bottom bar
                "line": {"color": "rgb(120, 120, 120)", "width": 1},
            },
            # Add y-axis line
            {
                "type": "line",
                "xref": "x",
                "yref": "paper",
                "x0": 0,
                "y0": 0,
                "x1": 0,
                "y1": 1,
                "line": {"color": "rgb(120, 120, 120)", "width": 1},
            },
        ],
    )
    fig.add_trace(
        go.Bar(
            x=bins[:-1] + np.diff(bins) / 2,
            y=hist,
            marker_color="rgba(152, 86, 86, 0.3)",
            name="score distribution",
            width=np.diff(bins)[0],
            hoverinfo="skip",
        )
    )
    # change font of all text
    # Save and display
    fig.write_image(f"figures/best_score_calibration_{strategy}.pdf", scale=5)
    fig.show()

    # Print some statistics
    print(f"Number of samples: {len(stats)}")
    print(f"Average prediction: {y_pred.mean():.3f}")
    print(f"Success rate: {y_true.mean():.3f}")


# Usage
plot_score_calibration(top_k=1, strategy="uniform", n_bins=11)
plot_score_calibration(top_k=1, strategy="quantile", n_bins=9)

Number of samples: 1001
Average prediction: 0.780
Success rate: 0.760


Number of samples: 1001
Average prediction: 0.780
Success rate: 0.760


In [32]:
import plotly.graph_objects as go

stats_999 = pd.read_csv("all_stats.csv")
# Create the DataFrame
molsize_df = pd.DataFrame(
    {
        "molecule_size": [
            stats_999["mol_size"],
        ],
    }
)

# Create a Plotly figure with the same color scheme
fig = go.Figure()

# Colors inspired by seaborn for a more scientific look
# colors

# Adding traces for each seed's molecule size distribution as histograms
# update size of the plot
fig.update_layout(width=200 * 1.6180339887, height=200)
# kde curve
fig.add_trace(
    go.Histogram(
        x=molsize_df["molecule_size"][0],
        histnorm="probability",
        name=f"Seed {444}",
        opacity=1,
        # marker_color=colors[-2],
        nbinsx=30,
    )
)
# change color of the bars
fig.update_traces(marker_color="rgb(152, 86, 86)")
# change font
fig.update_layout(font_family="CMU Sans Serif")
# opacity of the bars
fig.update_traces(opacity=0.6)
fig.update_traces(marker={"line": {"width": 0.4, "color": "DarkSlateGrey"}})
# Update layout for the plot
fig.update_layout(
    title="molecule size distribution",
    xaxis_title="number of atoms",
    yaxis_title="percentage of molecules",
    barmode="overlay",
    plot_bgcolor="white",
    legend={"title": "Seeds"},
    font={"family": "Arial", "size": 12},
    margin={"l": 0, "r": 0, "t": 0, "b": 0},
)
fig.write_image("figures/molecule_size_distribution.pdf", scale=5)
# Show the plot
fig.show()
