In [None]:
%load_ext autoreload
%autoreload 2
import dt4dds_benchmark
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd

data = dt4dds_benchmark.analysis.Dataset.combine(*[dt4dds_benchmark.pipelines.HDF5Manager(f'./data/{s}.hdf5').get_data() for s in (
    'basic',
    'cdhit',
    'clover',
    'lsh',
    'mmseqs2',
    'starcode',
)])

### get the results, merge with performance data, and normalize to base scenario

In [None]:
df = data.combined_results
df['scenario'] = df.input_file.str.split('/').str[-1].str.replace('.txt', '')

df = df.merge(data.performances)

basedf = df.loc[df["clustering.type"] == "BasicSet"]
for val in ('sensitivity', 'total_foundreferences'):
    d = {scenario: basedf.loc[basedf.scenario == scenario][val].values[0] for scenario in basedf.scenario}
    df[f'base_{val}'] = df.scenario.map(d)
    
df['rel_sensitivity'] = df.sensitivity / df.base_sensitivity

### plot all the metrics

In [None]:
metrics = ["rel_sensitivity", "max_similarity", "mean_similarity", "specificity", "duration"]

for metric in metrics:
    fig = dt4dds_benchmark.analysis.plotting.tiered_bar(
        df,
        "clustering.type",
        "clustering.name",
        metric,
        color_by = "scenario",
    )
    fig.update_yaxes(
        title_text=metric,
        range=[0, 1] if metric != 'duration' else None,
    )
    fig.update_layout(
        width=1050,
        height=300,
        margin=dict(l=0, r=10, t=10, b=30),
        showlegend=False,
    )


    fig = dt4dds_benchmark.analysis.plotting.standardize_plot(fig)
    fig.show()
    fig.write_image(f'./figures/{metric}.svg')
    fig.write_image(f'./figures/{metric}.png', scale=2)

### generate the raw data as table

In [None]:
datadf = df[[ "clustering.type", "clustering.name", "scenario", "rel_sensitivity", "max_similarity", "mean_similarity", "specificity", "duration"]]
datadf = datadf.sort_values(by=["clustering.type", "clustering.name", "scenario"])

datadf

### limit to best performers of each clustering type

In [None]:
selectdf = datadf.copy()
selectdf['id'] = selectdf['clustering.type'] + "_" + selectdf['clustering.name']
selectdf = selectdf.loc[selectdf['id'].isin(['BasicSet_default', 'CDHit_id85', 'Clover_D15V4', 'LSH_default', 'MMseqs2_covmode1', 'Starcode_sphereD6'])]

selectdf

In [None]:
plotdf = selectdf.copy()
plotdf['speed'] = 1 - (plotdf['duration'] / 60)/30
plotdf = plotdf.drop(columns=['duration'])
plotdf = pd.merge(
    plotdf.loc[plotdf.scenario == 'exp_electrochemical_20x'],
    plotdf.loc[plotdf.scenario == 'exp_material_20x'],
    on=['id', 'clustering.type', 'clustering.name'],
    suffixes=('_mat', '_elec')
)
plotdf.drop(columns=['scenario_mat', 'scenario_elec'], inplace=True)
plotdf = plotdf.melt(id_vars=['clustering.type', 'clustering.name', 'id'], value_vars=[
    "rel_sensitivity_mat", 
    "max_similarity_mat", 
    "specificity_mat", 
    "speed_mat",
    "speed_elec",
    "specificity_elec", 
    "max_similarity_elec", 
    "rel_sensitivity_elec", 
    ], var_name='metric', value_name='value')


fig = px.line_polar(
    plotdf, 
    r='value', 
    theta='metric', 
    color='clustering.type', 
    line_close=True,
    start_angle=90+360/8/2,
    direction='counterclockwise',
    category_orders={'clustering.type': ['BasicSet', 'Starcode', 'MMseqs2', 'Clover', 'LSH', 'CDHit']},
    color_discrete_map={'BasicSet': '#636363', 'Starcode': '#31a354', 'Clover': '#756bb1', 'LSH': '#3182bd', 'MMseqs2': '#e6550d', 'CDHit': '#de2d26'},
)
fig.update_polars(
    angularaxis_showgrid=True,
    angularaxis_gridwidth=2,
    angularaxis_tickfont_size=28/3,
    radialaxis_showgrid=True,
    radialaxis_showline=False,
    radialaxis_ticks="",
    radialaxis_showticklabels=False,
    # gridshape='linear',
)
for trace in fig.data:
    trace.line.width = 2.5
fig.update_layout(
    width=200,
    height=200,
    margin=dict(l=20, r=20, t=20, b=20),
    showlegend=False,
)
fig = dt4dds_benchmark.analysis.plotting.standardize_plot(fig)

fig.write_image('./figures/combined_polar.svg')
fig.write_image('./figures/combined_polar.png', scale=2)
fig.show()