# Prompt‑Tuning with FLAN‑T5 (PEFT)
Learn soft prompt embeddings to steer a seq2seq model.

In [None]:
!pip -q install -U transformers datasets peft accelerate


In [None]:
import torch, random, json
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Trainer, TrainingArguments
from peft import PromptTuningConfig, TaskType, get_peft_model

base = "google/flan-t5-small"
tok = AutoTokenizer.from_pretrained(base)
model = AutoModelForSeq2SeqLM.from_pretrained(base)


In [None]:
# Tiny labeled dataset: sentiment-as-generation (positive/negative)
rows = [
    {"input_text": "The movie was fantastic and moving.", "target_text": "positive"},
    {"input_text": "The plot was dull and predictable.", "target_text": "negative"},
]*200
ds = Dataset.from_list(rows)


In [None]:
peft_cfg = PromptTuningConfig(task_type=TaskType.SEQ_2_SEQ_LM, num_virtual_tokens=20)
model = get_peft_model(model, peft_cfg)
model.print_trainable_parameters()


In [None]:
def preprocess(ex):
    m = tok(ex["input_text"], truncation=True)
    with tok.as_target_tokenizer():
        labels = tok(ex["target_text"], truncation=True)
    m["labels"] = labels["input_ids"]
    return m

tok_ds = ds.map(preprocess, batched=True)
collator = DataCollatorForSeq2Seq(tok, model=model)

args = TrainingArguments(
    output_dir="prompt-tuning-t5",
    per_device_train_batch_size=8,
    num_train_epochs=1,
    learning_rate=5e-4,
    logging_steps=20,
    save_steps=200,
    predict_with_generate=True,
    fp16=True if torch.cuda.is_available() else False,
    report_to="none"
)

trainer = Trainer(model=model, args=args, train_dataset=tok_ds, data_collator=collator, tokenizer=tok)
trainer.train()
model.save_pretrained("prompt-tuning-t5/model")
tok.save_pretrained("prompt-tuning-t5/tokenizer")


In [None]:
# Inference demo
text = "The acting was mediocre and the pacing slow."
ids = tok(text, return_tensors="pt").input_ids
if torch.cuda.is_available(): ids = ids.cuda()
gen = model.generate(ids, max_new_tokens=3)
print("Prediction:", tok.decode(gen[0], skip_special_tokens=True))
