In [None]:
# Import necessary libraries
from openprompt.plms import T5TokenizerWrapper
from datasets import load_from_disk
from openprompt.pipeline_base import PromptDataLoader
from transformers import T5ForConditionalGeneration, T5Tokenizer
from openprompt.prompts import ManualTemplate
from openprompt import PromptForClassification
from openprompt.data_utils import FewShotSampler
from random import shuffle
from transformers import AdamW
from transformers.optimization import get_linear_schedule_with_warmup
import torch
from openprompt.prompts import ManualVerbalizer
from openprompt.data_utils import InputExample
from tqdm import tqdm
import json

# Load the dataset
dataset_path = "/path/to/your/data/set"
raw_dataset = load_from_disk(dataset_path)

# Load the T5 model and tokenizer
t5_path = "/path/to/t5-base"
model = T5ForConditionalGeneration.from_pretrained(t5_path)
tokenizer = T5Tokenizer.from_pretrained(t5_path)

# Set up logging
log_file = "qamc_id_t5.json"
results = []

# Map textual labels to numeric values
label_map = {"A": 0, "B": 1, "C": 2, "D": 3, "E": 4}

# Prepare datasets for training and validation
dataset = {}
for split in ['train', 'validation']:
    dataset[split] = []
    if split == 'train':
        # Shuffle and select a subset for training
        raw_dataset[split] = raw_dataset[split].shuffle(seed=42).select(range(1000))
    else:
        # Select a subset for validation
        raw_dataset[split] = raw_dataset[split].select(range(500))
    
    for idx, data in enumerate(raw_dataset[split]):
        label_text = data["targets_pretokenized"].strip()  # Extract the correct answer
        label_numeric = label_map.get(label_text, -1)  # Convert to numeric label
        input_example = InputExample(text_a=data['inputs_pretokenized'], guid=idx, label=label_numeric)
        dataset[split].append(input_example)

# Print a sample InputExample and its type for verification
print(dataset['train'][0])
print(type(dataset['train'][0]))

# Few-shot sampling from the training data
sampler = FewShotSampler(num_examples_per_label=30)
fewshot_data = sampler(dataset['train'], seed=42)

# Define the evaluation function
def evaluate(prompt_model, dataloader):
    prompt_model.eval()  # Set the model to evaluation mode
    total, correct = 0, 0
    
    with torch.no_grad():
        for inputs in dataloader:
            logits = prompt_model(inputs)
            preds = torch.argmax(logits, dim=-1)  # Predicted class
            labels = inputs['label']  # True labels
            
            total += len(labels)
            correct += (preds == labels).sum().item()  # Count correct predictions
        
    accuracy = correct / total
    return accuracy

# Hyperparameter ranges for grid search
learning_rates = [0.005, 0.001, 0.0005]  # Learning rates to test
warmup_steps = [10]  # Warm-up steps for scheduler

# Perform hyperparameter tuning
for lr in learning_rates:
    for warmup in warmup_steps:

        # Reload model and tokenizer for each configuration
        model = T5ForConditionalGeneration.from_pretrained(t5_path)
        tokenizer = T5Tokenizer.from_pretrained(t5_path)

        # Define the manual template for input formatting
        template = ManualTemplate(
            tokenizer=tokenizer,
            text='{"placeholder":"text_a"} Which option is correct? {"mask"}',
        )

        # Define the verbalizer to map model predictions to labels
        verbalizer = ManualVerbalizer(
            tokenizer=tokenizer,
            num_classes=5,  # Five options (A, B, C, D, E)
            label_words=[
                ["A", "a", "Option A", "first choice"],
                ["B", "b", "Option B", "second choice"],
                ["C", "c", "Option C", "third choice"],
                ["D", "d", "Option D", "fourth choice"],
                ["E", "e", "Option E", "fifth choice"]
            ]
        )

        # Wrap one example for debugging (optional)
        wrapped_example = template.wrap_one_example(dataset['train'][0])

        # Initialize the prompt model
        prompt_model = PromptForClassification(
            plm=model,
            template=template,
            verbalizer=verbalizer,
            freeze_plm=False,  # Allow fine-tuning of T5 model
        )

        # Prepare data loaders for training and validation
        train_dataloader = PromptDataLoader(
            dataset=fewshot_data,
            template=template,
            tokenizer=tokenizer,
            tokenizer_wrapper_class=T5TokenizerWrapper,
            decoder_max_length=3, max_seq_length=480,
            batch_size=5
        )

        validation_dataloader = PromptDataLoader(
            dataset=dataset["validation"],
            template=template,
            tokenizer=tokenizer,
            tokenizer_wrapper_class=T5TokenizerWrapper,
            decoder_max_length=3, max_seq_length=480,
            batch_size=20
        )

        # Define the loss function
        loss_func = torch.nn.CrossEntropyLoss()

        # Set optimizer parameters and no decay for biases and LayerNorm weights
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in prompt_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in prompt_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup, num_training_steps=1000)

        # Training loop
        prompt_model.train()
        for epoch in range(10):  # 10 epochs
            total_loss = 0
            pbar = tqdm(train_dataloader, desc="Training")
            for step, inputs in enumerate(train_dataloader):
                logits = prompt_model(inputs)  # Forward pass
                labels = inputs['label']  # Ground-truth labels
                loss = loss_func(logits, labels)  # Compute loss
                loss.backward()  # Backpropagation
                total_loss += loss.item()
                optimizer.step()  # Update parameters
                optimizer.zero_grad()  # Reset gradients
                pbar.set_postfix({"loss": total_loss / (step + 1)})

        # Validate the model after each epoch
        val_accuracy = evaluate(prompt_model, validation_dataloader)
        print(f"Validation Accuracy after Epoch {epoch + 1}: {val_accuracy:.4f}")

        # Log results
        result = {
            "learning_rate": lr,
            "warmup_steps": warmup,
            "final_loss": total_loss / (10 * len(train_dataloader)),
            "accuracy": val_accuracy
        }
        results.append(result)

        # Save results to a JSON file
        with open(log_file, "w") as f:
            json.dump(results, f, indent=4)

print("Tuning complete. Results saved to", log_file)


# Overview of QA Task Implementation with T5

This code demonstrates the implementation of a **question-answering multiple-choice (QA-MC)** task using the OpenPrompt framework and a pre-trained **T5 model**. The focus is on fine-tuning the T5 model for multiple-choice classification tasks by utilizing manual templates and verbalizers. The script includes hyperparameter tuning, training, and evaluation, with detailed logging of results for analysis.

---

## Key Features

### 1. **Dataset Preparation**
- **Loading and Preprocessing**:
  - Loads a dataset from disk and prepares it for classification.
  - Labels (A, B, C, D, E) are mapped to numeric values for model training.
- **Few-Shot Sampling**:
  - Creates a balanced training set with 30 examples per label, simulating a low-resource scenario.

### 2. **Manual Template**
- Formats the input text to include a task-specific prompt:
  ```python
  {"placeholder":"text_a"} Which option is correct? {"mask"}
  ``` 
- Guides the model to generate predictions by appending the placeholder text with a mask token for output generation.

### 3. **Manual Verbalizer**
- Maps the model's output logits to corresponding labels (`A`, `B`, `C`, `D`, `E`).
- Allows flexibility in label representation using synonyms like `["Option A", "first choice"]`.

### 4. **Training Process**
- Fine-tunes both the T5 model and the manual template using:
- **AdamW Optimizer**: Updates model parameters during backpropagation.
- **Linear Learning Rate Scheduler**: Adjusts the learning rate dynamically with warm-up steps.
- Logs training loss and evaluates performance at the end of each epoch.

### 5. **Evaluation**
- Computes the model's accuracy on a validation set after every epoch.
- Compares predicted labels with ground-truth labels to measure performance.

### 6. **Hyperparameter Tuning**
- Explores different combinations of:
- **Learning Rates**: `0.005`, `0.001`, `0.0005`.
- **Warm-Up Steps**: `10`.
- Identifies the best-performing configuration for fine-tuning.

### 7. **Results Logging**
- Records key metrics such as:
- Final loss
- Validation accuracy
- Hyperparameter settings
- Saves results in a JSON file for further analysis.

---

## Limitations of QA Tasks in the OpenPrompt Framework

### 1. **Label Dependency**
- QA tasks often have dynamic labels (e.g., options `A`, `B`, `C`, `D`, `E`), which require a manually defined verbalizer.
- OpenPrompt does not inherently support dynamically generated label mappings, making it less flexible for certain QA tasks.

### 2. **Output Limitations**
- The OpenPrompt framework generates outputs based on the verbalizer's predefined label words. For QA tasks, this can lead to:
  - Difficulty in capturing nuanced differences between options.
  - Over-reliance on exact label word matches.

### 3. **Scaling Issues**
- Large-scale QA datasets with diverse label structures may not fit well into OpenPrompt's manual verbalizer approach, which requires significant manual effort.

### 4. **Training Overhead**
- Fine-tuning both the template and the T5 model introduces additional computational overhead, which might not be optimal for lightweight deployment.

---

## Workflow

### 1. **Dataset Handling**
- The dataset is split into training and validation sets, with a maximum of 1000 examples for training and 500 for validation.
- Each example is converted into OpenPrompt's `InputExample` format.

### 2. **Template and Verbalizer Setup**
- A manual template and verbalizer are defined to structure inputs and map predictions, respectively.
- The template ensures compatibility with the T5 model's input-output format.

### 3. **Training**
- The T5 model is fine-tuned along with the template parameters for 10 epochs.
- Loss is computed using cross-entropy, and parameters are optimized using AdamW.

### 4. **Evaluation**
- Accuracy is calculated on the validation set after each epoch to monitor performance.

### 5. **Logging**
- Training loss, validation accuracy, and hyperparameter configurations are logged and saved for analysis.

---

## Applications

- **QA Multiple-Choice Tasks**: Solves tasks where the model predicts the correct option from a predefined list (e.g., `A`, `B`, `C`, `D`, `E`).
- **Few-Shot Learning**: Demonstrates how OpenPrompt can handle low-resource scenarios effectively.

---

## Conclusion

This implementation provides a practical example of fine-tuning a T5 model for QA tasks using the OpenPrompt framework. While it effectively applies templates and verbalizers to guide model predictions, limitations like dynamic label dependency and scaling challenges highlight areas where the OpenPrompt framework may struggle with QA tasks.

This script is ideal for experimenting with prompt-based learning but requires manual intervention for handling dynamic and large-scale QA datasets. Further enhancements could include automating verbalizer generation or exploring more scalable prompt engineering techniques.
