Import necessary libraries

In [None]:
%pip install datasets
%pip install evaluate
%pip install transformers[torch]
%pip install transformers
%pip install rouge-score
%pip install nltk
%pip install ipywidgets
%pip install accelerate -U

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset
from evaluate import load

Load dataset and metric

In [None]:
dataset = "xsum"
raw_datasets = load_dataset(dataset, trust_remote_code=True)
metric = load("rouge")

In [None]:
raw_datasets

Load model and tokenizer

In [None]:
model_checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [None]:
model

Define prompt-tuning parameters

In [None]:
num_virtual_tokens = 20  # Number of virtual tokens for prompt-tuning
initialize_from_vocab = True

Prepare the prompt embeddings


1. prompt_embeddings = torch.nn.Embedding(num_virtual_tokens, model.config.d_model)
This line creates a new embedding layer for the prompt tokens. Let's break it down:torch.nn.Embedding is a lookup table that stores embeddings of a fixed dictionary and size.
num_virtual_tokens is the number of tokens in our prompt (set to 20 earlier in the code).
model.config.d_model is the dimensionality of the embeddings in the T5 model.

So, this creates an embedding layer for our prompt tokens, where each token will have an embedding vector of the same size as the model's regular token embeddings.
2. if initialize_from_vocab:
This condition checks whether we want to initialize our prompt embeddings from the model's existing vocabulary. This is often a good starting point, as it gives the prompt embeddings some meaningful initial values.
3. prompt_embeddings.weight.data = model.shared.weight[:num_virtual_tokens].clone().detach(). If we're initializing from the vocabulary, this line does the following:
*   model.shared.weight accesses the shared embedding weights of the T5 model.
*   [:num_virtual_tokens] slices the first num_virtual_tokens embeddings from the model's vocabulary
*   .clone() creates a copy of these embeddings
*   .detach() detaches these embeddings from the original model's computational graph
  This initialization step gives our prompt embeddings a starting point based on actual word embeddings from the model's vocabulary, which can help speed up learning.

The purpose of these lines is to create a set of learnable embeddings for our prompt tokens. These embeddings will be prepended to the input embeddings of our actual text, allowing the model to learn task-specific information in the form of these "virtual tokens".
By initializing from the vocabulary, we're giving these prompt embeddings a head start, using the model's existing knowledge of language as encoded in its word embeddings. This can often lead to faster and more stable training compared to random initialization.


In [None]:
prompt_embeddings = torch.nn.Embedding(num_virtual_tokens, model.config.d_model)
if initialize_from_vocab:
    prompt_embeddings.weight.data = model.shared.weight[:num_virtual_tokens].clone().detach()

Freeze the model parameters

In [None]:
for param in model.parameters():
    param.requires_grad = False

Make prompt embeddings trainable

In [None]:
prompt_embeddings.weight.requires_grad = True

In [None]:
# Set up device - ADD THIS CODE HERE
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Move model and prompt embeddings to device
model = model.to(device)
prompt_embeddings = prompt_embeddings.to(device)

Modify the model's forward pass to include prompt-tuning

The concept of model forwarding, often referred to as the "forward pass" or simply "forward," is a fundamental aspect of neural network computation. Let's break this down:

1. Basic Concept: The forward pass is the process of passing input data through a neural network to get an output. It's called "forward" because the data flows from the input layer, through hidden layers (if any), to the output layer.
2. In PyTorch: In PyTorch, the forward pass is typically defined in the forward() method of a nn.Module class. This method describes how input data should be transformed to produce the output.
3. Computation Graph: During the forward pass, PyTorch builds a computational graph. This graph records all operations performed on the input data, which is crucial for the backward pass (backpropagation) during training.
4. Custom Forward Methods: Sometimes, as in our prompt-tuning example, we need to modify the forward pass of an existing model. This allows us to inject custom behavior, like adding prompt embeddings to the input.
5. In the Context of Prompt-Tuning: In our code, we modified the forward pass to include the following steps:
- Repeat the prompt embeddings for each item in the batch
- Concatenate the prompt embeddings with the input embeddings
- Adjust the attention mask to account for the added prompt tokens
- Call the original forward method with these modified inputs
6. Why Modify the Forward Pass: By modifying the forward pass, we can change how the model processes inputs without altering its fundamental architecture. In prompt-tuning, this allows us to prepend learnable prompt embeddings to every input, effectively giving the model additional context for its task.
7. Flexibility: Custom forward methods provide great flexibility. They allow us to adapt pre-trained models for new tasks, implement complex architectures, or introduce novel training techniques like prompt-tuning.
8. Efficiency: A well-designed forward pass can also improve computational efficiency. For instance, by doing certain computations in the forward pass, we might avoid repetitive calculations during training.

In essence, the forward pass defines how data flows through the model, and by customizing it, we can significantly alter the model's behavior without changing its core parameters. This is particularly useful in transfer learning scenarios, where we want to adapt a pre-trained model to a new task with minimal changes to the original model.

In [None]:
# Save the original forward method
original_forward = model.forward

def model_forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, 
                  decoder_input_ids=None, decoder_attention_mask=None, labels=None, **kwargs):
    # If we already have inputs_embeds, we're in a recursive call. Just add the prompt and return.
    if inputs_embeds is not None:
        batch_size = inputs_embeds.shape[0]
        prompt_embeds = prompt_embeddings.weight.repeat(batch_size, 1, 1)
        inputs_embeds = torch.cat([prompt_embeds, inputs_embeds], dim=1)
        
        if attention_mask is not None:
            attention_mask = torch.cat([torch.ones(batch_size, num_virtual_tokens).to(self.device), attention_mask], dim=1)
        
        return original_forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
                                decoder_input_ids=decoder_input_ids, 
                                decoder_attention_mask=decoder_attention_mask,
                                labels=labels, **kwargs)

    # If we have input_ids, convert them to embeddings
    if input_ids is not None:
        batch_size = input_ids.shape[0]
        inputs_embeds = self.encoder.embed_tokens(input_ids)
        prompt_embeds = prompt_embeddings.weight.repeat(batch_size, 1, 1)
        inputs_embeds = torch.cat([prompt_embeds, inputs_embeds], dim=1)
        
        if attention_mask is not None:
            attention_mask = torch.cat([torch.ones(batch_size, num_virtual_tokens).to(self.device), attention_mask], dim=1)
        
        return original_forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
                                decoder_input_ids=decoder_input_ids, 
                                decoder_attention_mask=decoder_attention_mask,
                                labels=labels, **kwargs)

    # If we have neither, raise an error
    raise ValueError("You have to specify either input_ids or inputs_embeds")

Apply the custom forward method

In [None]:
model.forward = model_forward.__get__(model)

Preprocessing function

These parameters were adjusted for specific reasons related to optimizing the prompt-tuning process. Let's break down the changes and their rationale:

1. max_input_length = 512 (changed from 1024) Reason for change:
- Efficiency: Reducing the maximum input length from 1024 to 512 tokens significantly decreases the memory usage and computational time required for each training step.
- Dataset characteristics: For many summarization tasks, especially with the XSum dataset, 512 tokens are often sufficient to capture the main content of the input text.
- Prompt-tuning focus: Since we're using prompt-tuning, we want to focus on learning the prompt embeddings rather than processing very long sequences
- Shorter inputs allow for more training iterations in the same amount of time.
- GPU memory constraints: Smaller input lengths allow for larger batch sizes, which can lead to more stable training, especially on GPUs with limited memory.
2. max_target_length = 128 (unchanged) Reason for keeping it the same:
- Summarization goal: The XSum dataset aims for extreme summarization, typically producing single-sentence summaries. 128 tokens are usually more than enough for this purpose.
- Consistency: Keeping the target length the same ensures that our model's output remains consistent with the original task specifications.

The adjustment of these parameters, particularly the reduction of max_input_length, serves several purposes:
- Faster training: Shorter sequences mean faster forward and backward passes through the network.
- Lower memory usage: This allows for larger batch sizes or training on GPUs with less memory.
- More iterations: With faster processing, we can potentially run more epochs or process more examples in the same amount of time.
- Focus on prompt: By limiting the input size, we place more emphasis on the role of the learned prompt embeddings in guiding the model's behavior.

It's worth noting that while these changes can significantly improve training efficiency, they do come with a trade-off. If the original texts in the dataset frequently exceed 512 tokens in length, we might be losing some information. However, for many summarization tasks, and especially for prompt-tuning, this trade-off is often beneficial.

If you find that 512 tokens are not sufficient for your specific use case, you can always adjust this parameter. The key is to balance between having enough context for good summaries and maintaining efficient training.

In [None]:
#why is changed?
max_input_length = 512
max_target_length = 128

def preprocess_function(examples):
    inputs = examples["document"]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    labels = tokenizer(text_target=examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

Tokenize datasets

In [None]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

Compute metrics

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    # Extract a few results
    result = {key: value * 100 for key, value in result.items()}

    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

Set up training arguments

In [None]:
batch_size = 8
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"{model_name}-prompt-tuned-xsum",
    evaluation_strategy="epoch",
    learning_rate=1e-3,  # Higher learning rate for prompt-tuning
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,  # Increase epochs as we're training fewer parameters
    predict_with_generate=True,
    fp16=True,
    push_to_hub=True,
)

Define data collator

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

Set up trainer

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    #compute_metrics=lambda pred: metric.compute(predictions=pred.predictions, references=pred.label_ids, use_stemmer=True),
)

Train the model

In [None]:
trainer.train()

Push the model to the Hub

In [None]:
#trainer.push_to_hub()

import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Assuming you've already trained your model and it's called 'model'
# and your tokenizer is called 'tokenizer'

def save_model_locally(model, tokenizer, path):
    """
    Save the model and tokenizer to a local directory.
    """
    model.save_pretrained(path)
    tokenizer.save_pretrained(path)
    
    # Save the prompt embeddings separately
    torch.save(model.prompt_embeddings.state_dict(), f"{path}/prompt_embeddings.pt")
    print(f"Model and tokenizer saved to {path}")

def load_model_locally(path):
    """
    Load the model and tokenizer from a local directory.
    """
    model = AutoModelForSeq2SeqLM.from_pretrained(path)
    tokenizer = AutoTokenizer.from_pretrained(path)
    
    # Load the prompt embeddings
    prompt_embeddings = torch.nn.Embedding(20, model.config.d_model)  # Adjust 20 if you used a different number
    prompt_embeddings.load_state_dict(torch.load(f"{path}/prompt_embeddings.pt"))
    
    # Attach prompt embeddings to the model
    model.prompt_embeddings = prompt_embeddings
    
    # Recreate the custom forward method
    def model_forward(self, input_ids=None, attention_mask=None, **kwargs):
        batch_size = input_ids.shape[0]
        prompt_embeds = self.prompt_embeddings.weight.repeat(batch_size, 1, 1)
        inputs_embeds = self.encoder.embed_tokens(input_ids)
        inputs_embeds = torch.cat([prompt_embeds, inputs_embeds], dim=1)
        
        attention_mask = torch.cat([torch.ones(batch_size, 20).to(self.device), attention_mask], dim=1)
        
        return self.forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)

    model.forward = model_forward.__get__(model)
    
    return model, tokenizer

def summarize_text(model, tokenizer, text):
    """
    Use the model to summarize a given text.
    """
    inputs = tokenizer([text], max_length=512, truncation=True, return_tensors="pt")
    summary_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=100, early_stopping=True)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

# Example usage:

# 1. Save the model locally
save_model_locally(model, tokenizer, "./my_prompt_tuned_model")

# 2. Load the model from local storage
loaded_model, loaded_tokenizer = load_model_locally("./my_prompt_tuned_model")

# 3. Run the model on a test sentence
test_text = """
The United Nations has warned that millions of people in South Sudan are facing severe food shortages. 
The UN's World Food Programme (WFP) says more than seven million people - about two-thirds of the population - are in need of food aid. 
The agency says the situation has been made worse by flooding, conflict and the economic crisis. 
South Sudan has been plagued by instability since it gained independence from Sudan in 2011.
"""

summary = summarize_text(loaded_model, loaded_tokenizer, test_text)
print("Original text:")
print(test_text)
print("\nGenerated summary:")
print(summary)