In [None]:
import matplotlib.pyplot as plt
import torch
import os

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']
plt.rcParams['text.usetex'] = False
# plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}'

from lib.utils import get_results, calc_grad_l2_norm, load_dataset
name_map = {"gpt2": "GPT-2", "pythia": "Pythia", "gpt-neo": "GPT-Neo"}

In [None]:
from lib.knockoff import evalute_knockoff
import numpy as np


num_knockoffs = 10
finetuned = True

# dataset_name = "xsum"
saved_dir = os.path.abspath("./saved_tensors/grad_norms")
figure_dir = os.path.abspath("./figures")



for i, dataset_name in enumerate(["wiki", "xsum", "bbc"]):
    dataset, input_column_name = load_dataset(dataset_name, 10)

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 6), sharex=True)

    for model_name in ["gpt2", "pythia", "gpt-neo"]:
        saved_file_name = f"{model_name}-{dataset_name}-{num_knockoffs}-{finetuned}.pt"
        ft_origin_grad, ft_nkf_grad = torch.load(os.path.join(saved_dir, saved_file_name), map_location="cpu")

        ft_origin_l2_norm = calc_grad_l2_norm(ft_origin_grad)
        ft_nkf_l2_norm = calc_grad_l2_norm(ft_nkf_grad)

        results = get_results(ft_origin_l2_norm, ft_nkf_l2_norm)

        q_list = np.arange(1, 20) * 0.05
        m_list = [evalute_knockoff(-results, dataset["label"], q) for q in q_list]

        power_list = [m[2] for m in m_list]
        fdr_list = [1 - m[1] for m in m_list]

        ax1.plot(q_list, power_list, "o-", label=f"{name_map[model_name]}", markersize=4)
        # if i == 0:
        ax1.set_ylabel("Power", fontsize=16)
        ax1.tick_params(axis='x', which='both', labelbottom=True)
        # ax1.set_xticks(q_list)
        ax2.plot(q_list, fdr_list, "o-", label=f"{name_map[model_name]}", markersize=4)
        # if i == 0:
        ax2.set_ylabel("FDR", fontsize=16)
        ax2.set_xticks(np.arange(1, 10) * 0.1)
        # plot_two_hist(-results, dataset)
        ax1.tick_params(axis='x', labelsize=14)
        ax1.tick_params(axis='y', labelsize=14)

        ax2.tick_params(axis='x', labelsize=14)
        ax2.tick_params(axis='y', labelsize=14)

    ax2.plot(q_list, q_list, "o-", label="Bound", markersize=4)
    ax2.legend(fontsize=16)
    ax2.set_xlabel("q", fontsize=16)
    fig.tight_layout()
    # fig.savefig(os.path.join(figure_dir, f"{dataset_name}_fdr_control.pdf"), bbox_inches='tight', pad_inches=0)
    

In [3]:
from datasets import load_from_disk

ds = load_from_disk("./datasets/xsum")

In [4]:
ds

Dataset({
    features: ['document', 'summary', 'id', 'label'],
    num_rows: 11332
})