# Prefix-Tuning Reproduction 
This notebook reproduces the results from the ACL 2021 paper "Prefix-Tuning: Optimizing Continuous Prompts for Generation" and extends it with a sentiment-controlled prefix tuning example.

## Reproduction
I validate the prefix-tuning method on a small text dataset to verify generation capabilities on CPU.

In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
import torch
import json
from torch.utils.data import Dataset
from evaluate import load
from evaluate import load
import pandas as pd


In [None]:
# === Sentiment Prefix Model ===
class SentimentPrefixModel(torch.nn.Module):
    def __init__(self, base_model, prefix_len=10, hidden_size=768, use_mlp=True):
        super().__init__()
        self.base_model = base_model
        self.prefix_len = prefix_len
        self.hidden_size = hidden_size
        self.prefix_embeddings = torch.nn.Embedding(2, prefix_len * hidden_size)
        self.use_mlp = use_mlp

        if self.use_mlp:
            self.mlp = torch.nn.Sequential(
                torch.nn.Linear(hidden_size, hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(hidden_size, hidden_size)
            )

    def forward(self, input_ids, sentiment_id, labels=None):
        prefix_embed = self.prefix_embeddings(sentiment_id).view(-1, self.prefix_len, self.hidden_size)
        if self.use_mlp:
            prefix_embed = self.mlp(prefix_embed)

        input_embed = self.base_model.transformer.wte(input_ids)
        input_with_prefix = torch.cat((prefix_embed, input_embed), dim=1)

        if labels is not None:
            pad = torch.full((labels.size(0), self.prefix_len), -100).to(labels.device)
            padded_labels = torch.cat([pad, labels], dim=1)
        else:
            padded_labels = None

        return self.base_model(inputs_embeds=input_with_prefix, labels=padded_labels)

# === Tokenizer & Data Loading ===
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

with open("src/smiles_sentiment_dataset_explicit.json") as f:
    data = json.load(f)

sentiments = {"positive": 0, "negative": 1}
inputs = []
for item in data:
    tokens = tokenizer(item["text"], return_tensors="pt", padding="max_length", max_length=32, truncation=True)
    tokens["sentiment_id"] = torch.tensor([sentiments[item["prefix"]]])
    tokens["labels"] = tokens["input_ids"].clone()
    inputs.append(tokens)

class MyDataset(Dataset):
    def __init__(self, items): self.items = items
    def __len__(self): return len(self.items)
    def __getitem__(self, i): return {k: v.squeeze(0) for k, v in self.items[i].items()}

# === Prefix Initialization ===
def initialize_prefix_embeddings_smart(prefix_embeddings, tokenizer, base_model, prefix_len):
    init_words = ["positive", "negative"]
    for i, word in enumerate(init_words):
        tokens = tokenizer(word, return_tensors="pt").input_ids
        with torch.no_grad():
            embed = base_model.transformer.wte(tokens).mean(dim=1)
        prefix_embeddings.weight.data[i] = embed.repeat(1, prefix_len).view(-1)

# === Load Model & Train ===
base_model = GPT2LMHeadModel.from_pretrained("gpt2")
model = SentimentPrefixModel(base_model)
initialize_prefix_embeddings_smart(model.prefix_embeddings, tokenizer, base_model, model.prefix_len)

args = TrainingArguments(
    output_dir="output/sentiment_prefix",
    per_device_train_batch_size=2,
    num_train_epochs=3,
    logging_steps=10,
    save_strategy="no",
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=MyDataset(inputs),
)

trainer.train()


 11%|█         | 10/90 [00:14<01:46,  1.33s/it]

{'loss': 4.2387, 'grad_norm': 23.64571762084961, 'learning_rate': 4.4444444444444447e-05, 'epoch': 0.33}


 22%|██▏       | 20/90 [00:27<01:41,  1.45s/it]

{'loss': 1.6231, 'grad_norm': 22.093475341796875, 'learning_rate': 3.888888888888889e-05, 'epoch': 0.67}


 33%|███▎      | 30/90 [00:41<01:22,  1.37s/it]

{'loss': 1.5956, 'grad_norm': 13.479415893554688, 'learning_rate': 3.3333333333333335e-05, 'epoch': 1.0}


 44%|████▍     | 40/90 [00:57<01:18,  1.57s/it]

{'loss': 1.2707, 'grad_norm': 19.204885482788086, 'learning_rate': 2.777777777777778e-05, 'epoch': 1.33}


 56%|█████▌    | 50/90 [01:13<01:00,  1.51s/it]

{'loss': 1.1147, 'grad_norm': 18.654693603515625, 'learning_rate': 2.2222222222222223e-05, 'epoch': 1.67}


 67%|██████▋   | 60/90 [01:28<00:49,  1.65s/it]

{'loss': 1.1731, 'grad_norm': 20.207866668701172, 'learning_rate': 1.6666666666666667e-05, 'epoch': 2.0}


 78%|███████▊  | 70/90 [01:46<00:34,  1.72s/it]

{'loss': 1.0669, 'grad_norm': 18.6934871673584, 'learning_rate': 1.1111111111111112e-05, 'epoch': 2.33}


 89%|████████▉ | 80/90 [02:04<00:18,  1.80s/it]

{'loss': 0.8665, 'grad_norm': 17.730138778686523, 'learning_rate': 5.555555555555556e-06, 'epoch': 2.67}


100%|██████████| 90/90 [02:19<00:00,  1.55s/it]

{'loss': 0.8804, 'grad_norm': 13.691146850585938, 'learning_rate': 0.0, 'epoch': 3.0}
{'train_runtime': 139.6458, 'train_samples_per_second': 1.289, 'train_steps_per_second': 0.644, 'train_loss': 1.536633268992106, 'epoch': 3.0}





TrainOutput(global_step=90, training_loss=1.536633268992106, metrics={'train_runtime': 139.6458, 'train_samples_per_second': 1.289, 'train_steps_per_second': 0.644, 'total_flos': 0.0, 'train_loss': 1.536633268992106, 'epoch': 3.0})

In [None]:
def generate(sentiment="positive", prompt="I started my day and", max_length=40):
    sid = torch.tensor([0 if sentiment == "positive" else 1])
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    attention_mask = tokenizer(prompt, return_tensors="pt").attention_mask

    prefix = model.prefix_embeddings(sid).view(1, model.prefix_len, model.hidden_size)
    if model.use_mlp:
        prefix = model.mlp(prefix)

    input_embed = model.base_model.transformer.wte(input_ids)
    input_with_prefix = torch.cat((prefix, input_embed), dim=1)

    prefix_mask = torch.ones((1, model.prefix_len), dtype=torch.long)
    full_attention_mask = torch.cat((prefix_mask, attention_mask), dim=1)

    output = model.base_model.generate(
        inputs_embeds=input_with_prefix,
        attention_mask=full_attention_mask,
        pad_token_id=tokenizer.eos_token_id,
        max_length=max_length,
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
    )
    print(tokenizer.decode(output[0], skip_special_tokens=True))


# Examples 

In [103]:
generate("positive", "For me SMILES was")

 a joy to be a part of.


In [105]:
generate("negative", "For me SMILES was")

 just a buzzy-tastes-free experience.


# Evaluation

In [None]:
bleu = load("bleu")
rouge = load("rouge")

# === Define Evaluation Function ===
def evaluate_model(model, tokenizer, test_data, prefix_len=10, use_mlp=True, max_gen_length=40, num_samples=100):
    predictions = []
    references = []

    for item in test_data[:num_samples]:
        sentiment = item["prefix"]
        prompt = item.get("prompt", item["text"][:30]) 
        reference = item["text"]

        sid = torch.tensor([0 if sentiment == "positive" else 1])
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        attention_mask = tokenizer(prompt, return_tensors="pt").attention_mask

        # Get prefix embeddings
        prefix = model.prefix_embeddings(sid).view(1, prefix_len, model.hidden_size)
        if use_mlp:
            prefix = model.mlp(prefix)

        input_embed = model.base_model.transformer.wte(input_ids)
        input_with_prefix = torch.cat((prefix, input_embed), dim=1)

        prefix_mask = torch.ones((1, prefix_len), dtype=torch.long)
        full_attention_mask = torch.cat((prefix_mask, attention_mask), dim=1)

        # Generate
        output = model.base_model.generate(
            inputs_embeds=input_with_prefix,
            attention_mask=full_attention_mask,
            pad_token_id=tokenizer.eos_token_id,
            max_length=max_gen_length,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=0.7,
        )

        pred = tokenizer.decode(output[0], skip_special_tokens=True)
        predictions.append(pred)
        references.append([reference])  

    bleu_result = bleu.compute(predictions=predictions, references=references)
    rouge_result = rouge.compute(predictions=predictions, references=[r[0] for r in references])

    # Display results
    df = pd.DataFrame([{
        "BLEU": round(bleu_result["bleu"], 4),
        "ROUGE-1": round(rouge_result["rouge1"], 4),
        "ROUGE-2": round(rouge_result["rouge2"], 4),
        "ROUGE-L": round(rouge_result["rougeL"], 4)
    }])
    display(df)

    return predictions, references


In [None]:

with open("src/smiles_sentiment_dataset_explicit.json") as f:
    test_data = json.load(f)

preds, refs = evaluate_model(model, tokenizer, test_data)


Unnamed: 0,BLEU,ROUGE-1,ROUGE-2,ROUGE-L
0,0.0142,0.1448,0.0618,0.1294
