In [1]:
# Import necessary libraries
from openprompt.plms import T5TokenizerWrapper
from datasets import load_from_disk
from openprompt.pipeline_base import PromptDataLoader
from transformers import T5ForConditionalGeneration, T5Tokenizer
from openprompt.prompts import ManualTemplate, MixedTemplate
from openprompt import PromptForClassification
from openprompt.data_utils import FewShotSampler
from random import shuffle
from transformers import AdamW
from transformers.optimization import get_linear_schedule_with_warmup
import torch
from openprompt.prompts import ManualVerbalizer
from openprompt.data_utils import InputExample
from tqdm import tqdm
import json

dataset_path = "/lustre/work/client/users/minhos/cache/datasets/p3_ropes_rc"
raw_dataset = load_from_disk(dataset_path)


dataset = {}
for split in ['train', 'validation']:
    dataset[split] = []
    raw_dataset[split] = raw_dataset[split].select(range(500))
    for idx, data in enumerate(raw_dataset[split]):
        dataset[split].append(data)

print(dataset['train'][0])
print(type(dataset['train'][0]))


# Load the T5 model
from openprompt.plms import load_plm
t5_path = "/lustre/work/client/users/minhos/models_for_supercomputer/t5-base"
model = T5ForConditionalGeneration.from_pretrained(t5_path)
tokenizer = T5Tokenizer.from_pretrained(t5_path)




# Logging setup
log_file = "qa_manual_template_multi_binary_t5.json"
results = []

# For Training, select 18 samples with the same context and different questions
data1 = dataset['train'][2:20]
dataset1 = []
for idx, data in enumerate(data1):
    question = data["inputs_pretokenized"]
    correct_answer = data["targets_pretokenized"].strip()
    label = 0 if correct_answer in ['cell X', 'Cell A', 'larger', 'more'] else 1  
    # Create an InputExample
    input_example = InputExample(
        text_a=question,
        label=label,
        guid=idx, # Assign a dummy label since there is only one answer
        meta={"correct_answer": correct_answer}
    )
    dataset1.append(input_example)

template1 = ManualTemplate(
    tokenizer=tokenizer,
    text='{"placeholder":"text_a"} The answer is: {"mask"}',
)
verbalizer1 = ManualVerbalizer(
    tokenizer=tokenizer,
    num_classes=2,
    label_words=[['cell X','cell A', 'larger','more'],['cell Z', 'cell B', 'smaller', 'less']]
)

prompt_model = PromptForClassification(
    plm=model,
    template=template1,
    verbalizer=verbalizer1,
    freeze_plm=False,
)
train_dataloader = PromptDataLoader(
    dataset = dataset1,
    template=template1,
    tokenizer=tokenizer,
    tokenizer_wrapper_class=T5TokenizerWrapper,
    decoder_max_length=68, max_seq_length=480,
    batch_size=1)

loss_func = torch.nn.CrossEntropyLoss()
no_decay = ['bias', 'LayerNorm.weight']
# it's always good practice to set no decay to biase and LayerNorm parameters
optimizer_grouped_parameters = [
    {'params': [p for n, p in prompt_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in prompt_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
from tqdm import tqdm
optimizer = AdamW(optimizer_grouped_parameters, lr=0.0001)
prompt_model.train()
for epoch in range(10):
    tot_loss = 0
    pbar = tqdm(train_dataloader, desc="Training")
    for step, inputs in enumerate(train_dataloader):
        logits = prompt_model(inputs)
        labels = inputs['label']
        loss = loss_func(logits, labels)
        loss.backward()
        tot_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        pbar.set_postfix({"loss": tot_loss / (step + 1)})
        if step %100 ==1:
            print("Epoch {}, average loss: {}".format(epoch+1, tot_loss/(step+1)), flush=True)


data2 = dataset['validation'][:1000]
dataset2 = []
for idx, data in enumerate(data2):
    question = data["inputs_pretokenized"]
    correct_answer = data["targets_pretokenized"].strip()
    # Create an InputExample
    input_example = InputExample(
        text_a=question,
        label= 0,
        guid=idx, # Assign a dummy label since there is only one answer
        meta={"correct_answer": correct_answer}
    )
    dataset2.append(input_example)
    
for idx, data in enumerate(dataset2):

    template2 = ManualTemplate(
        tokenizer=tokenizer,
        text='{"placeholder":"text_a"} The answer is: {"mask"}',
    )
    
    verbalizer2 = ManualVerbalizer(
        tokenizer=tokenizer,
        num_classes=2,
        label_words=[[data.meta['correct_answer']],["other"]]
    )

    prompt_model.template = template2
    prompt_model.verbalizer = verbalizer2

    validation_dataloader = PromptDataLoader(
        dataset = [data],
        template=template2,
        tokenizer=tokenizer,
        tokenizer_wrapper_class=T5TokenizerWrapper,
        decoder_max_length=3, max_seq_length=480,
        batch_size=1
        )

    prompt_model.eval()
    with torch.no_grad():
        for inputs in validation_dataloader:
            logits = prompt_model(inputs)
            preds = torch.argmax(logits, dim=-1)
            correct = preds.item() == data.label
    
    results.append({"index": idx, "correct": correct})

# Compute overall accuracy
accuracy = sum(r["correct"] for r in results) / len(results)
print(f"Validation Accuracy: {accuracy:.4f}")

# Save results to JSON
with open(log_file, "w") as f:
    json.dump(results, f, indent=4)

  from .autonotebook import tqdm as notebook_tqdm


{'inputs': [27, 54, 169, 48, 2458, 10, 3424, 757, 1855, 6986, 116, 3, 9, 8096, 9016, 190, 8, 2358, 13304, 406, 174, 53, 136, 827, 12, 1903, 190, 5, 100, 2906, 116, 3, 9, 8096, 6914, 45, 46, 616, 213, 34, 19, 72, 18054, 12, 46, 616, 213, 34, 19, 705, 18054, 5, 27664, 257, 19, 8, 381, 13, 14219, 13, 3, 9, 8096, 16, 3, 9, 787, 2908, 5, 1563, 31, 7, 497, 25, 27157, 3, 9, 21776, 13, 3136, 16, 3, 9, 4119, 13, 387, 5, 37, 29, 25, 27157, 192, 21776, 7, 13, 3136, 16, 430, 4119, 13, 387, 5, 37, 511, 1127, 56, 43, 3, 9, 1146, 6145, 13, 3136, 5, 852, 6, 27, 43, 3, 9, 126, 1419, 10, 71, 388, 474, 192, 12294, 6, 4119, 71, 11, 4119, 272, 6, 3353, 28, 4081, 6201, 13, 387, 30, 12, 3, 9, 953, 11, 3, 6412, 550, 12, 281, 691, 112, 4842, 5, 978, 520, 764, 590, 11, 1509, 8, 192, 12294, 11, 1500, 12, 474, 128, 2656, 16, 135, 12, 143, 3, 9, 11915, 3281, 5, 37, 861, 3, 22929, 192, 14987, 1329, 7, 13, 2656, 139, 4119, 71, 11, 386, 14987, 1329, 7, 13, 2656, 139, 4119, 272, 5, 11801, 48, 822, 754, 10, 4073, 4119,

  return torch.load(checkpoint_file, map_location="cpu")
tokenizing: 18it [00:00, 418.43it/s]
Training:   0%|          | 0/18 [00:31<?, ?it/s, loss=0.94]

Epoch 1, average loss: 0.9402022957801819


Training:   0%|          | 0/18 [01:21<?, ?it/s, loss=0.601]
Training:   0%|          | 0/18 [01:21<?, ?it/s, loss=0.601]

Training:   0%|          | 0/18 [00:02<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/18 [00:04<?, ?it/s, loss=0.847][A

Epoch 2, average loss: 0.8473597131669521



Training:   0%|          | 0/18 [00:06<?, ?it/s, loss=0.637][A
Training:   0%|          | 0/18 [00:09<?, ?it/s, loss=0.497][A
Training:   0%|          | 0/18 [00:12<?, ?it/s, loss=0.442][A
Training:   0%|          | 0/18 [00:14<?, ?it/s, loss=0.373][A
Training:   0%|          | 0/18 [00:16<?, ?it/s, loss=0.437][A
Training:   0%|          | 0/18 [00:18<?, ?it/s, loss=0.473][A
Training:   0%|          | 0/18 [00:21<?, ?it/s, loss=0.439][A
Training:   0%|          | 0/18 [00:23<?, ?it/s, loss=0.398][A
Training:   0%|          | 0/18 [00:25<?, ?it/s, loss=0.386][A
Training:   0%|          | 0/18 [00:27<?, ?it/s, loss=0.441][A
Training:   0%|          | 0/18 [00:29<?, ?it/s, loss=0.414][A
Training:   0%|          | 0/18 [00:31<?, ?it/s, loss=0.386][A
Training:   0%|          | 0/18 [00:33<?, ?it/s, loss=0.371][A
Training:   0%|          | 0/18 [00:36<?, ?it/s, loss=0.351][A
Training:   0%|          | 0/18 [00:38<?, ?it/s, loss=0.352][A
Training:   0%|          | 0/18 [00:40<

Epoch 3, average loss: 0.12093606404960155


Training:   0%|          | 0/18 [00:35<?, ?it/s, loss=0.305] 
Training:   0%|          | 0/18 [00:35<?, ?it/s, loss=0.305]

Training:   0%|          | 0/18 [00:02<?, ?it/s, loss=0.727][A
Training:   0%|          | 0/18 [00:04<?, ?it/s, loss=0.4]  [A

Epoch 4, average loss: 0.3999467305839062



Training:   0%|          | 0/18 [00:05<?, ?it/s, loss=0.302][A
Training:   0%|          | 0/18 [00:09<?, ?it/s, loss=0.233][A
Training:   0%|          | 0/18 [00:11<?, ?it/s, loss=0.193][A
Training:   0%|          | 0/18 [00:13<?, ?it/s, loss=0.17] [A
Training:   0%|          | 0/18 [00:16<?, ?it/s, loss=0.199][A
Training:   0%|          | 0/18 [00:18<?, ?it/s, loss=0.247][A
Training:   0%|          | 0/18 [00:20<?, ?it/s, loss=0.234][A
Training:   0%|          | 0/18 [00:22<?, ?it/s, loss=0.212][A
Training:   0%|          | 0/18 [00:25<?, ?it/s, loss=0.227][A
Training:   0%|          | 0/18 [00:27<?, ?it/s, loss=0.44] [A
Training:   0%|          | 0/18 [00:30<?, ?it/s, loss=0.407][A
Training:   0%|          | 0/18 [00:32<?, ?it/s, loss=0.378][A
Training:   0%|          | 0/18 [00:34<?, ?it/s, loss=0.355][A
Training:   0%|          | 0/18 [00:37<?, ?it/s, loss=0.334][A
Training:   0%|          | 0/18 [00:39<?, ?it/s, loss=0.321][A
Training:   0%|          | 0/18 [00:41<

Epoch 5, average loss: 0.12957225181162357


Training:   0%|          | 0/18 [00:44<?, ?it/s, loss=0.172] 
Training:   0%|          | 0/18 [00:44<?, ?it/s, loss=0.172]

Training:   0%|          | 0/18 [00:02<?, ?it/s, loss=0.234][A
Training:   0%|          | 0/18 [00:04<?, ?it/s, loss=0.124][A

Epoch 6, average loss: 0.12358094193041325



Training:   0%|          | 0/18 [00:07<?, ?it/s, loss=0.0878][A
Training:   0%|          | 0/18 [00:09<?, ?it/s, loss=0.069] [A
Training:   0%|          | 0/18 [00:11<?, ?it/s, loss=0.0566][A
Training:   0%|          | 0/18 [00:13<?, ?it/s, loss=0.0493][A
Training:   0%|          | 0/18 [00:16<?, ?it/s, loss=0.087] [A
Training:   0%|          | 0/18 [00:18<?, ?it/s, loss=0.184][A
Training:   0%|          | 0/18 [00:20<?, ?it/s, loss=0.164][A
Training:   0%|          | 0/18 [00:23<?, ?it/s, loss=0.148][A
Training:   0%|          | 0/18 [00:25<?, ?it/s, loss=0.197][A
Training:   0%|          | 0/18 [00:27<?, ?it/s, loss=0.207][A
Training:   0%|          | 0/18 [00:29<?, ?it/s, loss=0.192][A
Training:   0%|          | 0/18 [00:31<?, ?it/s, loss=0.178][A
Training:   0%|          | 0/18 [00:33<?, ?it/s, loss=0.168][A
Training:   0%|          | 0/18 [00:36<?, ?it/s, loss=0.159][A
Training:   0%|          | 0/18 [00:38<?, ?it/s, loss=0.16] [A
Training:   0%|          | 0/18 [0

Epoch 7, average loss: 0.07416049297899008


Training:   0%|          | 0/18 [00:43<?, ?it/s, loss=0.401] 
Training:   0%|          | 0/18 [00:43<?, ?it/s, loss=0.401]

Training:   0%|          | 0/18 [00:02<?, ?it/s, loss=0.144][A
Training:   0%|          | 0/18 [00:04<?, ?it/s, loss=0.0777][A

Epoch 8, average loss: 0.07767152041196823



Training:   0%|          | 0/18 [00:07<?, ?it/s, loss=0.0602][A
Training:   0%|          | 0/18 [00:10<?, ?it/s, loss=0.0467][A
Training:   0%|          | 0/18 [00:12<?, ?it/s, loss=0.0756][A
Training:   0%|          | 0/18 [00:15<?, ?it/s, loss=0.0638][A
Training:   0%|          | 0/18 [00:17<?, ?it/s, loss=0.0756][A
Training:   0%|          | 0/18 [00:19<?, ?it/s, loss=0.232] [A
Training:   0%|          | 0/18 [00:21<?, ?it/s, loss=0.214][A
Training:   0%|          | 0/18 [00:24<?, ?it/s, loss=0.193][A
Training:   0%|          | 0/18 [00:26<?, ?it/s, loss=0.247][A
Training:   0%|          | 0/18 [00:29<?, ?it/s, loss=0.248][A
Training:   0%|          | 0/18 [00:32<?, ?it/s, loss=0.234][A
Training:   0%|          | 0/18 [00:34<?, ?it/s, loss=0.217][A
Training:   0%|          | 0/18 [00:37<?, ?it/s, loss=0.23] [A
Training:   0%|          | 0/18 [00:39<?, ?it/s, loss=0.216][A
Training:   0%|          | 0/18 [00:42<?, ?it/s, loss=0.207][A
Training:   0%|          | 0/18 [

Epoch 9, average loss: 0.09022496757097542


Training:   0%|          | 0/18 [00:41<?, ?it/s, loss=0.129] 
Training:   0%|          | 0/18 [00:41<?, ?it/s, loss=0.129]

Training:   0%|          | 0/18 [00:02<?, ?it/s, loss=0.12][A
Training:   0%|          | 0/18 [00:04<?, ?it/s, loss=0.0659][A

Epoch 10, average loss: 0.0658742911182344



Training:   0%|          | 0/18 [00:06<?, ?it/s, loss=0.051] [A
Training:   0%|          | 0/18 [00:09<?, ?it/s, loss=0.0407][A
Training:   0%|          | 0/18 [00:12<?, ?it/s, loss=0.0346][A
Training:   0%|          | 0/18 [00:14<?, ?it/s, loss=0.0302][A
Training:   0%|          | 0/18 [00:16<?, ?it/s, loss=0.0339][A
Training:   0%|          | 0/18 [00:18<?, ?it/s, loss=0.0636][A
Training:   0%|          | 0/18 [00:22<?, ?it/s, loss=0.0599][A
Training:   0%|          | 0/18 [00:24<?, ?it/s, loss=0.0541][A
Training:   0%|          | 0/18 [00:26<?, ?it/s, loss=0.0689][A
Training:   0%|          | 0/18 [00:29<?, ?it/s, loss=0.0883][A
Training:   0%|          | 0/18 [00:31<?, ?it/s, loss=0.0819][A
Training:   0%|          | 0/18 [00:33<?, ?it/s, loss=0.0763][A
Training:   0%|          | 0/18 [00:35<?, ?it/s, loss=0.0722][A
Training:   0%|          | 0/18 [00:38<?, ?it/s, loss=0.0685][A
Training:   0%|          | 0/18 [00:40<?, ?it/s, loss=0.0648][A
tokenizing: 1it [00:00, 

Validation Accuracy: 0.8820
