In [1]:
import json
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import jsonlines
import altair as alt
from vega_datasets import data


In [None]:
columns_to_ignore = ["fpr_at_thresholds", "tpr_at_thresholds", "thresholds"]

#detectors = ["gpt_zero"]
detectors = ["fast_detect_gpt", "roberta_base_open_ai", "gpt_zero", "electra_large/full_finetuning/fake_true_dataset_mistral_10k/10_06_1242",
             "electra_large/full_finetuning/fake_true_dataset_round_robin_10k/10_06_1308", "watermark_detector"]
chat_model = "zephyr"

attacks = ["cnn_style", "in_context_example", "paraphrased_llm", "repetition_penalty_1.2", "temperature_1.2"]

attack_dfs = {}

for attack in attacks:
    
    # create df for the attack
    #attack_dfs[attack] = pd.read_json(f"results/{attack}/test_logs.jsonl", lines=True)
    
    cur_attack_dict = []
    for detector in detectors:
        
        # merge the pandas df of all the detectors for the given attack by filling cur_attack_dict line by line
        json_path = f"detection_test_results/{detector}/cnn_dailymail/{attack}_KGW/test/test_metrics_fix_bug.json"
        detector_df = pd.read_json(json_path)
        
        # add column detector to the detector df
        detector_df["detector"] = detector
        
        # add all the lines of the detector df to the cur_attack_dict
        for i, row in detector_df.iterrows():
            cur_attack_dict.append(row)
    
    # create the df for the attack
    attack_dfs[attack] = pd.DataFrame(cur_attack_dict)
    
# merge the dfs of the attacks by adding a column attack
for attack in attacks:
    attack_dfs[attack]["attack"] = attack

attack_df = pd.concat(attack_dfs.values())

In [None]:
# rename column "tp_rate_at_given_threshold" to "TPR"
attack_df = attack_df.rename(columns={"tp_rate_at_given_threshold": "TPR"})
attack_df["chat_model"] = chat_model

# rename detector names
detector_name_to_short_name = {"fast_detect_gpt": "fast_detect_gpt",
    "roberta_base_open_ai": "roberta_open_ai",
    "gpt_zero": "gpt_zero",
    "electra_large/full_finetuning/fake_true_dataset_mistral_10k/10_06_1242": "electra_mistral",
    "electra_large/full_finetuning/fake_true_dataset_round_robin_10k/10_06_1308": "electra_RR",}

attack_short_names = {"cnn_style": "CNN style",
                        "in_context_example": "In context example",
                        "paraphrased_llm": "Paraphrasing",
                        "repetition_penalty_1.2": "Repetition penalty",
                        "temperature_1.2": "Temperature"}


# add column detector_short_name
attack_df["detector_short_name"] = attack_df["detector"].apply(lambda x: detector_name_to_short_name[x])
attack_df["attack_short_name"] = attack_df["attack"].apply(lambda x: attack_short_names[x])


# keep unique rows. This is very important for vega altair, otw. blurry plots
attack_df = attack_df.drop_duplicates(subset=["detector", "attack", "TPR", "chat_model"])


heatmap = alt.Chart(attack_df).mark_rect().encode(
    alt.X('attack_short_name:N', sort="x", title="Evasion attack"),
    alt.Y('detector_short_name:N', sort=None, title="Tested detector"),
    alt.Color('TPR:Q').scale(scheme='redyellowgreen', domain=[0.1, 1]),
    #alt.Row("trained_on_dataset:N", title="Dataset used for training"),
).properties(
    width=300,
    height=300
)


heatmap_text = alt.Chart(attack_df).mark_text(baseline='middle').encode(
    alt.X('attack_short_name:N', sort=None, title="Evasion attack"),
    alt.Y('detector_short_name:N', sort=None, title="Tested detector"),
    text='TPR:Q',
    color=alt.condition(
        alt.datum.accuracy > 0.0,
        alt.value('black'),
        alt.value('white')
    )
).properties(
    width=300,
    height=300      
)

chart = alt.layer(heatmap, heatmap_text).facet(
    column=alt.Column("chat_model:N", title="Chat model")
).configure(
    numberFormat='0.2f'
).configure_axis(
    labelFontSize=18,
    titleFontSize=18
).configure_legend(
    labelFontSize=18,
    titleFontSize=18,
    titleLimit=0
).configure_header(
    titleFontSize=18,
    labelFontSize=18
).configure_text(
    fontSize=14,
    font="Arial",
    fontWeight="bold"
)

chart.save("notebooks/plots/heatmap_all_attacks.png")
chart