In [None]:
!pip install datasets
!pip install openprompt
from datasets import load_dataset, Dataset
from openprompt.plms import load_plm
from openprompt.prompts import MixedTemplate, SoftVerbalizer, ManualVerbalizer
from openprompt.data_utils import InputExample
from openprompt import PromptDataLoader, PromptForClassification
from transformers import  AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import torch

In [None]:
rawdata = load_dataset("csv", data_files={'train': ["../data/corpus_train.csv"], 'validation': ["../data/corpus_valid.csv"], 'test':["../data/corpus_test.csv"]})
rawdata = rawdata.filter(lambda example: example['label']!='title') 
rawdata = rawdata.filter(lambda example: example['label']!='common-ground') 

In [None]:
# load the plm
plm, tokenizer, model_config, WrapperClass = load_plm("t5", "t5-base")

In [None]:
# construct a template

#template_text = '{"placeholder": "text_a"} {"soft": "In this sentence, the topic is"} {"mask"} {"soft"}.'
#{"placeholder": "text_a"} {"placeholder": "text_b"} {"soft":"This"} topic {"soft":"is about"} {"mask"}.
#template_text = '{"placeholder":"text_a"} {"soft"} {"soft"} {"soft"} {"soft"} {"mask"}.'
#template_text = 'In this argumentative text with the title {"meta": "title", "shortenable": False}, the role of this sentence: {"meta": "sentence", "shortenable": False}, is {"mask"}.'
#template_text = 'In an argumentative text, the role of this sentence: {"placeholder": "text_a"}, is {"mask"}.'
template_text = '{"placeholder":"text_a"} {"soft": None, "duplicate": 20} {"mask"}.' #optimal one
mytemplate = MixedTemplate(model=plm, tokenizer=tokenizer, text=template_text)

In [None]:
#title and common-ground are removed
def int_label(label):
    if label == "assumption": return 0
    elif label == "testimony": return 1
    elif label == "anecdote": return 2
    elif label == "statistics": return 3
    #elif label == "title": return 4
    #elif label == "common-ground": return 4
    elif label == "other": return 4

In [None]:
# convert our raw data to openprompt's form
dataset = {}
for split in ['train', 'validation', 'test']:
    dataset[split] = []
    for data in rawdata[split]:
        #input_example = InputExample(meta={"sentence": data['sentence']}, label = int_label(data['label']), guid=data['article_id'])
        input_example = InputExample(text_a= data['sentence'], label = int_label(data['label']), guid=data['article_id'])
        dataset[split].append(input_example)
print(dataset['train'][0])

In [None]:
train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=True, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")

In [None]:
myverbalizer = SoftVerbalizer(tokenizer, plm, num_classes=5)

In [None]:
use_cuda = True
prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=False)
if use_cuda: prompt_model=  prompt_model.cuda()

In [None]:
loss_func = torch.nn.CrossEntropyLoss()
no_decay = ['bias', 'LayerNorm.weight']

# set no decay to biase and LayerNorm parameters
optimizer_grouped_parameters1 = [
    {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

# different optimizer for prompt and model 
optimizer_grouped_parameters2 = [
    {'params': prompt_model.verbalizer.group_parameters_1, "lr":3e-5},
    {'params': prompt_model.verbalizer.group_parameters_2, "lr":3e-4},
]


optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5)
optimizer2 = AdamW(optimizer_grouped_parameters2)


In [None]:
for epoch in range(5):
    tot_loss = 0
    for step, inputs in enumerate(train_dataloader):
        if use_cuda:
            inputs = inputs.cuda()
        logits = prompt_model(inputs)
        labels = inputs['label']
        loss = loss_func(logits, labels)
        loss.backward()
        tot_loss += loss.item()
        optimizer1.step()
        optimizer1.zero_grad()
        optimizer2.step()
        optimizer2.zero_grad()
        if step%100 ==0: 
          print("epoch = {}, step = {}, tot_loss/(step+1) = {}".format(epoch, step, tot_loss/(step+1)))

In [None]:
# ## evaluate

# %%
validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")

prompt_model.eval()

allpreds = []
alllabels = []
for step, inputs in enumerate(validation_dataloader):
    if use_cuda:
        inputs = inputs.cuda()
    logits = prompt_model(inputs)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
print("validation:",acc)

print(classification_report(alllabels, allpreds, zero_division=0))



In [None]:
#validation:
test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")
allpreds = []
alllabels = []
for step, inputs in enumerate(test_dataloader):
    if use_cuda:
        inputs = inputs.cuda()
    logits = prompt_model(inputs)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
print("test:", acc)  # roughly ~0.85
print(classification_report(alllabels, allpreds, zero_division=0))

In [None]:
rawcmv = load_dataset("csv", data_files={'CMV': ["../data/cmv_train.csv"]})
cmvdata = []
for data in rawcmv['CMV']:
    input_example = InputExample(text_a= data['sentence'], label = int_label(data['label']))
    cmvdata.append(input_example)
print(cmvdata[0])

In [None]:
#test on CMV data:
cmv_dataloader = PromptDataLoader(dataset= cmvdata, template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")
allpreds = []
alllabels = []
for step, inputs in enumerate(cmv_dataloader):
    if use_cuda:
        inputs = inputs.cuda()
    logits = prompt_model(inputs)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
print("test:", acc)  # roughly ~0.85
print(classification_report(alllabels, allpreds, zero_division=0))

In [None]:
disp = ConfusionMatrixDisplay(confusion_matrix(alllabels, allpreds))
disp.plot()
plt.show()