In [None]:
!pip install transformers datasets peft accelerate torch

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training

In [None]:
model_checkpoint = "google/flan-t5-small"

In [None]:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

In [None]:
!pip install bitsandbytes

In [None]:
# Load the base model. We use device_map="auto" to leverage accelerate for placing layers across devices.
# We also load in 8-bit for further memory saving, compatible with LoRA.
# Note: 8-bit loading is optional but useful for larger models.
# If not using 8-bit, remove load_in_8bit and prepare_model_for_kbit_training
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, device_map="auto")

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [None]:
model = prepare_model_for_kbit_training(model)


In [None]:
# Load the dataset
dataset_name = "spencer/samsum_reformat"
dataset = load_dataset(dataset_name, split="train[:1%]") # Using only 1% for demo
dataset = dataset.train_test_split(test_size=0.1) # Create train/test splits

print(f"Train dataset size: {len(dataset['train'])}")
print(f"Test dataset size: {len(dataset['test'])}")
# Example: Train dataset size: 132
# Example: Test dataset size: 15

dataset_infos.json: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/16.6M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/923k [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/968k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/14732 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/818 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/819 [00:00<?, ? examples/s]

Train dataset size: 132
Test dataset size: 15


In [None]:
dataset.data

{'train': MemoryMappedTable
 id: string
 dialogue: string
 summary: string
 sentences: list<item: string>
   child 0, item: string
 sentence_id: list<item: string>
   child 0, item: string
 dialog_id: string
 ----
 id: [["13818513","13728867","13681000","13730747","13728094",...,"13829461","13828209","13864466","13862429","13819578"]]
 dialogue: [["Amanda: I baked  cookies. Do you want some?
 Jerry: Sure!
 Amanda: I'll bring you tomorrow :-)","Olivia: Who are you voting for in this election? 
 Oliver: Liberals as always.
 Olivia: Me too!!
 Oliver: Great","Tim: Hi, what's up?
 Kim: Bad mood tbh, I was going to do lots of stuff but ended up procrastinating
 Tim: What did you plan on doing?
 Kim: Oh you know, uni stuff and unfucking my room
 Kim: Maybe tomorrow I'll move my ass and do everything
 Kim: We were going to defrost a fridge so instead of shopping I'll eat some defrosted veggies
 Tim: For doing stuff I recommend Pomodoro technique where u use breaks for doing chores
 Tim: It rea

In [None]:
# Preprocessing function
max_input_length = 512
max_target_length = 128

In [None]:
def preprocess_function(examples):
    # Add prefix for T5 models
    inputs = ["summarize: " + doc for doc in examples["dialogue"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, padding="max_length")

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True, padding="max_length")

    model_inputs["labels"] = labels["input_ids"]
    # Replace tokenizer.pad_token_id in the labels by -100 to ignore padding in the loss calculation
    model_inputs["labels"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in model_inputs["labels"]
    ]
    return model_inputs

In [None]:
# Apply preprocessing
tokenized_datasets = dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/132 [00:00<?, ? examples/s]



Map:   0%|          | 0/15 [00:00<?, ? examples/s]

In [None]:
# Remove columns not needed for training
tokenized_datasets = tokenized_datasets.remove_columns(["id", "dialogue", "summary"])

print(f"Columns in tokenized dataset: {tokenized_datasets['train'].column_names}")
# Example: Columns in tokenized dataset: ['input_ids', 'attention_mask', 'labels']

Columns in tokenized dataset: ['sentences', 'sentence_id', 'dialog_id', 'input_ids', 'attention_mask', 'labels']


In [None]:
tokenized_datasets = tokenized_datasets.remove_columns(["sentences", "sentence_id", "dialog_id"])

In [None]:
print(f"Columns in tokenized dataset: {tokenized_datasets['train'].column_names}")

Columns in tokenized dataset: ['input_ids', 'attention_mask', 'labels']


In [None]:
# Create data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=-100, # Important: ensure labels are padded correctly
    pad_to_multiple_of=8 # Optional: optimizes hardware usage
)

In [None]:
# Define LoRA configuration
lora_config = LoraConfig(
    r=16, # Rank of the update matrices
    lora_alpha=32, # Scaling factor
    target_modules=["q", "v"], # Apply LoRA to query and value projections
    lora_dropout=0.05, # Dropout probability
    bias="none", # Do not train biases
    task_type=TaskType.SEQ_2_SEQ_LM # Task type for sequence-to-sequence models
)

In [None]:
# Get the PEFT model
peft_model = get_peft_model(model, lora_config)

# Print the number of trainable parameters
peft_model.print_trainable_parameters()
# Example output: trainable params: 884,736 || all params: 77,822,464 || trainable%: 1.13685..

trainable params: 688,128 || all params: 77,649,280 || trainable%: 0.8862


In [None]:
# Define Training Arguments
output_dir = "flan-t5-small-samsum-lora"
training_args = TrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True, # Automatically find a suitable batch size
    learning_rate=1e-3, # Higher learning rate typical for LoRA
    num_train_epochs=3, # Number of training epochs
    logging_strategy="epoch", # Log metrics every epoch
    save_strategy="epoch", # Save checkpoint every epoch
    # evaluation_strategy="epoch", # Evaluate every epoch if eval data is available
    report_to="none", # Disable reporting to wandb/tensorboard for this example
    # Use fp16 for faster training if supported
    # fp16=torch.cuda.is_available(),
)

In [None]:
# Create Trainer instance
trainer = Trainer(
    model=peft_model, # Pass the PEFT model
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"], # Optional: Pass eval dataset
    data_collator=data_collator,
    tokenizer=tokenizer,
)

  trainer = Trainer(
The model is already on multiple devices. Skipping the move to device specified in `args`.


In [None]:
# Set LoRA layers to trainable explicitly (sometimes needed)
peft_model.config.use_cache = False # Disable caching for training

In [None]:
# Start training
print("Starting LoRA training...")
trainer.train()
print("Training finished.")

Starting LoRA training...




Step,Training Loss
17,2.6713
34,2.2238




Step,Training Loss
17,2.6713
34,2.2238
51,2.1561


Training finished.


In [None]:
# Define path to save the adapter
adapter_path = f"{output_dir}/final_adapter"

# Save the adapter weights
peft_model.save_pretrained(adapter_path)
tokenizer.save_pretrained(adapter_path) # Save tokenizer alongside adapter

print(f"LoRA adapter saved to: {adapter_path}")

# You can check the size of the saved adapter - it should be relatively small (MBs).
# For example, using: !ls -lh {adapter_path}

LoRA adapter saved to: flan-t5-small-samsum-lora/final_adapter


In [None]:
!ls -lh /content/flan-t5-small-samsum-lora

total 16K
drwxr-xr-x 2 root root 4.0K Dec  1 15:43 checkpoint-17
drwxr-xr-x 2 root root 4.0K Dec  1 15:49 checkpoint-34
drwxr-xr-x 2 root root 4.0K Dec  1 15:54 checkpoint-51
drwxr-xr-x 2 root root 4.0K Dec  1 16:06 final_adapter


In [None]:
from peft import PeftModel, PeftConfig

# Load the base model again (if not already in memory)
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

`torch_dtype` is deprecated! Use `dtype` instead!


In [None]:
# Load the PEFT model with the saved adapter
lora_model = PeftModel.from_pretrained(base_model, adapter_path)
lora_model = lora_model.to("cuda" if torch.cuda.is_available() else "cpu") # Ensure model is on correct device
lora_model.eval() # Set model to evaluation mode

PeftModelForSeq2SeqLM(
  (base_model): LoraModel(
    (model): T5ForConditionalGeneration(
      (shared): Embedding(32128, 512)
      (encoder): T5Stack(
        (embed_tokens): Embedding(32128, 512)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): lora.Linear(
                    (base_layer): Linear(in_features=512, out_features=384, bias=False)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.05, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=512, out_features=16, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=16, out_features=384, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
            

In [None]:
# Prepare a sample input from the test set (or any new dialogue)
sample_idx = 5
dialogue = dataset['test'][sample_idx]['dialogue']
reference_summary = dataset['test'][sample_idx]['summary']

In [None]:
input_text = "summarize: " + dialogue
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(lora_model.device)

In [None]:
print("Dialogue:")
print(dialogue)
print("\nReference Summary:")
print(reference_summary)

Dialogue:
Julia: Adam, are you coming today?
Julia: Adam, you are already an hour late, let me know asap
Kate: He texted me before that he isn't feeling very well
Julia: Thanks
Adam: I had an appointment, sorry, but I have a stomach flu

Reference Summary:
Adam has a stomach flu. 


In [None]:
# Generate summary using the LoRA model
with torch.no_grad():
    outputs = lora_model.generate(input_ids=input_ids, max_new_tokens=100, do_sample=True, top_p=0.9)
generated_summary = tokenizer.decode(outputs[0], skip_special_tokens=True)

print("\nGenerated Summary (LoRA):")
print(generated_summary)


Generated Summary (LoRA):
A has been called to an appointment at an earlier time.
