In [None]:
!pip install openprompt


We simulate a 2-class problem with classes being sports and health. We also define three input examples for which we are interested in getting the classification labels. 

In [None]:
from openprompt.data_utils import InputExample
classes = [
    "Sports",
    "Health"
]
dataset = [
    InputExample(
        guid=0,
        text_a="Cricket is a really popular sport in India.",
    ),
    InputExample(
        guid=1,
        text_a="Coronavirus is an infectious disease.",
    ),
    InputExample(
        guid=2,
        text_a="It's common to get hurt while doing stunts.",
    )
]


Next, we load our language model and we choose RoBERTa for our purposes.

In [None]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm(
    "roberta", "roberta-base")


Next, we define our template that allows us to put in our input example stored in "text_a" variable dynamically. The {"mask"} token is what the model fills-in. Feel free to check out How to Write a Template? for more detailed steps in designing yours.

In [None]:
from openprompt.prompts import ManualTemplate
promptTemplate = ManualTemplate(
    text='{"placeholder":"text_a"} It was {"mask"}',
    tokenizer=tokenizer,
)


Next, we define verbalizer that allows us to project our model's prediction to our pre-defined class labels. Feel free to check out How to Write a Verbalizer? for more detailed steps in designing yours.

In [None]:
from openprompt.prompts import ManualVerbalizer
promptVerbalizer = ManualVerbalizer(
    classes=classes,
    label_words={
        "Health": ["Medicine"],
        "Sports": ["Game", "Play"],
    },
    tokenizer=tokenizer,
)


Next, we create our prompt model for classification by passing in necessary parameters like templates, language model and verbalizer.

In [None]:
from openprompt import PromptForClassification
promptModel = PromptForClassification(
    template=promptTemplate,
    plm=plm,
    verbalizer=promptVerbalizer,
)


Next, we create our data loader for sampling mini-batches from a dataset.

In [None]:
from openprompt import PromptDataLoader
data_loader = PromptDataLoader(
    dataset=dataset,
    tokenizer=tokenizer,
    template=promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
)


Next, we set our model in evaluation mode and make prediction for each of the input example in a Masked-language model (MLM) fashion. 

In [None]:
import torch

promptModel.eval()
with torch.no_grad():
    for batch in data_loader:
        logits = promptModel(batch)
        preds = torch.argmax(logits, dim=-1)
        print(tokenizer.decode(batch['input_ids'][0],
                               skip_special_tokens=True), classes[preds])
