In [3]:
import torch
device = 'cuda:1'

In [4]:
import loader
data_path = 'dataset/FSS/FewShotSST/train_128.tsv'
train_data = loader.read_label_data(data_path)

In [5]:
from openprompt.data_utils import InputExample
classes = [ # There are two classes in Sentiment Analysis, one for negative and one for positive
    "negative",
    "positive"
]
dataset = [
    InputExample(guid=label, text_a=sentence) for sentence, label in train_data
]

In [6]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm("bert", "bert-base-cased")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
from openprompt.prompts import ManualTemplate
promptTemplate = ManualTemplate(
    text = '{"placeholder":"text_a"} It was {"mask"}',
    tokenizer = tokenizer,
)

In [8]:
from openprompt.prompts import ManualVerbalizer, SoftVerbalizer
promptVerbalizer1 = ManualVerbalizer(
    classes = classes,
    label_words = {
        "negative": ["bad", "stupid", "horrible", "awful", "disaster"],
        "positive": ["good", "wonderful", "great", "beautiful", "lovely"],
    },
    tokenizer = tokenizer,
)
promptVerbalizer2 = SoftVerbalizer(tokenizer=tokenizer, plm=plm, num_classes=2)
chooseVerbalizer = promptVerbalizer2

In [9]:
from openprompt import PromptForClassification
prompt_model = PromptForClassification(
    template = promptTemplate,
    plm = plm,
    verbalizer = chooseVerbalizer,
    freeze_plm=False
).to(device)

In [10]:
from openprompt import PromptDataLoader
train_dataloader = PromptDataLoader(
    dataset = dataset,
    tokenizer = tokenizer,
    template = promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
)

tokenizing: 128it [00:00, 1614.34it/s]


In [11]:
from transformers import  AdamW, get_linear_schedule_with_warmup
loss_func = torch.nn.CrossEntropyLoss()

no_decay = ['bias', 'LayerNorm.weight']

# it's always good practice to set no decay to biase and LayerNorm parameters
optimizer_grouped_parameters1 = [
    {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

# Using different optimizer for prompt parameters and model parameters
optimizer_grouped_parameters2 = [
    {'params': [p for n,p in prompt_model.template.named_parameters() if "raw_embedding" not in n]}
]

optimizer1 = AdamW(optimizer_grouped_parameters1, lr=1e-4)
optimizer2 = AdamW(optimizer_grouped_parameters2, lr=1e-3)

num_epochs = 10
for epoch in range(num_epochs):
    tot_loss = 0 
    for step, inputs in enumerate(train_dataloader):
        inputs = inputs.to(device)
        logits = prompt_model(inputs)
        labels = inputs['guid'].to(device)
        loss = loss_func(logits, labels)
        loss.backward()
        tot_loss += loss.item()
        optimizer1.step()
        optimizer1.zero_grad()
        optimizer2.step()
        optimizer2.zero_grad()
        print(tot_loss/(step+1))

0.6461499333381653
0.32451853179372847
0.21634954227374692
4.3529297859472535
3.482560934043067
3.7461213381887624
3.5891133471728347
3.277431794390168
3.052184642142796
2.8637226648212164
2.650717817209046
2.5384911453625136
2.446076645842018
2.299532658789888
2.174836272470202
2.0641957004294227
1.9683119711385455
1.8762369308803197
1.8607300978522912
1.8365497428715571
1.8061853097109892
1.7422199834964802
1.6883373276513853
1.656466937606562
1.6108036723566328
1.5777577382280539
1.5453755913354985
1.5117170100084618
1.493794557591988
1.4783752493421465
1.4460723546589795
1.4147052143104588
1.3858073148185763
1.3776284694987386
1.3507390288250982
1.3450056904190306
1.3381911220067442
1.3275769880225965
1.3144828044726473
1.2985719004541807
1.284261486323885
1.267438293380413
1.2505922104286753
1.2420809900462784
1.2343390684896198
1.2267427737531604
1.2117226080919978
1.2055552907511735
1.1923669365206684
1.1797626443123954
1.1674960858892127
1.1544138738948697
1.1507273357072096
1.

In [12]:
val_data_path = 'dataset/FSS/FewShotSST/dev.tsv'
val_data = loader.read_label_data(val_data_path)
val_dataset = [
    InputExample(guid=label, text_a=sentence) for sentence, label in val_data
]
val_dataloader = PromptDataLoader(
    dataset = val_dataset,
    tokenizer = tokenizer,
    template = promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
)

tokenizing: 1101it [00:00, 1497.72it/s]


In [13]:
prompt_model.eval()
pred_logits = []
with torch.no_grad():
    for batch in val_dataloader:
        batch = batch.to(device)
        logits = prompt_model(batch)
        preds = torch.argmax(logits, dim = -1)
        pred_logits.append(preds.item())

In [14]:
print(pred_logits)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [15]:
true_logit = [label for _, label in val_data]

In [16]:
total_count = len(true_logit)
correct_count = 0
for i in range(total_count):
    if true_logit[i] == pred_logits[i]:
        correct_count += 1
acc = correct_count / total_count
print("Accuracy: %.3f" % acc)

Accuracy: 0.507


In [17]:
testdata_path = 'dataset/FSS/FewShotSST/test.tsv'
test_data = loader.read_unlabel_data(testdata_path)
testset = [
    InputExample(text_a=sentence) for _, sentence in test_data
]
test_dataloader = PromptDataLoader(
    dataset = testset,
    tokenizer = tokenizer,
    template = promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
)


test_preds = []
for step, inputs in enumerate(test_dataloader):
    inputs = inputs.to(device)
    logits = prompt_model(inputs)
    test_preds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

tokenizing: 2210it [00:01, 1591.60it/s]


In [18]:
print(test_preds)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 