In [1]:
import torch
device = 'cuda:1'

In [2]:
import loader
data_path = 'dataset/FSS/FewShotSST/train_64.tsv'
train_data = loader.read_label_data(data_path)

In [3]:
from openprompt.data_utils import InputExample
classes = [ # There are two classes in Sentiment Analysis, one for negative and one for positive
    "negative",
    "positive"
]
dataset = [
    InputExample(guid=label, text_a=sentence) for sentence, label in train_data
]

In [4]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm("bert", "bert-base-cased")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
from openprompt.prompts import ManualTemplate
promptTemplate = ManualTemplate(
    text = '{"placeholder":"text_a"} It was {"mask"}',
    tokenizer = tokenizer,
)

In [6]:
from openprompt.prompts import ManualVerbalizer
promptVerbalizer = ManualVerbalizer(
    classes = classes,
    label_words = {
        "negative": ["bad", "stupid", "horrible", "awful", "disaster"],
        "positive": ["good", "wonderful", "great", "beautiful", "lovely"],
    },
    tokenizer = tokenizer,
)

In [7]:
from openprompt import PromptForClassification
prompt_model = PromptForClassification(
    template = promptTemplate,
    plm = plm,
    verbalizer = promptVerbalizer,
    freeze_plm=False
).to(device)

In [8]:
from openprompt import PromptDataLoader
train_dataloader = PromptDataLoader(
    dataset = dataset,
    tokenizer = tokenizer,
    template = promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
)

tokenizing: 64it [00:00, 929.07it/s]


In [9]:
from transformers import  AdamW, get_linear_schedule_with_warmup
from torch.optim import SGD
loss_func = torch.nn.CrossEntropyLoss()

optimizer1 = SGD(prompt_model.parameters(), lr=1e-6)

epochs = 15
for epoch in range(1, epochs + 1):
    tot_loss = 0 
    batch_cnt = 0
    for step, inputs in enumerate(train_dataloader):
        inputs = inputs.to(device)
        logits = prompt_model(inputs)
        labels = inputs['guid'].to(device)
        loss = loss_func(logits, labels)
        loss.backward()
        tot_loss += loss.item()
        optimizer1.step()
        optimizer1.zero_grad()
        # optimizer2.step()
        # optimizer2.zero_grad()
        batch_cnt += 1
    tot_loss /= batch_cnt
    print('epoch: %d\tTraining Loss: %5lf' % (epoch, tot_loss))

val_data_path = 'dataset/FSS/FewShotSST/dev.tsv'
val_data = loader.read_label_data(val_data_path)
val_dataset = [
    InputExample(guid=label, text_a=sentence) for sentence, label in val_data
]
val_dataloader = PromptDataLoader(
    dataset = val_dataset,
    tokenizer = tokenizer,
    template = promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
)
prompt_model.eval()
pred_logits = []
with torch.no_grad():
    for batch in val_dataloader:
        batch = batch.to(device)
        logits = prompt_model(batch)
        preds = torch.argmax(logits, dim = -1)
        pred_logits.append(preds.item())
true_logit = [label for _, label in val_data]
total_count = len(true_logit)
correct_count = 0
for i in range(total_count):
    if true_logit[i] == pred_logits[i]:
        correct_count += 1
acc = correct_count / total_count
print("Validation Accuracy: %.3f" % acc)

epoch: 1	Training Loss: 0.840910
epoch: 2	Training Loss: 0.817890
epoch: 3	Training Loss: 0.796914
epoch: 4	Training Loss: 0.777836
epoch: 5	Training Loss: 0.760513
epoch: 6	Training Loss: 0.744806
epoch: 7	Training Loss: 0.730579
epoch: 8	Training Loss: 0.717708
epoch: 9	Training Loss: 0.706072
epoch: 10	Training Loss: 0.695558
epoch: 11	Training Loss: 0.686061
epoch: 12	Training Loss: 0.677485
epoch: 13	Training Loss: 0.669738
epoch: 14	Training Loss: 0.662737
epoch: 15	Training Loss: 0.656407


tokenizing: 1101it [00:00, 1222.00it/s]


Validation Accuracy: 0.631


In [26]:
testdata_path = 'dataset/FSS/FewShotSST/test.tsv'
test_data = loader.read_unlabel_data(testdata_path)
testset = [
    InputExample(text_a=sentence) for _, sentence in test_data
]
test_loader = PromptDataLoader(
    dataset = testset,
    tokenizer = tokenizer,
    template = promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
)
'''Test on Test set'''
prompt_model.eval()
test_result = []
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        logits = prompt_model(batch)
        preds = torch.argmax(logits, dim = -1)
        test_result.append(preds.item())

tokenizing: 2210it [00:01, 1589.42it/s]


In [27]:
with open('32.tsv', 'w') as f:
    f.write('index\tprediction\n')
    for i in range(len(test_result)):
        f.write('%d\t%d\n' % (i, test_result[i]))
    f.close()