In [None]:
import pandas as pd
import json
from datetime import datetime
import plotly.express as px
import numpy as np
from pathlib import Path

In [None]:
input_csv = "./data/generic_cluster_labels.csv"
output_dir = "test"
compound_data_csv = "./data/unique_compounds.csv"
date_json = "/Users/alexpayne/Scientific_Projects/asapdiscovery-sars-retrospective/science/20240403_multi_pose_docking_v2/20240430_analyze_cross_docking_results/20240503_inputs_analysis/date_dict.json"
figures = Path("./figures")

In [None]:
def date_processor(date_string):
    if type(date_string) == str and not date_string == "None":
        try:
            return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
        except ValueError:
            return datetime.strptime(date_string, "%d/%m/%Y %H:%M")
    else:
        return None

In [None]:
with open(date_json, "r") as f:
    date_dict = [
        {"Name": name, "Date": date_processor(date)}
        for name, date in json.load(f).items()
    ]
    date_df = pd.DataFrame.from_records(date_dict)

compound_data = pd.read_csv(compound_data_csv)

compound_data = compound_data.merge(
    date_df, left_on="structure_name", right_on="Name"
)

df = pd.read_csv(input_csv)

df = df.merge(compound_data, on="compound_name", how="left")

## count number of structures per cluster

In [None]:
cluster_counts = df.groupby('cluster_id').count().reset_index()[['cluster_id', 'compound_name']]
cluster_counts.columns = ['cluster_id', 'count']

In [None]:
df_ccounts = pd.merge(df, cluster_counts, on='cluster_id', how='left')

## remove singlets

In [None]:
no_singlets = df_ccounts[df_ccounts['count'] > 1]

In [None]:
no_singlets_ccs = cluster_counts[cluster_counts['count'] > 1]

# Plot Bar Chart

In [None]:
int(cluster_counts['count'].max())

In [None]:
int(no_singlets_ccs['count'].max())

In [None]:
cluster_counts['count'].max()

In [None]:
cluster_counts['count'].min()

In [None]:
def plot_bar_chart(df):
    fig = px.histogram(df, x="count", template="simple_white", height=600, width=800, log_y=True, text_auto=True, nbins=int(df['count'].max()))
    fig.update_yaxes(title_text="<b> Number of Clusters </b>", tickvals=[(10**big)*small for big in range(0, 3) for small in [1,2,4,6,8]])
    fig.update_xaxes(title_text="<b> Number of Compounds in Cluster </b>")
    return fig

In [None]:

# plot_bar_chart(no_singlets_ccs)
fig = plot_bar_chart(cluster_counts)
fig.update_traces(xbins=dict( # bins used for histogram
        start=0.0,
        end=65.0,
        size=1
    ))
fig.update_xaxes(range=[0, 10], tickvals=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9,])
fig.show()

In [None]:
hist_data = cluster_counts.groupby('count').count().reset_index()
hist_data.columns = ['count', 'number_of_clusters']

In [None]:
fig1 = px.bar(hist_data, x='count', y='number_of_clusters', template='simple_white', log_y=True, text_auto=True, height=600, width=400)
fig1.update_xaxes(title_text="<b> Number of Compounds in Cluster </b>", range=[0, 10], tickvals=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
fig1.update_yaxes(title_text="<b> Number of Clusters </b>", tickvals=[(10**big)*small for big in range(0, 3) for small in [1,2,4,6,8]])
fig1.show()
fig1.write_image(figures / f"20250114_generic_bm_cluster_histogram_fig1.png")
fig1.write_image(figures / f"20250114_generic_bm_cluster_histogram_fig1.svg")

In [None]:
hist_data

In [None]:
cut_hist_data = hist_data[hist_data['count'] > 10]
fig2 = px.bar(cut_hist_data, x='count', y='number_of_clusters', template='simple_white', text_auto=True, height=600, width=400)
fig2.update_xaxes(title_text="<b> Number of Compounds in Cluster </b>")
fig2.update_yaxes(title_text="<b> Number of Clusters </b>")
fig2.show()

In [None]:
cut_hist_data['my_index'] = [i for i in [3,2,1,0]]

In [None]:
cut_hist_data['values'] = [f'Cluster {i} - {cut_hist_data.iloc[np.abs(3-i)]["count"]} Molecules' for i in [3,2,1,0]] 

In [None]:
cut_hist_data

In [None]:
fig = px.treemap(cut_hist_data, path=['values'], values='count', template='simple_white', height=600, width=600)
fig.update_layout(margin = dict(t=25, l=25, r=25, b=25))
fig.show()
fig.write_image(figures / f"20250114_generic_bm_cluster_histogram_fig2_treemap.png")
fig.write_image(figures / f"20250114_generic_bm_cluster_histogram_fig2_treemap.svg")

In [None]:
cut_hist_data['simple_values'] = [f'Cluster {i}' for i in [3,2,1,0]]
fig2 = px.bar(cut_hist_data, y='simple_values', x='count', template='simple_white', text_auto=True, height=600, width=400)
fig2.update_xaxes(title_text="<b> Number of Compounds in Cluster </b>")
fig2.update_yaxes(title_text="<b> Cluster </b>")
fig2.show()
fig2.write_image(figures / f"20250114_generic_bm_cluster_histogram_fig2_sideways_bar.png")
fig2.write_image(figures / f"20250114_generic_bm_cluster_histogram_fig2_sideways_bar.svg")

In [None]:
cut_hist_data['simple_values'] = [f'Cluster {i}' for i in [3,2,1,0]]
fig2 = px.bar(cut_hist_data, x='simple_values', y='count', template='simple_white', text_auto=True, height=600, width=400)
fig2.update_yaxes(title_text="<b> Number of Compounds in Cluster </b>")
fig2.update_xaxes(title_text="<b> Cluster </b>")
fig2.show()
fig2.write_image(figures / f"20250114_generic_bm_cluster_histogram_fig2_bar.png")
fig2.write_image(figures / f"20250114_generic_bm_cluster_histogram_fig2_bar.svg")

In [None]:
# combine fig1 and fig2 in a sidebyside plot
from plotly.subplots import make_subplots
fig = make_subplots(rows=1, cols=2, column_widths=[0.5, 0.5])
fig1_trace = fig1['data']
fig2_trace = fig2['data']
combined_traces = fig1_trace + fig2_trace
from plotly import graph_objects as go
combined_fig = go.Figure(data=combined_traces)

In [None]:
combined_fig

# plot scaffolds over time

In [None]:
def make_image(df):
    import plotly.express as px

    large_font = 24
    small_font = 18

    fig = px.ecdf(
        df,
        x="Date",
        color="cluster_id",
        ecdfnorm=None,
        template="simple_white",
        height=600,
        width=800,
    )
    # update legend title
    fig.update_layout(legend_title_text="<b> Bemis-Murcko Cluster </b>")
    fig.update_xaxes(title_text="<b> Date of Crystal Structure Collection </b>")
    fig.update_yaxes(title_text="<b> Cumulative Number of Structures </b>")

    update_layout_dict = dict(
        xaxis=dict(
            title_font=dict(size=large_font),
            color="black",
        ),
        yaxis=dict(
            # range=(0,1),
            title_font=dict(size=large_font),
            color="black",
        ),
    )

    # move legend to inside the plot
    fig.update_layout(
        legend=dict(yanchor="bottom", y=0.25, xanchor="right", x=1.1),
        **update_layout_dict,
    )

    return fig


## get rid of legend

In [None]:
def make_image(df):
    import plotly.express as px

    large_font = 24
    small_font = 18

    fig = px.ecdf(
        df,
        x="Date",
        color="cluster_id",
        ecdfnorm=None,
        template="simple_white",
        height=600,
        width=800,
        # color_discrete_sequence=px.colors.qualitative.Plotly,
    )
    # update legend title
    # fig.update_layout(legend_title_text="<b> Bemis-Murcko Cluster </b>")
    fig.update_xaxes(title_text="<b> Date of Crystal Structure Collection </b>")
    fig.update_yaxes(title_text="<b> Cumulative Number of Structures </b>")

    update_layout_dict = dict(
        xaxis=dict(
            title_font=dict(size=large_font),
            color="black",
        ),
        yaxis=dict(
            # range=(0,1),
            title_font=dict(size=large_font),
            color="black",
        ),
    )
    fig.update_layout(showlegend=False)

    # # move legend to inside the plot
    # fig.update_layout(
    #     legend=dict(yanchor="bottom", y=0.25, xanchor="right", x=1.1),
    #     **update_layout_dict,
    # )

    return fig

In [None]:
nothing_less_than_4 = no_singlets[no_singlets['count'] > 3]

In [None]:
fig = make_image(nothing_less_than_4)
fig.show()
fig.write_image(figures / f"20250114_generic_bm_cluster_over_time.png")
fig.write_image(figures / f"20250114_generic_bm_cluster_over_time.svg")