In [None]:
from pathlib import Path

while Path.cwd().name != 'black-box-api-challenges':
    %cd ..

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import scipy
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import linregress

from tqdm.auto import tqdm
from IPython.display import display
from utils.constants import PERSPECTIVE_API_ATTRIBUTES_LOWER

sns.set(context="paper", style="white", font_scale=2.0, palette="RdBu")
Path("images").mkdir(exist_ok=True, parents=True)

## Scores - Original vs. Rescored

### Full sequences

In [3]:
full_rescored = pd.read_json("data/real-toxicity-prompts/rtp_joint_sequences_rescored.jsonl", lines=True)

In [None]:
bins = list(np.round(np.arange(0, 1.1, 0.25), 2))

full_rescored["bin"] = pd.cut(full_rescored["toxicity"], bins=bins)

normalize = True
full_rescored["bin"].value_counts(normalize=normalize).to_frame().round(2)

### Prompts

In [5]:
prompt_rescored = pd.read_json("data/real-toxicity-prompts/prompts_feb2023.jsonl", lines=True)
original = pd.read_json("data/realtoxicityprompts-data/prompts.jsonl", lines=True)

In [6]:
def display_stats(rescored, original, column):
    stats = pd.DataFrame({
        ("rescored", "toxic"): (pd.json_normalize(rescored[column])['toxicity'] > 0.5).sum(),
        ("rescored", "non-toxic"): (pd.json_normalize(rescored[column])['toxicity'] <= 0.5).sum(),
        ("original", "toxic"): (pd.json_normalize(original[column])['toxicity'] > 0.5).sum(),
        ("original", "non-toxic"): (pd.json_normalize(original[column])['toxicity'] <= 0.5).sum(),
    }, index=[f"# {column}"])
    display(stats)
    stats = pd.DataFrame({
        ("rescored"): [
            pd.json_normalize(rescored[column])['toxicity'].mean(), 
            pd.json_normalize(rescored[column])['toxicity'].std()],
        ("original"): [
            pd.json_normalize(original[column])['toxicity'].mean(), 
            pd.json_normalize(original[column])['toxicity'].std()],
    }, index=["Avg. Toxicity", "std"]).round(2)
    display(stats)

In [None]:
display_stats(prompt_rescored, original, column="prompt")

In [None]:
bins = list(np.round(np.arange(0, 1.1, 0.1), 2))

original["bin"] = pd.cut(original.prompt.apply(lambda x: x["toxicity"]), bins=bins)
prompt_rescored["bin"] = pd.cut(prompt_rescored.prompt.apply(lambda x: x["toxicity"]), bins=bins)

normalize = False
df_bins = pd.concat([original["bin"].value_counts(normalize=normalize), prompt_rescored["bin"].value_counts(normalize=normalize)], axis=1).sort_index()
df_bins.columns = ["original", "rescored"]
df_bins

In [9]:
def process_distributions(original, rescored, attributes, melt=True):
    temp = pd.DataFrame()
    for attr in attributes:
        temp = pd.concat([temp, pd.DataFrame({
            "published": original["prompt"].apply(lambda x: x[attr]), 
            "rescored": rescored["prompt"].apply(lambda x: x[attr]),
            "attribute": attr
        })])
    if melt:
        temp = temp.melt(value_vars=["published", "rescored"], var_name="score", id_vars=["attribute"])
    return temp

In [None]:
distributions = process_distributions(original, prompt_rescored, attributes=PERSPECTIVE_API_ATTRIBUTES_LOWER)
distributions["attribute"] = distributions["attribute"].str.replace("_", " ")

plt.figure(figsize=(15, 6))
sns.violinplot(data=distributions, x="attribute", y="value", hue="score", split=True, inner=None)
plt.xticks(rotation=45)
plt.xlabel("Perspective API attributes density plots")
plt.ylabel("Perspective API scores")
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.tight_layout()
plt.savefig('images/rtp_prompts_attr_distributions.svg', format="svg")
plt.savefig('images/rtp_prompts_attr_distributions.pdf')
plt.show()

In [None]:
distributions = process_distributions(original, prompt_rescored, attributes=PERSPECTIVE_API_ATTRIBUTES_LOWER, melt=False).dropna()
distributions

In [None]:
distributions.groupby("attribute").apply(
    lambda x: scipy.stats.wasserstein_distance(
        x['published'].values, 
        x['rescored'].values, 
    )
).round(3).sort_values()

In [None]:
def kl_divergence(p, q):
    return np.sum(np.where(p != 0, p * np.log(p / q), 0))

distributions.groupby("attribute").apply(
    lambda x: kl_divergence(
        x['published'].values, 
        x['rescored'].values, 
    )
).round(3).sort_values()

### Qualitative Assessment of Prompts

In [15]:
from wordcloud import WordCloud
from wordcloud import STOPWORDS
import string

In [16]:
def plot_wordcloud(df):
    text = " ".join(i for i in df.text)
    stopwords = set(list(STOPWORDS) + list(string.ascii_lowercase) + list(string.ascii_uppercase))
    wordcloud = WordCloud(
        stopwords=stopwords, 
        background_color="white",
        collocations=False,
        normalize_plurals=True,
        width=1000,
        height=1500, 
        random_state=42,
        colormap="copper"
    ).generate(text)
    plt.figure(figsize=(5, 7))
    plt.imshow(wordcloud, interpolation='bilinear')
    plt.axis("off")
    return plt

def display_samples(df):
    display("Top-10 samples with bigger abs difference.")
    display(df.sort_values(
        by="abs_diff", 
        ascending=False
        ).iloc[:10][["text", "original", "rescored", "abs_diff"]].round(2).set_index("text", drop=True)
    )

    display("Top-10 samples with smaller abs difference.")
    display(df.sort_values(
        by="abs_diff", 
        ascending=True
        ).iloc[:10][["text", "original", "rescored", "abs_diff"]].round(2).set_index("text", drop=True)
    )
    g = sns.regplot(data=df, x="original", y="rescored")
    plt.show()

In [None]:
threshold = 0.5
prompts = pd.json_normalize(original["prompt"])[["text", "toxicity"]].rename(columns={"toxicity": "original"})
rescored_prompts = pd.json_normalize(prompt_rescored["prompt"])[["text", "toxicity"]]

assert prompts["text"].equals(rescored_prompts["text"])
prompts["rescored"] = rescored_prompts["toxicity"]

pd.crosstab(prompts["original"] > threshold, prompts["rescored"] > threshold)

In [None]:
new_toxic = prompts.query("original <= @threshold and rescored > @threshold")
new_nontoxic = prompts.query("rescored <= @threshold and original > @threshold")

new_toxic["abs_diff"] = np.abs(new_toxic.rescored.values - new_toxic.original.values)
new_nontoxic["abs_diff"] = np.abs(new_nontoxic.rescored.values - new_nontoxic.original.values)

In [19]:
path = Path("data/qualitative/prompts_new_toxic.csv")
path.parent.mkdir(parents=True, exist_ok=True)
new_toxic[["text"]].set_index("text").to_csv(path)

path = Path("data/qualitative/prompts_new_nontoxic.csv")
new_nontoxic[["text"]].set_index("text").to_csv(path)

#### New Toxic Prompts

In [None]:
plot_wordcloud(new_toxic.copy())
plt.tight_layout()
plt.savefig("images/wordcloud_prompts_new_toxic.pdf")
plt.show()

In [None]:
pd.set_option('display.max_colwidth', None)

display_samples(new_toxic.copy())

#### New Non-Toxic prompts

In [None]:
plot_wordcloud(new_nontoxic.copy())
plt.tight_layout()
plt.savefig("images/wordcloud_prompts_new_nontoxic.pdf")
plt.show()

In [None]:
display_samples(new_nontoxic.copy())

### Continuations

In [24]:
cont_rescored = pd.read_json("data/real-toxicity-prompts/rtp_continuations_rescored.jsonl", lines=True)

In [None]:
display_stats(cont_rescored, original, column="continuation")

## Prompted scores - Baselines from RealToxicityPrompts

From Table 2. Original results, partially rescored (only generations) and fully rescored (gens and prompts).

In [26]:
prompted_models = {
    "GPT1": {
        "original": "data/real-toxicity-prompts/rtp_generations/prompted_gens_gpt1_original_toxicities.csv",
        "all rescored": "data/real-toxicity-prompts/rtp_generations/prompted_gens_gpt1_all_rescored_toxicities.csv",
        "generations rescored": "data/real-toxicity-prompts/rtp_generations/prompted_gens_gpt1_gens_rescored_toxicities.csv",
    },
    "GPT2": {
        "original": "data/real-toxicity-prompts/rtp_generations/prompted_gens_gpt2_original_toxicities.csv",
        "all rescored": "data/real-toxicity-prompts/rtp_generations/prompted_gens_gpt2_all_rescored_toxicities.csv",
        "generations rescored": "data/real-toxicity-prompts/rtp_generations/prompted_gens_gpt2_gens_rescored_toxicities.csv",
    },
    "GPT3": {
        "original": "data/real-toxicity-prompts/rtp_generations/prompted_gens_gpt3_davinci_original_toxicities.csv",
        "all rescored": "data/real-toxicity-prompts/rtp_generations/prompted_gens_gpt3_davinci_all_rescored_toxicities.csv",
        "generations rescored": "data/real-toxicity-prompts/rtp_generations/prompted_gens_gpt3_davinci_gens_rescored_toxicities.csv",
    },
    "CTRL": {
        "original": "data/real-toxicity-prompts/rtp_generations/prompted_gens_ctrl_original_toxicities.csv",
        "all rescored": "data/real-toxicity-prompts/rtp_generations/prompted_gens_ctrl_all_rescored_toxicities.csv",
        "generations rescored": "data/real-toxicity-prompts/rtp_generations/prompted_gens_ctrl_gens_rescored_toxicities.csv",
    },
    "CTRL-W": {
        "original": "data/real-toxicity-prompts/rtp_generations/prompted_gens_ctrl_wiki_original_toxicities.csv",
        "all rescored": "data/real-toxicity-prompts/rtp_generations/prompted_gens_ctrl_wiki_all_rescored_toxicities.csv",
        "generations rescored": "data/real-toxicity-prompts/rtp_generations/prompted_gens_ctrl_wiki_gens_rescored_toxicities.csv",
    },
}

In [27]:
def load_toxicity(paths_dict):
    tox_metrics = pd.DataFrame()
    for model, metrics in paths_dict.items():
        for mode, path in metrics.items():
            if not Path(path).exists():
                continue
            metrics = pd.read_csv(path).T
            metrics.columns = metrics.iloc[0].values.tolist()
            metrics = metrics[1:]
            metrics["model"] = model
            metrics["mode"] = mode
            tox_metrics = pd.concat([tox_metrics, metrics])
    tox_metrics.index.name = "prompt_toxicity"
    return tox_metrics.reset_index()

def show_toxicity_results(df, mode, row_order=None, round=2):
    df = df.replace(to_replace={"toxic": "Toxic", "nontoxic": "Non-Toxic"})
    df = df.rename(columns={
        "avg_max": "Exp. Max. Toxicity", 
        "toxicity_probability": "Toxicity Probability", 
        "toxic_fraction": "Toxic Fraction",
        "model": "Model"
    })

    metrics = ["Exp. Max. Toxicity", "std_max", "Toxicity Probability"]
    if 'Toxic Fraction' in df.columns:
        metrics += ["Toxic Fraction"]
    df = pd.pivot_table(
        df.query("mode == @mode and prompt_toxicity != 'full'"), 
        index=["Model"], 
        values=metrics, columns=["prompt_toxicity"]
    ).round(round)

    # Put toxic first
    df = df.reindex(["Toxic", "Non-Toxic"], axis=1, level=1)

    # Reorder rows to match paper results
    if row_order is not None:
        df = df.reindex(row_order)

    return df

In [28]:
toxicity_metrics = load_toxicity(prompted_models)
row_order = ["GPT1", "GPT2", "GPT3", "CTRL", "CTRL-W"]

In [None]:
original_res = show_toxicity_results(toxicity_metrics, mode="original", row_order=row_order)
original_res

In [None]:
gens_res = show_toxicity_results(toxicity_metrics, mode="generations rescored", row_order=row_order)
gens_res

In [None]:
all_res = show_toxicity_results(toxicity_metrics, mode="all rescored", row_order=row_order)
all_res

In [32]:
orig_melt = original_res.reset_index().melt(id_vars=["Model"])
orig_melt["mode"] = "original"

gens_melt = gens_res.reset_index().melt(id_vars=["Model"])
gens_melt["mode"] = "generations rescored"

all_melt = all_res.reset_index().melt(id_vars=["Model"])
all_melt["mode"] = "all rescored"

rtp_baselines = pd.concat([orig_melt, gens_melt, all_melt]).reset_index(drop=True)
rtp_baselines = rtp_baselines.rename(columns={None: "Metric", "prompt_toxicity": "Prompt"})

In [None]:
temp = rtp_baselines.query("Metric != 'std_max'")
g = sns.catplot(
    data=temp,
    x="mode", y='value', hue="Model",
    col="Prompt", 
    row="Metric",
    kind="point", 
    sharex=True, 
    sharey=True,
    height=5,
    scale=2.0
)
g.set_titles(template="Prompt = {col_name}")
g.set_xticklabels(["published", "generations\nrescored", "all\nrescored"], fontsize=16)
g.set_xlabels("", "")
g.set_xlabels("", "")

metrics = ["Exp. Max. Toxicity", "Toxicity Probability"]
for i in range(len(metrics)):
    g.axes[i, 0].set_ylabel(metrics[i])
    if i > 0:
        g.axes[i, 0].set_title("")
        g.axes[i, 1].set_title("")

g.savefig(f"images/rtp_baselines.svg", format="svg")
g.savefig(f"images/rtp_baselines.pdf")

plt.show()

In [None]:
for i, metric in enumerate(["Exp. Max. Toxicity", "Toxicity Probability"]):
    temp = rtp_baselines.query("Metric == @metric")
    g = sns.catplot(
        data=temp,
        x="mode", 
        y='value', 
        hue="Model",
        col="Prompt", 
        kind="point", 
        sharex=True, 
        sharey=True,
        height=4,
        scale=1.5
    )
    g.figure.subplots_adjust(top=0.82)
    g.set_titles(template="Prompt = {col_name}", size=14)
    g.set_xticklabels(["published", "generations\nrescored", "prompts and\ngenerations\nrescored"], fontsize=11)
    g.set(ylim=(0, 1.0))
    # g.set_yticklabels([f'{val:.2f}' for val in np.arange(0, 1.1, 0.25)], fontsize=11)
    g.set_xlabels("", "")
    g.set_ylabels("", "")

    sns.move_legend(g, "upper center", ncols=5, fontsize=11, title="")
    if i == 0:
        g.axes[0, i].set(ylim=(0.3, 0.9))
        # g.legend.remove()

    g.savefig(f"images/rtp_baselines_{metric}.pdf")

plt.show()

## Prompted Generations - Other papers

### DExperts

Only models showcased in UDDIA's paper table 3.

In [35]:
dexperts_models = {
    "GPT2 (large)": {
        "original": "data/dexperts/generations/toxicity/toxicity/prompted_gens_gpt2_original_toxicity.csv",
        "generations rescored": "data/dexperts/generations/toxicity/toxicity/prompted_gens_gpt2_gens_rescored_toxicity.csv",
        "all rescored": "data/dexperts/generations/toxicity/toxicity/prompted_gens_gpt2_all_rescored_toxicity.csv"
    },
    "DAPT": {
        "original": "data/dexperts/generations/toxicity/toxicity/prompted_gens_dapt_original_toxicity.csv",
        "generations rescored": "data/dexperts/generations/toxicity/toxicity/prompted_gens_dapt_gens_rescored_toxicity.csv",
        "all rescored": "data/dexperts/generations/toxicity/toxicity/prompted_gens_dapt_all_rescored_toxicity.csv"
    },
    "GeDi": {
        "original": "data/dexperts/generations/toxicity/toxicity/prompted_gens_gedi_original_toxicity.csv",
        "generations rescored": "data/dexperts/generations/toxicity/toxicity/prompted_gens_gedi_gens_rescored_toxicity.csv",
        "all rescored": "data/dexperts/generations/toxicity/toxicity/prompted_gens_gedi_all_rescored_toxicity.csv"
    },
    "DExperts (large)": {
        "original": "data/dexperts/generations/toxicity/toxicity/prompted_gens_dexperts_large_original_toxicity.csv",
        "generations rescored": "data/dexperts/generations/toxicity/toxicity/prompted_gens_dexperts_large_gens_rescored_toxicity.csv",
        "all rescored": "data/dexperts/generations/toxicity/toxicity/prompted_gens_dexperts_large_all_rescored_toxicity.csv"
    },
    "PPLM (10%)": {
        "original": "data/dexperts/generations/toxicity/toxicity/prompted_gens_pplm_original_toxicity.csv",
        "generations rescored": "data/dexperts/generations/toxicity/toxicity/prompted_gens_pplm_gens_rescored_toxicity.csv",
        "all rescored": "data/dexperts/generations/toxicity/toxicity/prompted_gens_pplm_all_rescored_toxicity.csv"
    },
    "UDDIA (TH=40)": {
        "original": "data/uddia/continuations/TH40/published_toxicity.csv",
        "generations rescored": "data/uddia/continuations/TH40/toxicity_dict.csv",
    },
}

In [36]:
dexperts_tox = load_toxicity(dexperts_models)

In [None]:
dexperts_tox

In [None]:
original_res = show_toxicity_results(dexperts_tox, mode="original", round=3)
original_res

In [None]:
gens_res = show_toxicity_results(dexperts_tox, mode="generations rescored", round=3)
gens_res

In [None]:
all_res = show_toxicity_results(dexperts_tox, mode="all rescored", round=3)
all_res

In [41]:
orig_melt = original_res.reset_index().melt(id_vars=["Model"])
orig_melt["mode"] = "original"

gens_melt = gens_res.reset_index().melt(id_vars=["Model"])
gens_melt["mode"] = "generations rescored"

# Removed 'all' from plot since no major changes in results
dexperts_baselines = pd.concat([orig_melt, gens_melt]).reset_index(drop=True)
dexperts_baselines = dexperts_baselines.rename(columns={None: "Metric", "prompt_toxicity": "Prompt"})

In [None]:
temp = dexperts_baselines.query("Metric != 'std_max' and Prompt == 'Non-Toxic'")

g = sns.catplot(
    data=temp,
    x="mode", 
    y="value", 
    hue="Model",
    col="Metric",
    kind="point", 
    sharex=True, 
    sharey=False,
    height=5,
    scale=2.2
)
g.set_titles(template="{col_name}")
g.set_xticklabels(["published", "generations\nrescored"])
g.set_xlabels("", "")
for axis in g.axes.flat:
    axis.tick_params(labelleft=True)
sns.move_legend(g, "upper center", ncols=6, fontsize=12, title="", bbox_to_anchor=(0.5, 1.1))
plt.tight_layout()
g.savefig(f"images/uddia_results.svg", format="svg")
g.savefig(f"images/uddia_results.pdf")
plt.show()

In [None]:
temp = dexperts_baselines.query("Metric != 'std_max' and Prompt == 'Non-Toxic'")
temp["normalized value"] = temp.groupby(["mode", "Metric"])["value"].transform(
    lambda x: (x - x.min()) / (x.max() - x.min()))
temp["slope"] = temp.groupby(["Model", "Metric"])["normalized value"].transform(lambda x: linregress([0, 1], [x.iloc[0], x.iloc[1]]).slope).round(2)

g = sns.catplot(
    data=temp,
    x="mode", 
    y="normalized value", 
    hue="Model",
    col="Metric",
    kind="point", 
    sharex=True, 
    sharey=False,
    height=5,
    scale=2.2
)
g.set_titles(template="{col_name}")
g.set_xticklabels(["published", "generations\nrescored"])
g.set_xlabels("", "")
for col, metric in enumerate(g.col_names):
    ax = g.axes[0, col]
    for c in ax.collections:
        offsets = c.get_offsets()
        # Annotate just next to second dot
        slope = linregress(offsets.data.T).slope

        x_bump = 0.1
        y_bump = -0.05  if slope > 0 else 0
        ax.annotate(f"{slope:.2f}", offsets[1, :] + [x_bump, y_bump], fontsize=12) 

sns.move_legend(g, "upper center", ncols=6, fontsize=12, title="", bbox_to_anchor=(0.5, 1.1))
plt.tight_layout()

g.savefig(f"images/uddia_slopes.svg", format="svg")
g.savefig(f"images/uddia_slopes.pdf")

plt.show()

### How is the distribution of 10k non-toxic sample from DExperts now?

In [44]:
dexperts_prompts_original = pd.read_json("data/dexperts/prompts/nontoxic_prompts-10k.jsonl", lines=True)
dexperts_prompts_rescored = pd.read_json("data/dexperts/prompts/nontoxic_prompts-10k_rescored.jsonl", lines=True)

In [45]:
toxicity = pd.DataFrame({
    "toxic_original": dexperts_prompts_original.prompt.apply(lambda x: x.get("toxicity")) > 0.5,
    "toxic_rescored": dexperts_prompts_rescored.prompt.apply(lambda x: x.get("toxicity")) > 0.5
})

In [None]:
pd.crosstab(toxicity["toxic_original"], toxicity["toxic_rescored"])