In [1]:
from datasets import load_dataset
import loader
device = 'cuda:1'
data_root_dir = 'dataset/FSS/FewShotSST/'
split = {'train': 'train_64.tsv', 'validation': 'dev.tsv', 'test': 'test.tsv'}
raw_dataset = loader.generate_dialogue(data_root_dir, split)
raw_dataset['train'][0]

{'premise': "The most repugnant adaptation of a classic text since Roland Joff茅 and Demi Moore 's The Scarlet Letter .",
 'hypothesis': 'it was negative',
 'idx': 0,
 'label': 0}

In [2]:
from openprompt.data_utils import InputExample

In [3]:
dataset = {}
for split in ['train', 'validation', 'test']:
    dataset[split] = []
    for data in raw_dataset[split]:
        input_example = InputExample(text_a = data['premise'], text_b = data['hypothesis'], label=int(data['label']), guid=data['idx'])
        dataset[split].append(input_example)
print(dataset['train'][0])

{
  "guid": 0,
  "label": 0,
  "meta": {},
  "text_a": "The most repugnant adaptation of a classic text since Roland Joff\u8305 and Demi Moore 's The Scarlet Letter .",
  "text_b": "it was negative",
  "tgt_text": null
}



In [4]:
from openprompt.plms import load_plm

In [27]:
plm, tokenizer, model_config, WrapperClass = load_plm("bert", "bert-base-cased")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [28]:
from openprompt.prompts import MixedTemplate

In [29]:
mytemplate1 = MixedTemplate(model=plm, tokenizer=tokenizer, text='{"placeholder":"text_a"} {"soft": "Question:"} {"placeholder":"text_b"}? Is it correct? {"mask"}.')
mytemplate = MixedTemplate(model=plm, tokenizer=tokenizer, text='{"placeholder":"text_a"} {"soft"} {"soft"} {"soft"} {"placeholder":"text_b"} {"soft"} {"mask"}.')
choose_template = mytemplate

In [30]:
wrapped_example = mytemplate.wrap_one_example(dataset['train'][0]) 
print(wrapped_example)

[[{'text': "The most repugnant adaptation of a classic text since Roland Joff茅 and Demi Moore 's The Scarlet Letter .", 'soft_token_ids': 0, 'loss_ids': 0, 'shortenable_ids': 1}, {'text': '', 'soft_token_ids': 1, 'loss_ids': 0, 'shortenable_ids': 0}, {'text': '', 'soft_token_ids': 2, 'loss_ids': 0, 'shortenable_ids': 0}, {'text': '', 'soft_token_ids': 3, 'loss_ids': 0, 'shortenable_ids': 0}, {'text': ' it was negative', 'soft_token_ids': 0, 'loss_ids': 0, 'shortenable_ids': 1}, {'text': '', 'soft_token_ids': 4, 'loss_ids': 0, 'shortenable_ids': 0}, {'text': '<mask>', 'soft_token_ids': 0, 'loss_ids': 1, 'shortenable_ids': 0}, {'text': '.', 'soft_token_ids': 0, 'loss_ids': 0, 'shortenable_ids': 0}], {'guid': 0, 'label': 0}]


In [31]:
wrapped_tokenizer = WrapperClass(max_seq_length=128, decoder_max_length=3, tokenizer=tokenizer,truncate_method="head")

In [32]:
from openprompt import PromptDataLoader

train_dataloader = PromptDataLoader(dataset=dataset["train"], template=choose_template, tokenizer=tokenizer, 
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3, 
    batch_size=4,shuffle=True, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")

tokenizing: 64it [00:00, 799.25it/s]


In [33]:
from openprompt.prompts import ManualVerbalizer
import torch

In [34]:
from openprompt.prompts import SoftVerbalizer
promptVerbalizer1 = ManualVerbalizer(tokenizer, num_classes=2, label_words=[["yes"], ["no"], ["maybe"]])
promptVerbalizer2 = SoftVerbalizer(tokenizer=tokenizer, plm=plm, num_classes=2)
myverbalizer = promptVerbalizer1

In [35]:
logits = torch.randn(2,len(tokenizer)) # creating a pseudo output from the plm
# myverbalizer.process_logits(logits)

In [36]:
from openprompt import PromptForClassification

use_cuda = True
prompt_model = PromptForClassification(plm=plm,template=choose_template, verbalizer=myverbalizer, freeze_plm=False)
if use_cuda:
    prompt_model=  prompt_model.to(device)

In [37]:
from transformers import  AdamW, get_linear_schedule_with_warmup
from torch.optim import SGD
loss_func = torch.nn.CrossEntropyLoss()

no_decay = ['bias', 'LayerNorm.weight']

# it's always good practice to set no decay to biase and LayerNorm parameters
optimizer_grouped_parameters1 = [
    # {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    # {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.5},
    {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.5}
]

# Using different optimizer for prompt parameters and model parameters
optimizer_grouped_parameters2 = [
    {'params': [p for n,p in prompt_model.template.named_parameters() if "raw_embedding" not in n]}
]

# optimizer1 = AdamW(optimizer_grouped_parameters1, lr=1e-4)
# optimizer2 = AdamW(optimizer_grouped_parameters2, lr=1e-3)
# optimizer1 = AdamW(optimizer_grouped_parameters1, lr=2.75e-4)
# optimizer2 = AdamW(optimizer_grouped_parameters2, lr=1e-3)
optimizer1 = SGD(prompt_model.parameters(), lr=9e-6)

epochs = 15
for epoch in range(1, epochs + 1):
    tot_loss = 0 
    batch_cnt = 0
    for step, inputs in enumerate(train_dataloader):
        if use_cuda:
            inputs = inputs.to(device)
        logits = prompt_model(inputs)
        labels = inputs['label']
        loss = loss_func(logits, labels)
        loss.backward()
        tot_loss += loss.item()
        optimizer1.step()
        optimizer1.zero_grad()
        # optimizer2.step()
        # optimizer2.zero_grad()
        batch_cnt += 1
    tot_loss /= batch_cnt
    print('epoch: %d\tTraining Loss: %5lf' % (epoch, tot_loss))

validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=choose_template, tokenizer=tokenizer, 
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3, 
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")


allpreds = []
alllabels = []
for step, inputs in enumerate(validation_dataloader):
    if use_cuda:
        inputs = inputs.to(device)
    logits = prompt_model(inputs)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
print('Validation Accuracy: %.3lf' % acc)

epoch: 1	Training Loss: 1.644167
epoch: 2	Training Loss: 1.325929
epoch: 3	Training Loss: 1.149788
epoch: 4	Training Loss: 1.034256
epoch: 5	Training Loss: 0.945933
epoch: 6	Training Loss: 0.872767
epoch: 7	Training Loss: 0.810137
epoch: 8	Training Loss: 0.752352
epoch: 9	Training Loss: 0.701842
epoch: 10	Training Loss: 0.651557
epoch: 11	Training Loss: 0.609833
epoch: 12	Training Loss: 0.571508
epoch: 13	Training Loss: 0.536904
epoch: 14	Training Loss: 0.505138
epoch: 15	Training Loss: 0.476056


tokenizing: 1101it [00:00, 1224.10it/s]


Validation Accuracy: 0.794


In [20]:
# %%
validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=choose_template, tokenizer=tokenizer, 
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3, 
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")


allpreds = []
alllabels = []
for step, inputs in enumerate(validation_dataloader):
    if use_cuda:
        inputs = inputs.to(device)
    logits = prompt_model(inputs)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
print(acc)

tokenizing: 1101it [00:00, 1339.42it/s]


1.0


In [21]:
print(allpreds)

[1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 

In [17]:
testdata_path = 'dataset/FSS/FewShotSST/test.tsv'
test_data = loader.read_unlabel_data(testdata_path)
testset = [
    InputExample(text_a=sentence) for _, sentence in test_data
]
test_dataloader = PromptDataLoader(
    dataset = testset,
    tokenizer = tokenizer,
    template = choose_template,
    tokenizer_wrapper_class=WrapperClass,
)


test_preds = []
for step, inputs in enumerate(test_dataloader):
    if use_cuda:
        inputs = inputs.to(device)
    logits = prompt_model(inputs)
    test_preds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

tokenizing: 2210it [00:01, 1369.50it/s]


In [18]:
print(test_preds)

[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 