In [1]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import ast
from plotly.subplots import make_subplots
import warnings
import re

warnings.filterwarnings("ignore")

In [2]:
def clean_np_literals(s):
    if not isinstance(s, str):
        return s
    # Tisztítjuk a np.float64(…) és np.int64(…) hívásokat
    s = re.sub(r'np\.float64\(([^)]+)\)', r'\1', s)
    s = re.sub(r'np\.int64\(([^)]+)\)', r'\1', s)
    return s

In [3]:
centroid_df = pd.read_excel("../data/results/hyperparameter_tuning_centroid_vs_full/centroid_results_kmeans500_v2_l2.xlsx")
centroid_df['centroid_metrics'] = (centroid_df['centroid_metrics'].apply(clean_np_literals).apply(ast.literal_eval))
full_df = pd.read_excel("../data/results/hyperparameter_tuning_centroid_vs_full/full_results_kmeans500_v2_l2.xlsx")
full_df['full_metrics'] = (full_df['full_metrics'].apply(clean_np_literals).apply(ast.literal_eval))

centroid_df['centroid_metrics'] = centroid_df['centroid_metrics'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
full_df['full_metrics'] = full_df['full_metrics'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
centroid_df['x_axis'] = centroid_df['top_k'].astype(str) + "_" + centroid_df['top_n_clusters'].astype(str)

In [4]:
centroid_df.head()

Unnamed: 0.1,Unnamed: 0,top_k,top_n_clusters,centroid_metrics,centroid_time,x_axis
0,0,3,5,"{'doc_accuracy': 0.5764677678968938, 'chunk_ac...",2201.52147,3_5
1,1,3,10,"{'doc_accuracy': 0.5995616388314935, 'chunk_ac...",2382.507434,3_10
2,2,3,20,"{'doc_accuracy': 0.6154636468452837, 'chunk_ac...",2775.859128,3_20
3,3,3,35,"{'doc_accuracy': 0.6242080389045537, 'chunk_ac...",3575.666577,3_35
4,4,5,5,"{'doc_accuracy': 0.5978492905170151, 'chunk_ac...",2254.86549,5_5


In [5]:
centroid_df.loc[0, "centroid_metrics"]

{'doc_accuracy': 0.5764677678968938,
 'chunk_accuracy': 0.6632153335083734,
 'doc_precision': 0.8496396817969115,
 'doc_recall': 0.6096159011549825,
 'doc_f1': 0.7098877898111584,
 'chunk_precision': 1.0,
 'chunk_recall': 0.6037872683319904,
 'chunk_f1': 0.7529518163090991,
 'correct_chunk_accuracy': 0.26499541471173566,
 'doc_true_positives': 45392,
 'doc_true_negatives': 5106,
 'doc_false_positives': 8033,
 'doc_false_negatives': 29068,
 'chunk_true_positives': 44958,
 'chunk_true_negatives': 13139,
 'chunk_false_positives': 0,
 'chunk_false_negatives': 29502}

In [6]:
groups = {
    "Accuracy": ["doc_accuracy", "chunk_accuracy"],
    "F1 Score": ["doc_f1", "chunk_f1"],
    "Precision": ["doc_precision", "chunk_precision"],
    "Recall": ["doc_recall", "chunk_recall"],
    "Correct chunk accuracy": ["correct_chunk_accuracy"],
}

for title, keys in groups.items():
    fig = make_subplots(rows=1, cols=2, subplot_titles=(f"Centroid - {title}", f"Full - {title}"))

    for key in keys:
        fig.add_trace(go.Scatter(
            x=centroid_df['x_axis'],
            y=centroid_df['centroid_metrics'].apply(lambda m: m.get(key)),
            mode='lines+markers',
            name=key
        ), row=1, col=1)

        fig.add_trace(go.Scatter(
            x=full_df['top_k'],
            y=full_df['full_metrics'].apply(lambda m: m.get(key)),
            mode='lines+markers',
            name=key
        ), row=1, col=2)

    fig.update_yaxes(range=[0, 1.1], row=1, col=1)
    fig.update_yaxes(range=[0, 1.1], row=1, col=2)

    fig.update_xaxes(title_text='top k with top n clusters', row=1, col=1)
    fig.update_xaxes(title_text='top k', row=1, col=2)

    fig.show()