In [8]:
import torch
device = 'cuda:1'

In [9]:
import loader
data_path = 'dataset/FSS/FewShotSST/dev.tsv'
dev_data = loader.read_label_data(data_path)

In [10]:
from openprompt.data_utils import InputExample
classes = [ # There are two classes in Sentiment Analysis, one for negative and one for positive
    "negative",
    "positive"
]
dataset = [
    InputExample(guid=label, text_a=sentence) for sentence, label in dev_data
]

In [11]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm("bert", "bert-base-cased")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [12]:
from openprompt.prompts import ManualTemplate
promptTemplate = ManualTemplate(
    text = '{"placeholder":"text_a"} It was {"mask"}',
    tokenizer = tokenizer,
)

In [13]:
from openprompt.prompts import ManualVerbalizer
promptVerbalizer = ManualVerbalizer(
    classes = classes,
    label_words = {
        "negative": ["bad", "stupid", "horrible", "awful", "disaster"],
        "positive": ["good", "wonderful", "great", "beautiful", "lovely"],
    },
    tokenizer = tokenizer,
)

In [14]:
from openprompt import PromptForClassification
promptModel = PromptForClassification(
    template = promptTemplate,
    plm = plm,
    verbalizer = promptVerbalizer,
).to(device)

In [15]:
from openprompt import PromptDataLoader
data_loader = PromptDataLoader(
    dataset = dataset,
    tokenizer = tokenizer,
    template = promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
)

tokenizing: 1101it [00:01, 754.69it/s]


In [16]:
# making zero-shot inference using pretrained MLM with prompt
promptModel.eval()
pred_logits = []
with torch.no_grad():
    for batch in data_loader:
        batch = batch.to(device)
        logits = promptModel(batch)
        preds = torch.argmax(logits, dim = -1)
        pred_logits.append(preds.item())
# predictions would be 1, 0 for classes 'positive', 'negative'

In [17]:
print(pred_logits)

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [18]:
true_logit = [label for _, label in dev_data]

In [19]:
total_count = len(true_logit)
correct_count = 0
for i in range(total_count):
    if true_logit[i] == pred_logits[i]:
        correct_count += 1
acc = correct_count / total_count
print("Validation accuracy: %.3f" % acc)

Validation accuracy: 0.510


In [18]:
testdata_path = 'dataset/FSS/FewShotSST/test.tsv'
test_data = loader.read_unlabel_data(testdata_path)
testset = [
    InputExample(text_a=sentence) for _, sentence in test_data
]
test_loader = PromptDataLoader(
    dataset = testset,
    tokenizer = tokenizer,
    template = promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
)
'''Test on Test set'''
promptModel.eval()
test_result = []
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        logits = promptModel(batch)
        preds = torch.argmax(logits, dim = -1)
        test_result.append(preds.item())

tokenizing: 2210it [00:01, 1408.37it/s]


In [20]:
with open('0.tsv', 'w') as f:
    f.write('index\tprediction\n')
    for i in range(len(test_result)):
        f.write('%d\t%d\n' % (i, test_result[i]))
    f.close()