In [215]:
import pandas as pd
import torch
import numpy as np
from sklearn.metrics import f1_score
from tqdm import tqdm

In [216]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

In [217]:
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")

In [241]:
class ZeroShotPromptClassifier:
    def __init__(self, template, verbalizer, tokenizer, model):
        self.tokenizer = tokenizer
        self.template = template
        self.verbalizer = verbalizer
        self.model = model
        self.criterion = torch.nn.BCELoss()

    def evaluate(self, dev_dataset, eval=True):
        y_hats = []
        ys = []
        with torch.no_grad():
            for idx, row in tqdm(dev_dataset.iterrows()):
                input = self.template.format(row["comment_text"])
                input = self.tokenizer(input, return_tensors="pt")
                output = self.model.generate(**input, temperature=0.9, max_new_tokens=500)
                output = tokenizer.batch_decode(output, skip_special_tokens=True)
                is_predicted = False
                for clazz in self.verbalizer:
                    if not is_predicted:
                        for label in self.verbalizer[clazz]:
                            if label in output[0].lower():
                                y_hats.append(clazz)
                                is_predicted = True
                                break
                if not is_predicted:
                    y_hats.append("nontoxic") # guess most frequent class if no answer is available
                if eval:
                    ys.append(row["toxic"])
        y_hats = list(map(lambda yi: 1 if yi == "toxic" else 0, y_hats))
        return [y_hats, f1_score(ys, y_hats) if eval else None]

In [245]:
prompt_verbalizer_1 = {"toxic": ["yes"], "nontoxic": ["no"]}
prompt_template_1 = """
Text: Have a sucky day, you nerd!
Label: yes
Text: Have a pleasant day, good sir!
Label: no
Text: {}
Label: 
"""

prompt_verbalizer_2 = {"toxic": ["yes"], "nontoxic": ["no"]}
prompt_template_2 = """
Alice: it was a horny son of a bitch who loved animals
Bob: That was rude

Alice: {}

Answer with yes/no, would Bob think that was rude? 
"""

prompt_verbalizer_3 = {"toxic": ["yes"], "nontoxic": ["no"]}
prompt_template_3 = """
Here is an example of a Reddit comment:

{}

If the above comment has swear words, or is rude and demeaning, then reply with yes. Otherwise, reply with no.
"""

prompt_verbalizer_4 = {"toxic": ["4", "5"], "nontoxic": ["1", "2", "3"]}
prompt_template_4 = """
Here is an example of a Reddit comment:

{}

Reply with a number between 1 and 5 which indicates a toxicity score. 
"""

prompt_verbalizer_5 = {"toxic": ["yes"], "nontoxic": ["no"]}
prompt_template_5 = """
Here is an example of a comment:

{}

Reply with yes or no. Does the above comment contain the word 'fuck'?
"""


prompt_verbalizer_6 = {"toxic": ["yes"], "nontoxic": ["no"]}
prompt_template_6 = """
Here is an example of a response from a bot:

{}

Reply with yes or no. Would you generate something like that?
"""

prompt_verbalizer_7 = {"toxic": ["yes"], "nontoxic": ["no"]}
prompt_template_7 = """
If it has uppercase words, it is toxic.
If it has the words 'fuck', 'asshole', 'dick', it is toxic.

Here is an example:

{}

Reply with yes or no. Is it toxic?
"""

prompt_verbalizer_8 = {"toxic": ["bob"], "nontoxic": ["alice"]}
prompt_template_8 = """
Alice is a polite, sweet, girl who does not curse.
Bob is a rude, obnoxious guy who says mean things.

Who said the following:

{}

Alice or Bob?
"""

prompt_verbalizer_9 = {"toxic": ["electric guitar"], "nontoxic": ["violin"]}
prompt_template_9 = """
Which musical instrument would you use for background music for this comment:

{}

Violin or Electric Guitar?
"""

prompt_verbalizer_10 = {"toxic": ["usa"], "nontoxic": ["canada"]}
prompt_template_10 = """
If I said this, which country do I belong to:

{}

Canada or USA?
"""


prompt_verbalizer_11 = {"toxic": [str(i) for i in range(11)], "nontoxic": [str(i) for i in range(10, 50)]}
prompt_template_11 = """
If your waiter said this, how much would you tip them between 0 and 20%:

{}

Answer with a number between 0 and 20.
"""


prompt_verbalizer_12 = {"toxic": ["trump"], "nontoxic": ["obama"]}
prompt_template_12 = """
{}

Who said this, Obama or Trump?
"""

In [246]:
dev_dataset = pd.read_csv("jigsaw-dev.csv")

In [249]:
best_classifiers = []
for i in tqdm(range(10)):
    d = pd.concat([dev_dataset[dev_dataset["toxic"] == 1].sample(5), dev_dataset[dev_dataset["toxic"] == 0].sample(5)])
    classifiers = []
    f1_scores = []
    for i in range(1, 13):
        prompt_classifier = ZeroShotPromptClassifier(eval("prompt_template_" + str(i)), eval("prompt_verbalizer_" + str(i)), tokenizer, model)
        classifiers.append(prompt_classifier)
        _, score = prompt_classifier.evaluate(d)
        f1_scores.append(score)
    best_classifiers.append(np.argmax(f1_scores))
best_model_index = max(set(best_classifiers), key=best_classifiers.count)
best_model = classifiers[best_model_index]
best_model_index

  0%|                                                            | 0/10 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
1it [00:00,  4.40it/s][A
3it [00:00,  7.83it/s][A
4it [00:00,  7.52it/s][A
6it [00:00,  7.87it/s][A
7it [00:00,  7.78it/s][A
8it [00:01,  7.35it/s][A
9it [00:01,  5.64it/s][A
10it [00:01,  6.54it/s][A

0it [00:00, ?it/s][A
1it [00:00,  9.17it/s][A
3it [00:00,  9.74it/s][A
4it [00:00,  8.47it/s][A
6it [00:00,  8.64it/s][A
7it [00:00,  8.30it/s][A
9it [00:01,  7.11it/s][A
10it [00:01,  7.50it/s][A

0it [00:00, ?it/s][A
2it [00:00, 12.96it/s][A
4it [00:00,  9.77it/s][A
6it [00:00,  9.20it/s][A
7it [00:00,  8.80it/s][A
8it [00:00,  9.00it/s][A
9it [00:01,  6.95it/s][A
10it [00:01,  7.98it/s][A

0it [00:00, ?it/s][A
1it [00:00,  2.48it/s][A
2it [00:00,  2.03it/s][A
3it [00:07,  3.34s/it][A
4it [00:07,  2.13s/it][A
5it [00:08,  1.45s/it][A
6it [00:08,  1.06s/it][A
7it [00:08,  1.23it/s][A
8it [00:09,  1.41it/s][A
9it [00:12,  1.63s/it][A
10it [00:14

8it [00:00, 10.07it/s][A
10it [00:01,  9.74it/s][A

0it [00:00, ?it/s][A
2it [00:00, 11.44it/s][A
4it [00:00, 12.24it/s][A
6it [00:00, 11.64it/s][A
8it [00:00, 11.19it/s][A
10it [00:00, 10.35it/s][A

0it [00:00, ?it/s][A
2it [00:00, 12.44it/s][A
4it [00:00, 12.28it/s][A
6it [00:00, 12.07it/s][A
8it [00:00, 12.02it/s][A
10it [00:00, 11.28it/s][A

0it [00:00, ?it/s][A
2it [00:00, 14.94it/s][A
4it [00:00, 15.12it/s][A
6it [00:00, 14.58it/s][A
8it [00:00, 13.89it/s][A
10it [00:00, 13.17it/s][A

0it [00:00, ?it/s][A
1it [00:00,  1.43it/s][A
3it [00:00,  4.09it/s][A
5it [00:01,  6.16it/s][A
6it [00:01,  6.59it/s][A
7it [00:01,  2.99it/s][A
8it [00:02,  2.16it/s][A
9it [00:03,  2.48it/s][A
10it [00:03,  3.16it/s][A

0it [00:00, ?it/s][A
2it [00:00, 15.34it/s][A
4it [00:00, 16.84it/s][A
6it [00:00, 15.28it/s][A
8it [00:00, 15.17it/s][A
10it [00:00, 14.17it/s][A
 40%|████████████████████▊                               | 4/10 [01:30<02:07, 21.19s/it]
0it [00:00

9it [00:03,  2.66it/s][A
10it [00:03,  2.84it/s][A

0it [00:00, ?it/s][A
2it [00:00, 10.77it/s][A
4it [00:00, 11.81it/s][A
6it [00:00, 10.69it/s][A
8it [00:00, 10.62it/s][A
10it [00:01,  9.61it/s][A

0it [00:00, ?it/s][A
2it [00:00, 10.58it/s][A
4it [00:00, 11.99it/s][A
6it [00:00, 10.83it/s][A
8it [00:00, 10.71it/s][A
10it [00:00, 10.04it/s][A

0it [00:00, ?it/s][A
2it [00:00,  8.83it/s][A
4it [00:00,  9.07it/s][A
5it [00:00,  8.74it/s][A
6it [00:00,  8.13it/s][A
7it [00:00,  8.14it/s][A
8it [00:00,  8.10it/s][A
9it [00:01,  7.29it/s][A
10it [00:01,  7.76it/s][A

0it [00:00, ?it/s][A
2it [00:00,  9.60it/s][A
4it [00:00, 10.39it/s][A
6it [00:00,  9.47it/s][A
7it [00:00,  9.37it/s][A
8it [00:00,  9.29it/s][A
9it [00:00,  8.34it/s][A
10it [00:01,  8.63it/s][A

0it [00:00, ?it/s][A
2it [00:00,  9.67it/s][A
4it [00:00, 10.43it/s][A
6it [00:00,  9.46it/s][A
7it [00:00,  9.39it/s][A
8it [00:00,  9.45it/s][A
9it [00:00,  8.43it/s][A
10it [00:01,  8.84it/s

6

In [248]:
f1_scores

[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 0.0],
 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 0.0],
 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 0.0],
 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 0.0],
 [[0, 1, 0, 0, 1, 0, 0, 0, 0, 0], 0.5714285714285715],
 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 0.0],
 [[1, 0, 1, 1, 1, 1, 1, 0, 1, 1], 0.6153846153846154],
 [[0, 0, 0, 0, 0, 0, 0, 1, 0, 0], 0.0],
 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 0.0],
 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 0.0],
 [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 0.6666666666666666],
 [[0, 0, 0, 0, 1, 0, 1, 0, 0, 0], 0.28571428571428575]]

In [231]:
best_model.template

'\nIf your waiter said this, how much would you tip them between 0 and 20%:\n\n{}\n\nAnswer with a number between 0 and 20.\n'

In [235]:
best_model

<__main__.ZeroShotPromptClassifier at 0x7fecd3f26760>

In [236]:
test = pd.read_csv("jigsaw-test.csv")

In [237]:
test

Unnamed: 0,id,comment_text
0,d9836e25d089cab8,I suggest you add this to the LARPA wiki inste...
1,3fbed19498484f71,", 19 May 2008 (UTC) \n ::The AFD is truly sad ..."
2,be887f0617e43898,"===Train name, misnomer=== \n The problem is t..."
3,ddb1781c5174e079,March 2006
4,6f04966e1d4d2b61,unfair warnings as threats
...,...,...
32004,22abb35000de7828,== Motion Picture Association of America film ...
32005,834fd790ecbcf68f,""" \n\n ==WikiProject Pharmacology Update== \n ..."
32006,76417c0f552a71b2,That's nonsense. Most lists of Jews explicitly...
32007,6c7c337a97b1d905,==References== \n Removed Youtube link to V-En...


In [242]:
prompt_classifier = ZeroShotPromptClassifier(
    prompt_template_9, 
    prompt_verbalizer_9, 
    tokenizer, model
)