In [58]:
import pandas as pd
from sklearn.metrics import f1_score
from tqdm import tqdm
import torch
import numpy as np

In [73]:
from transformers import BertTokenizer, AutoModelForSequenceClassification
from torch.optim import  AdamW

In [60]:
from openprompt.prompts import ManualTemplate
from openprompt.prompts import ManualVerbalizer
from openprompt.data_utils import InputExample
from openprompt.plms import load_plm
from openprompt import PromptForClassification, PromptDataLoader

In [61]:
plm_object = load_plm("bert", "microsoft/MiniLM-L12-H384-uncased")
tokenizer = plm_object[1]

Some weights of BertForMaskedLM were not initialized from the model checkpoint at microsoft/MiniLM-L12-H384-uncased and are newly initialized: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [68]:
classes = [
    "toxic",
    "nontoxic"
]

In [74]:
prompt_template_1 = ManualTemplate(
    text = '{"placeholder":"text_a"}. Were there swear words in the previous sentence? {"mask"}',
    tokenizer = tokenizer,
)

prompt_verbalizer_1 = ManualVerbalizer(
    classes = classes,
    label_words = {
        "toxic": ["yes"],
        "nontoxic": ["no"],
    },
    tokenizer = tokenizer,
)

prompt_template_2 = ManualTemplate(
    text = '{"placeholder":"text_a"}. On a scale of 1-5, how rude was that sentence? {"mask"}',
    tokenizer = tokenizer,
)

prompt_verbalizer_2 = ManualVerbalizer(
    classes = classes,
    label_words = {
        "toxic": ["4", "5"],
        "nontoxic": ["1", "2", "3"],
    },
    tokenizer = tokenizer,
)

prompt_template_3 = ManualTemplate(
    text = '{"placeholder":"text_a"}. What was the author of the previous sentence feeling? {"mask"}',
    tokenizer = tokenizer,
)

prompt_verbalizer_3 = ManualVerbalizer(
    classes = classes,
    label_words = {
        "toxic": ["anger", "frustration", "sad", "depressed"],
        "nontoxic": ["content", "happy", "calm", "nothing"],
    },
    tokenizer = tokenizer,
)

prompt_template_4 = ManualTemplate(
    text = '{"placeholder":"text_a"}. Would you kiss your mother with that mouth? {"mask"}',
    tokenizer = tokenizer,
)

prompt_verbalizer_4 = ManualVerbalizer(
    classes = classes,
    label_words = {
        "toxic": ["no"],
        "nontoxic": ["yes"],
    },
    tokenizer = tokenizer,
)

prompt_template_5 = ManualTemplate(
    text = '{"placeholder":"text_a"}. If that sentence were spoken out, how loud would it be on a scale of 1-5? {"mask"}',
    tokenizer = tokenizer,
)


prompt_verbalizer_5 = ManualVerbalizer(
    classes = classes,
    label_words = {
        "toxic": ["4", "5"],
        "nontoxic": ["1", "2", "3"],
    },
    tokenizer = tokenizer,
)

prompt_template_6 = ManualTemplate(
    text = 'Alice: How are you? Bob: {"placeholder":"text_a"}. Alice: That\'s {"mask"}',
    tokenizer = tokenizer,
)


prompt_verbalizer_6 = ManualVerbalizer(
    classes = classes,
    label_words = {
        "toxic": ["rude", "terrible", "bad"],
        "nontoxic": ["nice", "good", "alright"],
    },
    tokenizer = tokenizer,
)

prompt_template_7 = ManualTemplate(
    text = 'If someone said {"placeholder":"text_a"} that would be {"mask"}',
    tokenizer = tokenizer,
)

prompt_verbalizer_7 = ManualVerbalizer(
    classes = classes,
    label_words = {
        "toxic": ["rude", "terrible", "bad"],
        "nontoxic": ["nice", "good", "alright"],
    },
    tokenizer = tokenizer,
)

prompt_template_8 = ManualTemplate(
    text = 'If I agree with {"placeholder":"text_a"} then I am {"mask"}',
    tokenizer = tokenizer,
)

prompt_verbalizer_8 = ManualVerbalizer(
    classes = classes,
    label_words = {
        "toxic": ["rude", "terrible", "dick", "horrible"],
        "nontoxic": ["nice", "good", "alright", "calm", "composed"],
    },
    tokenizer = tokenizer,
)

prompt_template_9 = ManualTemplate(
    text = '''
    YOU GAY ASSHOLE MOTHERFUCKER, THAT'S BULLSHIT. Is this toxic? yes
    Well, that's alright, you didn't know. Is this toxic? no
    {"placeholder":"text_a"}. Is this toxic? {"mask"}
    ''',
    tokenizer = tokenizer,
)

prompt_verbalizer_9 = ManualVerbalizer(
    classes = classes,
    label_words = {
        "toxic": ["yes"],
        "nontoxic": ["no"],
    },
    tokenizer = tokenizer,
)

prompt_template_10 = ManualTemplate(
    text = 'As a Reddit mod would you ban this comment: {"placeholder":"text_a"}? {"mask"}',
    tokenizer = tokenizer,
)

prompt_verbalizer_10 = ManualVerbalizer(
    classes = classes,
    label_words = {
        "toxic": ["yes"],
        "nontoxic": ["no"],
    },
    tokenizer = tokenizer,
)

In [75]:
class PromptClassifier:
    def __init__(self, template, verbalizer, plm_obj):
        self.plm, self.tokenizer, self.model_config, self.WrapperClass = plm_obj
        self.template = template
        self.verbalizer = verbalizer
        self.model = PromptForClassification(
            template = template,
            plm = self.plm,
            verbalizer = verbalizer,
        )
        self.criterion = torch.nn.BCELoss()
        self.optim = AdamW(self.model.parameters(), lr=1e-2)
    
    def create_data_loader(self, dataset):
        return PromptDataLoader(
            dataset = dataset,
            tokenizer = self.tokenizer,
            template = self.template,
            tokenizer_wrapper_class = self.WrapperClass,
        )

    def train(self, dataset):
        for batch in tqdm(self.create_data_loader(dataset)):
            logits = self.model(batch)
            preds = torch.argmax(logits, dim = -1)
            loss = self.criterion(preds.float(), batch["label"].float())
            loss = torch.autograd.Variable(loss, requires_grad = True)
            self.optim.step()
            self.optim.zero_grad()
            loss.backward()

    def evaluate(self, dev_dataset):
        self.model.eval()
        y_hats = []
        ys = []
        with torch.no_grad():
            for batch in tqdm(self.create_data_loader(dev_dataset)):
                logits = self.model(batch)
                preds = torch.argmax(logits, dim = -1)
                y_hats.append(preds.item())
                ys.append(batch["label"].item())    
        return f1_score(ys, y_hats)

In [76]:
df = pd.read_csv("jigsaw-train.csv")
dev_df = pd.read_csv("jigsaw-dev.csv")

In [77]:
best_classifiers = []
for i in tqdm(range(10)):
    dataset_0 = [
        InputExample(
            guid=idx,
            text_a=row["comment_text"],
            label=row["toxic"]
        ) for idx, row in df[df["toxic"] == 0].sample(5).iterrows()
    ]

    dataset_1 = [
        InputExample(
            guid=idx,
            text_a=row["comment_text"],
            label=row["toxic"]
        ) for idx, row in df[df["toxic"] == 1].sample(5).iterrows()
    ]

    dataset = dataset_0 + dataset_1

    dev_dataset_0 = [
        InputExample(
            guid=idx,
            text_a=row["comment_text"],
            label=row["toxic"]
        ) for idx, row in dev_df[dev_df["toxic"] == 0].sample(5).iterrows()
    ]

    dev_dataset_1 = [
        InputExample(
            guid=idx,
            text_a=row["comment_text"],
            label=row["toxic"]
        ) for idx, row in dev_df[dev_df["toxic"] == 1].sample(5).iterrows()
    ]

    dev_dataset = dev_dataset_0 + dev_dataset_1
    
    classifiers = []
    f1_scores = []
    for i in range(1, 11):
        prompt_classifier = PromptClassifier(eval("prompt_template_" + str(i)), eval("prompt_verbalizer_" + str(i)), plm_object)
        prompt_classifier.train(dataset)
        classifiers.append(prompt_classifier)
        score = prompt_classifier.evaluate(dev_dataset)
        f1_scores.append(score)
    best_classifiers.append(np.argmax(f1_scores))
best_model_index = max(set(best_classifiers), key=best_classifiers.count)
best_model = classifiers[best_model_index]
best_model_index

  0%|                                                            | 0/10 [00:00<?, ?it/s]
tokenizing: 10it [00:00, 302.87it/s]

  0%|                                                            | 0/10 [00:00<?, ?it/s][A
 10%|█████▏                                              | 1/10 [00:00<00:06,  1.47it/s][A
 20%|██████████▍                                         | 2/10 [00:01<00:05,  1.57it/s][A
 30%|███████████████▌                                    | 3/10 [00:01<00:03,  1.76it/s][A
 40%|████████████████████▊                               | 4/10 [00:02<00:03,  1.78it/s][A
 50%|██████████████████████████                          | 5/10 [00:02<00:02,  1.90it/s][A
 60%|███████████████████████████████▏                    | 6/10 [00:03<00:02,  1.95it/s][A
 70%|████████████████████████████████████▍               | 7/10 [00:03<00:01,  2.01it/s][A
 80%|█████████████████████████████████████████▌          | 8/10 [00:04<00:00,  2.04it/s][A
 90%|████████████████████████████████████████

 40%|████████████████████▊                               | 4/10 [00:02<00:03,  1.68it/s][A
 50%|██████████████████████████                          | 5/10 [00:02<00:02,  1.70it/s][A
 60%|███████████████████████████████▏                    | 6/10 [00:03<00:02,  1.65it/s][A
 70%|████████████████████████████████████▍               | 7/10 [00:04<00:01,  1.67it/s][A
 80%|█████████████████████████████████████████▌          | 8/10 [00:04<00:01,  1.68it/s][A
 90%|██████████████████████████████████████████████▊     | 9/10 [00:05<00:00,  1.70it/s][A
100%|███████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.70it/s][A

tokenizing: 10it [00:00, 197.04it/s]

  0%|                                                            | 0/10 [00:00<?, ?it/s][A
 10%|█████▏                                              | 1/10 [00:00<00:05,  1.51it/s][A
 20%|██████████▍                                         | 2/10 [00:01<00:05,  1.44it/s][A
 30%|███████████████▌                    

 90%|██████████████████████████████████████████████▊     | 9/10 [00:08<00:00,  1.13it/s][A
100%|███████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.15it/s][A

tokenizing: 10it [00:00, 375.01it/s]

  0%|                                                            | 0/10 [00:00<?, ?it/s][A
 10%|█████▏                                              | 1/10 [00:00<00:05,  1.64it/s][A
 20%|██████████▍                                         | 2/10 [00:01<00:04,  1.72it/s][A
 30%|███████████████▌                                    | 3/10 [00:01<00:04,  1.62it/s][A
 40%|████████████████████▊                               | 4/10 [00:02<00:03,  1.62it/s][A
 50%|██████████████████████████                          | 5/10 [00:03<00:03,  1.26it/s][A
 60%|███████████████████████████████▏                    | 6/10 [00:04<00:03,  1.12it/s][A
 70%|████████████████████████████████████▍               | 7/10 [00:05<00:02,  1.26it/s][A
 80%|████████████████████████████████████

 20%|██████████▍                                         | 2/10 [00:01<00:05,  1.59it/s][A
 30%|███████████████▌                                    | 3/10 [00:01<00:04,  1.64it/s][A
 40%|████████████████████▊                               | 4/10 [00:02<00:03,  1.64it/s][A
 50%|██████████████████████████                          | 5/10 [00:03<00:03,  1.66it/s][A
 60%|███████████████████████████████▏                    | 6/10 [00:03<00:02,  1.66it/s][A
 70%|████████████████████████████████████▍               | 7/10 [00:04<00:01,  1.71it/s][A
 80%|█████████████████████████████████████████▌          | 8/10 [00:04<00:01,  1.67it/s][A
 90%|██████████████████████████████████████████████▊     | 9/10 [00:05<00:00,  1.59it/s][A
100%|███████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.62it/s][A

tokenizing: 10it [00:00, 463.63it/s]

  0%|                                                            | 0/10 [00:00<?, ?it/s][A
 10%|█████▏                              

 70%|████████████████████████████████████▍               | 7/10 [00:04<00:02,  1.45it/s][A
 80%|█████████████████████████████████████████▌          | 8/10 [00:05<00:01,  1.48it/s][A
 90%|██████████████████████████████████████████████▊     | 9/10 [00:06<00:00,  1.43it/s][A
100%|███████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.43it/s][A

tokenizing: 10it [00:00, 296.09it/s]

  0%|                                                            | 0/10 [00:00<?, ?it/s][A
 10%|█████▏                                              | 1/10 [00:00<00:07,  1.26it/s][A
 20%|██████████▍                                         | 2/10 [00:01<00:07,  1.12it/s][A
 30%|███████████████▌                                    | 3/10 [00:02<00:05,  1.23it/s][A
 40%|████████████████████▊                               | 4/10 [00:03<00:04,  1.33it/s][A
 50%|██████████████████████████                          | 5/10 [00:03<00:03,  1.38it/s][A
 60%|███████████████████████████████▏    

 10%|█████▏                                              | 1/10 [00:00<00:07,  1.15it/s][A
 20%|██████████▍                                         | 2/10 [00:01<00:06,  1.20it/s][A
 30%|███████████████▌                                    | 3/10 [00:02<00:05,  1.32it/s][A
 40%|████████████████████▊                               | 4/10 [00:02<00:04,  1.40it/s][A
 50%|██████████████████████████                          | 5/10 [00:03<00:03,  1.50it/s][A
 60%|███████████████████████████████▏                    | 6/10 [00:04<00:02,  1.53it/s][A
 70%|████████████████████████████████████▍               | 7/10 [00:04<00:01,  1.59it/s][A
 80%|█████████████████████████████████████████▌          | 8/10 [00:05<00:01,  1.51it/s][A
 90%|██████████████████████████████████████████████▊     | 9/10 [00:06<00:00,  1.52it/s][A
100%|███████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.47it/s][A

tokenizing: 10it [00:00, 246.09it/s]

  0%|                                    

 60%|███████████████████████████████▏                    | 6/10 [00:04<00:02,  1.51it/s][A
 70%|████████████████████████████████████▍               | 7/10 [00:04<00:01,  1.55it/s][A
 80%|█████████████████████████████████████████▌          | 8/10 [00:05<00:01,  1.50it/s][A
 90%|██████████████████████████████████████████████▊     | 9/10 [00:06<00:00,  1.53it/s][A
100%|███████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.48it/s][A

tokenizing: 10it [00:00, 375.98it/s]

  0%|                                                            | 0/10 [00:00<?, ?it/s][A
 10%|█████▏                                              | 1/10 [00:00<00:05,  1.69it/s][A
 20%|██████████▍                                         | 2/10 [00:01<00:04,  1.68it/s][A
 30%|███████████████▌                                    | 3/10 [00:01<00:04,  1.59it/s][A
 40%|████████████████████▊                               | 4/10 [00:02<00:03,  1.57it/s][A
 50%|██████████████████████████          

tokenizing: 10it [00:00, 204.12it/s]

  0%|                                                            | 0/10 [00:00<?, ?it/s][A
 10%|█████▏                                              | 1/10 [00:00<00:06,  1.30it/s][A
 20%|██████████▍                                         | 2/10 [00:01<00:06,  1.27it/s][A
 30%|███████████████▌                                    | 3/10 [00:02<00:05,  1.31it/s][A
 40%|████████████████████▊                               | 4/10 [00:03<00:04,  1.33it/s][A
 50%|██████████████████████████                          | 5/10 [00:03<00:03,  1.40it/s][A
 60%|███████████████████████████████▏                    | 6/10 [00:04<00:02,  1.39it/s][A
 70%|████████████████████████████████████▍               | 7/10 [00:05<00:02,  1.39it/s][A
 80%|█████████████████████████████████████████▌          | 8/10 [00:05<00:01,  1.36it/s][A
 90%|██████████████████████████████████████████████▊     | 9/10 [00:06<00:00,  1.37it/s][A
100%|█████████████████████████████████████

 40%|████████████████████▊                               | 4/10 [00:03<00:05,  1.08it/s][A
 50%|██████████████████████████                          | 5/10 [00:04<00:04,  1.16it/s][A
 60%|███████████████████████████████▏                    | 6/10 [00:05<00:03,  1.28it/s][A
 70%|████████████████████████████████████▍               | 7/10 [00:05<00:02,  1.38it/s][A
 80%|█████████████████████████████████████████▌          | 8/10 [00:06<00:01,  1.48it/s][A
 90%|██████████████████████████████████████████████▊     | 9/10 [00:07<00:00,  1.50it/s][A
100%|███████████████████████████████████████████████████| 10/10 [00:07<00:00,  1.27it/s][A

tokenizing: 10it [00:00, 270.75it/s]

  0%|                                                            | 0/10 [00:00<?, ?it/s][A
 10%|█████▏                                              | 1/10 [00:00<00:06,  1.33it/s][A
 20%|██████████▍                                         | 2/10 [00:01<00:06,  1.31it/s][A
 30%|███████████████▌                    

100%|███████████████████████████████████████████████████| 10/10 [00:07<00:00,  1.38it/s][A

tokenizing: 10it [00:00, 370.47it/s]

  0%|                                                            | 0/10 [00:00<?, ?it/s][A
 10%|█████▏                                              | 1/10 [00:00<00:05,  1.55it/s][A
 20%|██████████▍                                         | 2/10 [00:01<00:04,  1.61it/s][A
 30%|███████████████▌                                    | 3/10 [00:01<00:04,  1.65it/s][A
 40%|████████████████████▊                               | 4/10 [00:02<00:03,  1.70it/s][A
 50%|██████████████████████████                          | 5/10 [00:03<00:02,  1.69it/s][A
 60%|███████████████████████████████▏                    | 6/10 [00:03<00:02,  1.70it/s][A
 70%|████████████████████████████████████▍               | 7/10 [00:04<00:01,  1.72it/s][A
 80%|█████████████████████████████████████████▌          | 8/10 [00:04<00:01,  1.75it/s][A
 90%|████████████████████████████████████

 30%|███████████████▌                                    | 3/10 [00:01<00:03,  1.82it/s][A
 40%|████████████████████▊                               | 4/10 [00:02<00:03,  1.84it/s][A
 50%|██████████████████████████                          | 5/10 [00:02<00:02,  1.85it/s][A
 60%|███████████████████████████████▏                    | 6/10 [00:03<00:02,  1.88it/s][A
 70%|████████████████████████████████████▍               | 7/10 [00:03<00:01,  1.94it/s][A
 80%|█████████████████████████████████████████▌          | 8/10 [00:04<00:01,  1.96it/s][A
 90%|██████████████████████████████████████████████▊     | 9/10 [00:04<00:00,  2.00it/s][A
100%|███████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.93it/s][A

tokenizing: 10it [00:00, 328.16it/s]

  0%|                                                            | 0/10 [00:00<?, ?it/s][A
 10%|█████▏                                              | 1/10 [00:00<00:04,  1.96it/s][A
 20%|██████████▍                         

 80%|█████████████████████████████████████████▌          | 8/10 [00:03<00:00,  2.02it/s][A
 90%|██████████████████████████████████████████████▊     | 9/10 [00:04<00:00,  2.03it/s][A
100%|███████████████████████████████████████████████████| 10/10 [00:04<00:00,  2.01it/s][A

tokenizing: 10it [00:00, 232.75it/s]

  0%|                                                            | 0/10 [00:00<?, ?it/s][A
 10%|█████▏                                              | 1/10 [00:00<00:04,  2.05it/s][A
 20%|██████████▍                                         | 2/10 [00:01<00:04,  1.81it/s][A
 30%|███████████████▌                                    | 3/10 [00:01<00:03,  1.87it/s][A
 40%|████████████████████▊                               | 4/10 [00:02<00:03,  1.89it/s][A
 50%|██████████████████████████                          | 5/10 [00:02<00:02,  1.92it/s][A
 60%|███████████████████████████████▏                    | 6/10 [00:03<00:02,  1.94it/s][A
 70%|████████████████████████████████████

 10%|█████▏                                              | 1/10 [00:00<00:04,  2.00it/s][A
 20%|██████████▍                                         | 2/10 [00:00<00:03,  2.06it/s][A
 30%|███████████████▌                                    | 3/10 [00:01<00:03,  2.04it/s][A
 40%|████████████████████▊                               | 4/10 [00:02<00:03,  1.96it/s][A
 50%|██████████████████████████                          | 5/10 [00:02<00:02,  1.92it/s][A
 60%|███████████████████████████████▏                    | 6/10 [00:03<00:02,  1.96it/s][A
 70%|████████████████████████████████████▍               | 7/10 [00:03<00:01,  1.99it/s][A
 80%|█████████████████████████████████████████▌          | 8/10 [00:04<00:00,  2.02it/s][A
 90%|██████████████████████████████████████████████▊     | 9/10 [00:04<00:00,  2.04it/s][A
100%|███████████████████████████████████████████████████| 10/10 [00:04<00:00,  2.01it/s][A

tokenizing: 10it [00:00, 445.95it/s]

  0%|                                    

1

In [78]:
best_model_index

1

In [82]:
best_model

<__main__.PromptClassifier at 0x7fe8bd9b3400>

In [83]:
f1_scores

[0.0,
 0.6666666666666666,
 0.6666666666666666,
 0.6666666666666666,
 0.5714285714285714,
 0.6666666666666666,
 0.6666666666666666,
 0.6666666666666666,
 0.0,
 0.0]

In [84]:
df.shape

(159571, 3)