In [39]:
from openprompt.data_utils import InputExample
classes = [
    "Learning",
    "Sports",
    "Health",
]
dataset = [
    InputExample(
        guid=0,
        text_a="football is a really popular in many countries.",
    ),
    InputExample(
        guid=1,
        text_a="Coronavirus is an infectious disease.",
    ),
    InputExample(
        guid=2,
        text_a="It's common to get hurt while doing stunts.",
    ),
    InputExample(
        guid=3,
        text_a="Machine learning is an important part of artificial intelligence.",
    ),
    InputExample(
        guid=4,
        text_a="School is where students come to study.",
    )
]

In [40]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm(
    "roberta", "roberta-base")

In [41]:
from openprompt.prompts import ManualTemplate
promptTemplate = ManualTemplate(
    text='{"placeholder":"text_a"} The topic is {"mask"}',
    tokenizer=tokenizer,
)

In [42]:
from openprompt.prompts import ManualVerbalizer
promptVerbalizer = ManualVerbalizer(
    classes=classes,
    label_words={
        "Health": ["Medicine"],
        "Sports": ["Game", "Play"],
        "Learning": ["Study", "learning"],
    }, # type: ignore
    tokenizer=tokenizer,
)

In [43]:
from openprompt import PromptForClassification
promptModel = PromptForClassification(
    template=promptTemplate,
    plm=plm, # type: ignore
    verbalizer=promptVerbalizer,
)

In [44]:
from openprompt import PromptDataLoader
data_loader = PromptDataLoader(
    dataset=dataset,
    tokenizer=tokenizer,
    template=promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
)

tokenizing: 5it [00:00, 474.29it/s]


In [45]:
import torch

promptModel.eval()
with torch.no_grad():
    for batch in data_loader:
        logits = promptModel(batch)
        preds = torch.argmax(logits, dim=-1)
        print(tokenizer.decode(batch['input_ids'][0],
              skip_special_tokens=True), classes[preds])

football is a really popular in many countries. The topic is Sports
Coronavirus is an infectious disease. The topic is Health
It's common to get hurt while doing stunts. The topic is Health
Machine learning is an important part of artificial intelligence. The topic is Learning
School is where students come to study. The topic is Learning
