In [2]:
from openprompt.data_utils import InputExample

In [3]:
classes = ['negative', 'positive']
dataset = [
    InputExample(
        guid = 0,
        text_a = 'Albert Einstein was one of the greatest intellects of his time.'
    ),
    InputExample(
        guid = 1,
        text_a = 'The film was badly made.'
    )
]

In [4]:
from openprompt.plms import load_plm

In [11]:
plm, tokenizer, model_config, WrapperClass = load_plm('bert', 'bert-base-cased')

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

In [12]:
from openprompt.prompts import ManualTemplate

In [13]:
promptTemplate = ManualTemplate(
    text = '{"placeholder": "text_a"} It was {"mask"}',
    tokenizer = tokenizer
)

In [14]:
from openprompt.prompts import ManualVerbalizer

In [15]:
promptVerbalizer = ManualVerbalizer(
    classes = classes,
    label_words = {
        'negative': ['bad'],
        'positive': ['good', 'wonderful', 'great'],
    },
    tokenizer = tokenizer,
)

In [16]:
from openprompt import PromptForClassification

promptModel = PromptForClassification(
    template = promptTemplate,
    plm = plm,
    verbalizer = promptVerbalizer,
)

In [17]:
from openprompt import PromptDataLoader

data_loader = PromptDataLoader(
    dataset = dataset,
    tokenizer = tokenizer,
    template = promptTemplate,
    tokenizer_wrapper_class = WrapperClass,
)

tokenizing: 2it [00:00, 22.98it/s]


In [20]:
import torch

promptModel.eval()
with torch.no_grad():
    for batch in data_loader:
        logits = promptModel(batch)
        preds = torch.argmax(logits, dim = -1)
        print(f'logits: {logits}, classes: {classes[preds]}')

logits: tensor([[-3.2480, -1.4555]]), classes: positive
logits: tensor([[-1.3735, -1.8500]]), classes: negative
