In [21]:
import pandas as pd
import time, os
from usefull import correctLabels
from usefull import plot_metrics
import seaborn as sns
import matplotlib.pyplot as plt

In [11]:
def correctLabels(datasetName, modelName):
    df1 = readDatasetWithSentenceId(datasetName, "test")
    df2 = pd.read_csv(f'results/{datasetName}/{datasetName}_{modelName }_predictions.csv')
    df1.labels= df1.labels.str.replace("I-", "").str.replace("B-", "")
    df2.labels= df2.labels.str.replace("I-", "").str.replace("B-", "")
    custom_labels = list(set(list(df1.labels.unique()) + list(df2.labels.unique())))
    report = classification_report(df1.labels, df2.labels, labels=custom_labels, target_names=custom_labels, output_dict=True)
    data = {
        "modelName": [modelName ],
        "datasetName": [datasetName],
        "accuracy_global": [round(report['weighted avg']['precision']*100, 2)],
        "recall_global": [round(report['weighted avg']['recall']*100, 2)],
        "f1_score_global": [round(report['weighted avg']['f1-score'] *100, 2)]
    }
    for label in custom_labels:
        data[f"accuracy_{label}"] = [round(report[label]['precision']*100, 2)]
        data[f"recall_{label}"] = [round(report[label]['recall']*100, 2)]
        data[f"f1_score_{label}"] = [round(report[label]['f1-score']*100, 2)]

    df = pd.DataFrame(data)
    df.to_csv(f'results/{datasetName}/{datasetName}_{modelName }_eval_corrected.csv', sep=",", index=False)
    return df

In [26]:
def plot_metrics(df, metric, sufixe):
    
    # Extract column names starting  with '_accuracy'
    metric_columns = [col for col in df.columns if col.startswith(f'{metric}_')]

    # Set up Seaborn with a diverging color palette
    sns.set(style="whitegrid")
    colors = sns.color_palette("RdYlGn", len(metric_columns))

    # Plotting the bar chart with seaborn
    plt.figure(figsize=(10, 6))
    print("l12",df.iloc[0],metric_columns)
    ax = sns.barplot(x=metric_columns, y=df.iloc[0][metric_columns], hue=metric_columns, palette=colors)

    # Customize plot appearance
    plt.title(f'{model_name} {sufixe} {metric} Comparison for {dataset_name}', fontsize=16)
    plt.xlabel('Entity type', fontsize=14)
    plt.ylabel(metric, fontsize=14)
    plt.ylim(0, 110)  # Set y-axis limit to percentages (0-100)

    # Rotate x-axis labels for better readability
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')

    # Add legend
    plt.legend([col.replace(f'{metric}_', "") for col in metric_columns], title="Labels List", loc="upper left")

    # Add grid lines
    plt.grid(True, linestyle='--', alpha=0.7)

    # Increase tick label font size
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)

    # Save the plot with tight layout to prevent label overlaps
    plt.tight_layout()

In [None]:
def plot_metrics2(df, metricList, sufixe, model):
    
    # Extract column names starting  with '_accuracy'
    metric_columns = [col for col in df.columns if col.startswith(f'{metric}_')]

    # Set up Seaborn with a diverging color palette
    sns.set(style="whitegrid")
    colors = sns.color_palette("RdYlGn", len(metric_columns))

    # Plotting the bar chart with seaborn
    plt.figure(figsize=(10, 6))
    print("l12",df.iloc[0],metric_columns)
    
    for metric in metrics:
        
    
    ax = sns.barplot(x=metric_columns, y=df.iloc[0][metric_columns], hue=metric_columns, palette=colors)

    # Customize plot appearance
    plt.title(f'{model_name} {sufixe} {metric} Comparison for {dataset_name}', fontsize=16)
    plt.xlabel('Entity type', fontsize=14)
    plt.ylabel(metric, fontsize=14)
    plt.ylim(0, 110)  # Set y-axis limit to percentages (0-100)

    
    
    
    
    # Rotate x-axis labels for better readability
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')

    # Add legend
    plt.legend([col.replace(f'{metric}_', "") for col in metric_columns], title="Labels List", loc="upper left")

    # Add grid lines
    plt.grid(True, linestyle='--', alpha=0.7)

    # Increase tick label font size
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)

    # Save the plot with tight layout to prevent label overlaps
    plt.tight_layout()

In [28]:
df.columns

Index(['modelName', 'datasetName', 'accuracy_global', 'recall_global',
       'f1_score_global', 'accuracy_Entity', 'recall_Entity',
       'f1_score_Entity', 'accuracy_O', 'recall_O', 'f1_score_O'],
      dtype='object')

In [27]:
#  loop over dataset and model to generate plots
modelList = ["Bio-bert-based", "bert-based", "spark-nlp"]
metricList = ["accuracy", "recall", "f1_score"]
for dataset in os.listdir("datasets/"):
    print(dataset)
    for model in modelList:
        print(model)
        df = correctLabels(datasetName=dataset, modelName=model)
        plot_metrics2(df, metricList, sufixe, model)

bc5cdr
Bio-bert-based
l12 modelName          Bio-bert-based
datasetName                bc5cdr
accuracy_global             94.58
recall_global               94.76
f1_score_global             94.49
accuracy_Entity             89.14
recall_Entity                69.8
f1_score_Entity             78.29
accuracy_O                  95.43
recall_O                    98.67
f1_score_O                  97.02
Name: 0, dtype: object ['accuracy_global', 'accuracy_Entity', 'accuracy_O']


KeyError: "None of [Index([('accuracy_global', 'accuracy_Entity', 'accuracy_O')], dtype='object')] are in the [index]"

<Figure size 1000x600 with 0 Axes>