In [None]:
# Import necessary libraries
from openprompt.plms import T5TokenizerWrapper
from datasets import load_from_disk
from openprompt.pipeline_base import PromptDataLoader
from transformers import T5ForConditionalGeneration, T5Tokenizer
from openprompt.prompts import ManualTemplate
from openprompt import PromptForClassification
from openprompt.data_utils import InputExample
from tqdm import tqdm
import torch
from transformers import AdamW
from transformers.optimization import get_linear_schedule_with_warmup
from openprompt.prompts import ManualVerbalizer
import json

# Load dataset
dataset_path = "/path/to/your/dataset"
raw_dataset = load_from_disk(dataset_path)

# Prepare dataset
dataset = {}
for split in ['train', 'validation']:
    dataset[split] = []
    raw_dataset[split] = raw_dataset[split].select(range(500))  # Select the first 500 examples
    for idx, data in enumerate(raw_dataset[split]):
        dataset[split].append(data)

# Print the first training example for debugging
print(dataset['train'][0])
print(type(dataset['train'][0]))

# Load T5 model and tokenizer
t5_path = "/path/to/t5-base"
model = T5ForConditionalGeneration.from_pretrained(t5_path)
tokenizer = T5Tokenizer.from_pretrained(t5_path)

# Logging setup
log_file = "rc_binary_t5.json"
results = []

# Prepare training dataset
# For the training, 'ropes_background_new_situation_answer' from P3 dataset.
# These 18 samples share the same background.

data1 = dataset['train'][2:20]  # Select 18 samples
dataset1 = []
for idx, data in enumerate(data1):
    question = data["inputs_pretokenized"]  # Extract the question
    correct_answer = data["targets_pretokenized"].strip()  # Extract the correct answer
    label = 0 if correct_answer in ['cell X', 'Cell A', 'larger', 'more'] else 1  # Binary label
    input_example = InputExample(
        text_a=question,
        label=label,
        guid=idx,
        meta={"correct_answer": correct_answer}
    )
    dataset1.append(input_example)

# Define manual template and verbalizer for training
template1 = ManualTemplate(
    tokenizer=tokenizer,
    text='{"placeholder":"text_a"} The answer is: {"mask"}',
)
verbalizer1 = ManualVerbalizer(
    tokenizer=tokenizer,
    num_classes=2,
    label_words=[['cell X', 'cell A', 'larger', 'more'], ['cell Z', 'cell B', 'smaller', 'less']]
)

# Initialize the prompt model
prompt_model = PromptForClassification(
    plm=model,
    template=template1,
    verbalizer=verbalizer1,
    freeze_plm=False,
)

# Prepare training dataloader
train_dataloader = PromptDataLoader(
    dataset=dataset1,
    template=template1,
    tokenizer=tokenizer,
    tokenizer_wrapper_class=T5TokenizerWrapper,
    decoder_max_length=68, max_seq_length=480,
    batch_size=1
)

# Define optimizer and loss function
loss_func = torch.nn.CrossEntropyLoss()
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in prompt_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in prompt_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=0.0001)
prompt_model.train()

# Training loop
for epoch in range(10):  # Train for 10 epochs
    tot_loss = 0
    pbar = tqdm(train_dataloader, desc="Training")
    for step, inputs in enumerate(train_dataloader):
        logits = prompt_model(inputs)  # Get predictions
        labels = inputs['label']  # Ground-truth labels
        loss = loss_func(logits, labels)  # Compute loss
        loss.backward()  # Backpropagation
        tot_loss += loss.item()
        optimizer.step()  # Update weights
        optimizer.zero_grad()  # Clear gradients
        pbar.set_postfix({"loss": tot_loss / (step + 1)})

# Prepare validation dataset
data2 = dataset['validation'][:1000]  # Use first 1000 examples for validation
dataset2 = []
for idx, data in enumerate(data2):
    question = data["inputs_pretokenized"]
    correct_answer = data["targets_pretokenized"].strip()
    input_example = InputExample(
        text_a=question,
        label=0,  # Assign a dummy label for validation
        guid=idx,
        meta={"correct_answer": correct_answer}
    )
    dataset2.append(input_example)

# Validation loop
for idx, data in enumerate(dataset2):
    template2 = ManualTemplate(
        tokenizer=tokenizer,
        text='{"placeholder":"text_a"} The answer is: {"mask"}',
    )
    verbalizer2 = ManualVerbalizer(
        tokenizer=tokenizer,
        num_classes=2,
        label_words=[[data.meta['correct_answer']], ["other"]]
    )
    prompt_model.template = template2  # Update template
    prompt_model.verbalizer = verbalizer2  # Update verbalizer

    validation_dataloader = PromptDataLoader(
        dataset=[data],
        template=template2,
        tokenizer=tokenizer,
        tokenizer_wrapper_class=T5TokenizerWrapper,
        decoder_max_length=3, max_seq_length=480,
        batch_size=1
    )

    # Evaluate on validation set
    prompt_model.eval()
    with torch.no_grad():
        for inputs in validation_dataloader:
            logits = prompt_model(inputs)
            preds = torch.argmax(logits, dim=-1)  # Predicted class
            correct = preds.item() == data.label  # Compare with true label

    results.append({"index": idx, "correct": correct})

# Compute overall accuracy
accuracy = sum(r["correct"] for r in results) / len(results)
print(f"Validation Accuracy: {accuracy:.4f}")

# Save results to JSON file
with open(log_file, "w") as f:
    json.dump(results, f, indent=4)


# Overview of Reading Comprehension Task with T5 and OpenPrompt Framework

This code implements a **reading comprehension** task using the OpenPrompt framework and a pre-trained **T5 model**. The task is reframed as a **binary classification problem**, where the model predicts whether the provided answer is correct (`right answer`) or incorrect (`wrong answer`). The implementation explores both **zero-shot** and **few-shot learning**, examining the model's ability to generalize and learn from a small number of training examples. It also highlights how fine-tuning can affect the performance of pre-trained models.

---

## Key Features

### 1. **Task Framing as Binary Classification**
- The reading comprehension task, typically requiring exact answer generation or selection, is reformulated as a binary classification problem:
  - **Right Answer**: The model predicts the correct answer from a predefined set.
  - **Wrong Answer**: Any answer not matching the correct label is classified as incorrect.

### 2. **Zero-Shot and Few-Shot Learning**
- **Zero-Shot**: Evaluates the pre-trained T5 model without any additional training to establish baseline performance.
- **Few-Shot**: Fine-tunes the T5 model on a small dataset of 18 samples with the same context but different questions to test its ability to learn patterns from limited data.

### 3. **Manual Template and Verbalizer**
- **Template**: A manual template structures the input text into a prompt, such as:
  ```python
  {"placeholder":"text_a"} The answer is: {"mask"}
  ```
- **Verbalizer**: A manual verbalizer maps model outputs to binary labels (`correct` or `incorrect`) using predefined label words.

### 4. **Training and Validation**
- **Few-Shot Training**: Fine-tunes both the T5 model and the manual template using a carefully curated set of examples.
- **Dynamic Validation**: During validation, the verbalizer dynamically updates to include the correct answer for each specific example, ensuring accurate evaluation.

### 5. **Results Logging**
- Tracks and logs the accuracy of the model on the validation set after fine-tuning.
- Saves predictions and correctness information for individual validation samples to a JSON file.

---

## Observations and Limitations

1. **Few-Shot Learning Benefits**:
 - Few-shot learning improves T5's performance on the binary classification task, demonstrating its ability to learn structural patterns from limited data.

2. **Impact on Pre-Trained Models**:
 - Fine-tuning can `confuse pre-trained models in some cases`, reducing their baseline performance. This highlights the need for careful monitoring during training to ensure consistent improvements.

3. **Simplification of the Task**:
 - The reading comprehension task is simplified into binary classification. While this approach is efficient, it might not capture the nuanced understanding required for more complex tasks, such as ranking or reasoning between multiple answers.

---

## Applications

- **Reading Comprehension**: Adaptation of pre-trained models like T5 for reading comprehension tasks using prompt-based learning.
- **Few-Shot Learning**: Demonstrates the potential of a small number of examples to influence model performance on structured tasks.
- **Binary Classification**: Illustrates how complex NLP tasks can be reframed into simpler classification problems for ease of implementation and experimentation.

---

## Conclusion

This implementation highlights the flexibility of the OpenPrompt framework and the adaptability of T5 for reading comprehension tasks reframed as binary classification. The results show the potential of few-shot learning to enhance performance, while also cautioning about the possibility of confusing pre-trained models through fine-tuning. This approach provides an efficient and flexible framework for experimenting with prompt-based learning in reading comprehension and similar NLP tasks.
