In [2]:
import json
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import jsonlines
import altair as alt
from vega_datasets import data



os.chdir("../")

SRC_PATH = ["src"]
for module_path in SRC_PATH:
    if module_path not in sys.path:
        sys.path.append(module_path)

from utils import *

In [3]:
dataset_names = ["phi", "gemma", "mistral", "gemma_chat", "zephyr", "llama3", "round_robin"]
training_method = "full_finetuning"
trained_on_models = {"distil_roberta-base": {"03_05_1735": "phi", "03_05_1741": "gemma", "03_05_1748": "mistral", "03_05_1754": "round_robin"},
                    "roberta_large": {"03_05_1845": "phi", "03_05_1910": "gemma", "03_05_1935": "mistral", "03_05_2001": "round_robin"},
                    "electra_large": {"03_05_1842": "phi", "03_05_1909": "gemma", "03_05_1935": "mistral", "03_05_2001": "round_robin"}}


freeze_base_df = create_df_from_test_logs("full_finetuning", trained_on_models, dataset_names)

freeze_base_df = freeze_base_df.sort_values(by="trained_on_dataset")
dataset_order = ["phi", "gemma", "mistral", "round_robin", "gemma_chat", "zephyr", "llama3"]
freeze_base_df = freeze_base_df.set_index("dataset").loc[dataset_order].reset_index()
detector_name_to_short_name = {"distil_roberta-base": "distil", "roberta_large": "roberta", "electra_large": "electra"}
freeze_base_df["detector_short_name"] = freeze_base_df["base_detector"].apply(lambda x: detector_name_to_short_name[x])
# set detector_name as f"{detector_short_name}_{trained_on_dataset}"
freeze_base_df["detector_name"] = freeze_base_df["detector_short_name"] + "_" + freeze_base_df["trained_on_dataset"]
freeze_base_df.head()

Unnamed: 0,dataset,accuracy,precision,recall,f1_score,fp_rate,std_accuracy,std_precision,std_recall,std_f1_score,std_fp_rate,TP,TN,FP,FN,base_detector,trained_on_dataset,detector,detector_short_name,detector_name
0,phi,0.947539,0.921121,0.978991,0.949152,0.083968,0.00489,0.008095,0.004428,0.004867,0.008634,975.819,911.679,83.563,20.939,roberta_large,gemma,roberta_large_gemma,roberta,roberta_gemma
1,phi,0.964984,0.957048,0.97371,0.96529,0.043752,0.004102,0.006376,0.005173,0.00419,0.006385,970.558,951.691,43.551,26.2,electra_large,gemma,electra_large_gemma,electra,electra_gemma
2,phi,0.945348,0.942504,0.948649,0.945541,0.057955,0.005259,0.007493,0.007078,0.00537,0.007521,945.572,937.561,57.681,51.186,distil_roberta-base,gemma,distil_roberta-base_gemma,distil,distil_gemma
3,phi,0.949168,0.928045,0.973925,0.950408,0.075626,0.004905,0.008146,0.004834,0.004921,0.008523,970.769,919.974,75.268,25.989,electra_large,mistral,electra_large_mistral,electra,electra_mistral
4,phi,0.955619,0.932033,0.982984,0.956811,0.071784,0.004458,0.007439,0.004227,0.004445,0.007775,979.794,923.799,71.443,16.964,roberta_large,mistral,roberta_large_mistral,roberta,roberta_mistral


In [9]:
heatmap = alt.Chart(freeze_base_df).mark_rect().encode(
    alt.X('dataset:N', sort=None, title="Dataset used for testing"),
    alt.Y('detector_short_name:N', sort=None, title="Detector"),
    alt.Color('accuracy:Q').scale(scheme='redyellowgreen', domain=[0.85, 1]),
    #alt.Row("trained_on_dataset:N", title="Dataset used for training"),
).properties(
    width=200,
    height=200
)

heatmap_text = alt.Chart(freeze_base_df).mark_text(baseline='middle').encode(
    alt.X('dataset:N', sort=None, title="Dataset used for testing"),
    alt.Y('detector_short_name:N', sort=None, title="Detector"),
    text='accuracy:Q',
    color=alt.condition(
        alt.datum.accuracy > 0.5,
        alt.value('black'),
        alt.value('white')
    )
).properties(
    width=200,
    height=200
)

plot = alt.layer(heatmap, heatmap_text).facet(
    column=alt.Column("trained_on_dataset:N", title="Dataset used for training")
).configure(
    numberFormat='0.2f'
).configure_axis(
    labelFontSize=18,
    titleFontSize=18
).configure_legend(
    labelFontSize=18,
    titleFontSize=18
).configure_header(
    titleFontSize=18,
    labelFontSize=18
)

# save the plot
plot.save("notebooks/plots/heatmap_full.png")

plot