# Prompting flan-T5-base

In [None]:
## Config
random_seed = 100
data_path = "/kaggle/working/"

In [None]:
%%capture
!pip install -U datasets
!pip install transformers datasets evaluate rouge_score --quiet
!pip uninstall keras -y
!pip install keras==2.11
!pip install bert_score

In [None]:
import evaluate
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import pandas as pd
from bert_score import score
import pickle
import os

2025-06-19 13:50:40.237407: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750341040.262774     253 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750341040.270596     253 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Load Data and Model

In [None]:
# Full dataset (split included)
dataset = load_dataset("EdinburghNLP/xsum")

In [None]:
# Few-shot examples from training set
train_examples = dataset["train"].select(range(2))

# Sample a subset of the test set for evaluation
test_sample = dataset["test"]
references = [example["summary"] for example in test_sample]

In [None]:
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=2024)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

## Zero-Shot, One-Shot and Few-Shot Prompts

In [None]:
def build_zero_shot_prompt(doc):
    """Builds a zero-shot prompt."""
    prompt_template = "Summarize the input text.\n\n ### INPUT TEXT\nDocument:{}\nSummary:[Fill the summary]"
    return prompt_template.format(doc)

def build_one_shot_prompt(doc, train_example):
    """Builds a one-shot prompt with one example."""
    prompt = ""
    prompt += "Task: Summarize the input text. An example is provided below. \n"
    prompt += f"### EXAMPLE:\nDocument: {train_example['document'].strip()}\nSummary: {train_example['summary'].strip()}\n\n"
    prompt += f"### INPUT TEXT:\nDocument: {doc.strip()}\nSummary:[Fill the summary]"
    return prompt

def build_few_shot_prompt(doc, few_shots):
    """Builds a few-shot prompt with multiple examples."""
    prompt = ""
    prompt += "Task: Summarize the input text. Examples are provided below. \n"
    for ex in few_shots:
        prompt += f"### EXAMPLE:\nDocument: {ex['document'].strip()}\nSummary: {ex['summary'].strip()}\n\n"
    prompt += f"### INPUT TEXT:\nDocument: {doc.strip()}\nSummary:[Fill the summary]"
    return prompt

In [None]:
# document = test_sample[0]["document"]

# zero_shot_prompts = [build_zero_shot_prompt(document["document"]) for document in test_sample]
# one_shot_prompts = [build_one_shot_prompt(document["document"], train_examples[0]) for document in test_sample]
# few_shot_prompts = [build_few_shot_prompt(document["document"], train_examples) for document in test_sample]

In [None]:
zero_shot_prompts = []
for document in test_sample:
    prompt = build_zero_shot_prompt(document["document"])
    zero_shot_prompts.append(prompt)

one_shot_prompts = []
for document in test_sample:
    prompt = build_one_shot_prompt(document["document"], train_examples[0])
    one_shot_prompts.append(prompt)

# few_shot_prompts = []
# for document in test_sample:
#     prompt = build_few_shot_prompt(document["document"], train_examples)
#     few_shot_prompts.append(prompt)

### Generate model ops

In [None]:
# Move the model to the correct device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_ = model.to(device)

In [None]:
def generate_prompt_output(prompts, model, device):
  # Generate few-shot predictions in batches
  batch_size = 20
  preds = []

  # Select a larger subset for demonstration
  subset = zero_shot_prompts

  for i in tqdm(range(0, len(prompts), batch_size)):
      batch_subset = []
      for j in range(i, min(i + batch_size, len(prompts))):
          batch_subset.append(prompts[j])

      batch_prompts = batch_subset

      # Tokenize and move inputs to the correct device
      inputs = tokenizer(batch_prompts, return_tensors="pt", truncation=True, max_length=2024, padding=True)
      inputs = {k: v.to(device) for k, v in inputs.items()}  # Fix: move input tensors to the device

      # Generate predictions
      outputs = model.generate(**inputs, max_length=64)  # Do NOT call .to(device) here

      # Decode predictions
      batch_preds = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
      preds.extend(batch_preds)

  return preds

In [None]:
zero_shot_results = generate_prompt_output(zero_shot_prompts, model, device)

file_name = "zero_shot_testset.pkl"
file_path = os.path.join(data_path, file_name)

try:
    with open(file_path, 'wb') as f:
        pickle.dump(zero_shot_results, f)
    print(f"Successfully saved the list as pickle to: {file_path}")
except Exception as e:
    print(f"An error occurred while saving the pickle file: {e}")

  1%|          | 3/567 [00:21<1:08:22,  7.27s/it]

In [None]:
one_shot_results = generate_prompt_output(one_shot_prompts, model, device)

file_name = "one_shot_testset.pkl"
file_path = os.path.join(data_path, file_name)

try:
    with open(file_path, 'wb') as f:
        pickle.dump(one_shot_results, f)
    print(f"Successfully saved the list as pickle to: {file_path}")
except Exception as e:
    print(f"An error occurred while saving the pickle file: {e}")

In [None]:
# # few_shot_results = generate_prompt_output(few_shot_prompts, model, device)

file_path = "/kaggle/input/few-shot-testset-pkl/few_shot_testset.pkl"

try:
    with open(file_path, 'rb') as f:
        few_shot_results = pickle.load(f)
    print(f"Successfully loaded the list from pickle file: {file_path}")
    # Now you can work with the 'few_shot_preds' variable
    # For example, you can print the first few elements:
    # print(few_shot_preds[:5])
except FileNotFoundError:
    print(f"Error: The file '{file_path}' was not found.")
except Exception as e:
    print(f"An error occurred while loading the pickle file: {e}")

### Calculate Evaluation metrics

In [None]:
# Ensure you have the necessary evaluation metrics loaded
rouge = evaluate.load("rouge")

# --- Calculate Metrics ---

# Zero-shot
zero_shot_rouge = rouge.compute(predictions=zero_shot_results, references=references)
_, _, zero_shot_bertscore_f1 = score(zero_shot_results, references, lang="en", verbose=True)

# One-shot
one_shot_rouge = rouge.compute(predictions=one_shot_results, references=references)
_, _, one_shot_bertscore_f1 = score(one_shot_results, references, lang="en", verbose=True)

# Few-shot
# few_shot_rouge = rouge.compute(predictions=few_shot_results, references=references)
# _, _, few_shot_bertscore_f1 = score(few_shot_results, references, lang="en", verbose=True, device= device)

# --- Prepare Data for DataFrame ---

results = {
    ('Zero-shot', model_name): {
        'ROUGE-1': zero_shot_rouge['rouge1'],
        'ROUGE-2': zero_shot_rouge['rouge2'],
        'ROUGE-L': zero_shot_rouge['rougeL'],
        'BERTScore F1': zero_shot_bertscore_f1.mean().item()
    },
    ('One-shot', model_name): {
        'ROUGE-1': one_shot_rouge['rouge1'],
        'ROUGE-2': one_shot_rouge['rouge2'],
        'ROUGE-L': one_shot_rouge['rougeL'],
        'BERTScore F1': one_shot_bertscore_f1.mean().item()
    },
    ('Few-shot', model_name): {
        'ROUGE-1': few_shot_rouge['rouge1'],
        'ROUGE-2': few_shot_rouge['rouge2'],
        'ROUGE-L': few_shot_rouge['rougeL'],
        'BERTScore F1': few_shot_bertscore_f1.mean().item()
    }
}

# --- Create DataFrame ---

df_results = pd.DataFrame.from_dict(results, orient='index')

# Set the index names
df_results.index.names = ['Prompt Type', 'Model']

In [None]:
# Display the DataFrame
print(df_results)

In [None]:
file_name = "results.pkl"
file_path = os.path.join(data_path, file_name)

try:
    df_results.to_pickle(file_path)
    print(f"Successfully saved the DataFrame as pickle to: {file_path}")
except Exception as e:
    print(f"An error occurred while saving the pickle file: {e}")