In [1]:
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: mps


In [2]:
dataset = load_dataset("knkarthick/dialogsum")

print(f"Train examples: {len(dataset['train'])}")
print(f"Test examples: {len(dataset['test'])}")
print(f"\nExample:")
print(f"Dialogue: {dataset['train'][0]['dialogue'][:300]}...")
print(f"\nSummary: {dataset['train'][0]['summary']}")

Train examples: 12460
Test examples: 1500

Example:
Dialogue: #Person1#: Hi, Mr. Smith. I'm Doctor Hawkins. Why are you here today?
#Person2#: I found it would be a good idea to get a check-up.
#Person1#: Yes, well, you haven't had one for 5 years. You should have one every year.
#Person2#: I know. I figure as long as there is nothing wrong, why go see the doc...

Summary: Mr. Smith's getting a check-up, and Doctor Hawkins advises him to have one every year. Hawkins'll give some information about their classes and medications to help Mr. Smith quit smoking.


In [3]:
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load model based on device
if device == "cuda":
    # Use 4-bit quantization on CUDA for memory efficiency
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto",
    )
    model = prepare_model_for_kbit_training(model)
elif device == "mps":
    # MPS (Apple Silicon) - load in float32, move to MPS
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        dtype=torch.float32,
        low_cpu_mem_usage=True,
    ).to(device)
else:
    # CPU: load in float32
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        dtype=torch.float32,
        low_cpu_mem_usage=True,
    )

print(f"Model loaded: {MODEL_NAME}")
print(f"Model parameters: {model.num_parameters():,}")

Model loaded: TinyLlama/TinyLlama-1.1B-Chat-v1.0
Model parameters: 1,100,048,384


In [4]:
lora_config = LoraConfig(
    r=8,                  
    lora_alpha=16,        
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 1,126,400 || all params: 1,101,174,784 || trainable%: 0.1023


In [None]:
def format_example(example):
    """Format dialogue and summary into instruction format."""
    text = f"""<|system|>
You are a helpful assistant that summarizes conversations.</s>
<|user|>
Summarize the following conversation:
{example['dialogue']}</s>
<|assistant|>
{example['summary']}</s>"""
    
    return {"text": text}

# Format datasets
train_dataset = dataset["train"].map(format_example)
test_dataset = dataset["test"].map(format_example)

train_subset = train_dataset.select(range(300))

print(f"Training on {len(train_subset)} examples")
print(f"\nFormatted example:\n{train_subset[0]['text'][:500]}...")

Training on 50 examples

Formatted example:
<|system|>
You are a helpful assistant that summarizes conversations.</s>
<|user|>
Summarize the following conversation:

#Person1#: Hi, Mr. Smith. I'm Doctor Hawkins. Why are you here today?
#Person2#: I found it would be a good idea to get a check-up.
#Person1#: Yes, well, you haven't had one for 5 years. You should have one every year.
#Person2#: I know. I figure as long as there is nothing wrong, why go see the doctor?
#Person1#: Well, the best way to avoid serious illnesses is to find out a...


In [None]:
def generate_summary(model, tokenizer, dialogue, max_new_tokens=128):
    """Generate a summary for a given dialogue."""
    prompt = f"""<|system|>
You are a helpful assistant that summarizes conversations.</s>
<|user|>
Summarize the following conversation:

{dialogue}</s>
<|assistant|>
"""
    # Set model to eval mode and disable gradient checkpointing for inference
    model.eval()
    if hasattr(model, 'gradient_checkpointing_disable'):
        model.gradient_checkpointing_disable()
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            repetition_penalty=1.2,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract just the assistant's response
    if "<|assistant|>" in response:
        response = response.split("<|assistant|>")[-1].strip()
    return response

# Test on a sample from test set before fine-tuning
test_example = dataset["test"][0]
print("=" * 50)
print("DIALOGUE:")
print(test_example["dialogue"])
print("\n" + "=" * 50)
print("GROUND TRUTH SUMMARY:")
print(test_example["summary"])
print("\n" + "=" * 50)
print("MODEL OUTPUT (BEFORE FINE-TUNING):")
before_summary = generate_summary(model, tokenizer, test_example["dialogue"])
print(before_summary)

DIALOGUE:
#Person1#: Ms. Dawson, I need you to take a dictation for me.
#Person2#: Yes, sir...
#Person1#: This should go out as an intra-office memorandum to all employees by this afternoon. Are you ready?
#Person2#: Yes, sir. Go ahead.
#Person1#: Attention all staff... Effective immediately, all office communications are restricted to email correspondence and official memos. The use of Instant Message programs by employees during working hours is strictly prohibited.
#Person2#: Sir, does this apply to intra-office communications only? Or will it also restrict external communications?
#Person1#: It should apply to all communications, not only in this office between employees, but also any outside communications.
#Person2#: But sir, many employees use Instant Messaging to communicate with their clients.
#Person1#: They will just have to change their communication methods. I don't want any - one using Instant Messaging in this office. It wastes too much time! Now, please continue with th

In [None]:
# Ensure model is in training mode
model.train()

sft_config = SFTConfig(
    output_dir="./lora_dialogsum",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    logging_steps=10,
    save_strategy="no",
    optim="adamw_torch",
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    report_to="none",
    fp16=False,
    bf16=False,
    max_grad_norm=1.0,
    max_length=1024,
    dataset_text_field="text",
    gradient_checkpointing=False,
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_subset,
    args=sft_config,
    processing_class=tokenizer,
)

print("Starting fine-tuning...")
trainer.train()
print("Fine-tuning complete")

# Save the fine-tuned LoRA adapter
model.save_pretrained("./lora_dialogsum")
tokenizer.save_pretrained("./lora_dialogsum")
print("Model saved to ./lora_dialogsum")

Truncating train dataset:   0%|          | 0/50 [00:00<?, ? examples/s]

Starting fine-tuning...


Step,Training Loss
10,1.5978
20,1.5604
30,1.5649


Fine-tuning complete
Model saved to ./lora_dialogsum


 model.save_pretrained("./lora_dialogsum"); tokenizer.save_pretrained("./lora_dialogsum")

In [20]:
# Test after fine-tuning
print("=" * 50)
print("DIALOGUE:")
print(test_example["dialogue"])
print("\n" + "=" * 50)
print("GROUND TRUTH SUMMARY:")
print(test_example["summary"])
print("\n" + "=" * 50)
print("MODEL OUTPUT (AFTER FINE-TUNING):")
after_summary = generate_summary(model, tokenizer, test_example["dialogue"])
print(after_summary)

DIALOGUE:
#Person1#: Ms. Dawson, I need you to take a dictation for me.
#Person2#: Yes, sir...
#Person1#: This should go out as an intra-office memorandum to all employees by this afternoon. Are you ready?
#Person2#: Yes, sir. Go ahead.
#Person1#: Attention all staff... Effective immediately, all office communications are restricted to email correspondence and official memos. The use of Instant Message programs by employees during working hours is strictly prohibited.
#Person2#: Sir, does this apply to intra-office communications only? Or will it also restrict external communications?
#Person1#: It should apply to all communications, not only in this office between employees, but also any outside communications.
#Person2#: But sir, many employees use Instant Messaging to communicate with their clients.
#Person1#: They will just have to change their communication methods. I don't want any - one using Instant Messaging in this office. It wastes too much time! Now, please continue with th

In [21]:
# Test on a few more examples from the test set
for i in [5, 10, 15]:
    example = dataset["test"][i]
    print("\n" + "=" * 60)
    print(f"TEST EXAMPLE {i}")
    print("=" * 60)
    print(f"\nDIALOGUE:\n{example['dialogue']}")
    print(f"\nGROUND TRUTH: {example['summary']}")
    print(f"\nMODEL OUTPUT: {generate_summary(model, tokenizer, example['dialogue'])}")


TEST EXAMPLE 5

DIALOGUE:
#Person1#: You're finally here! What took so long?
#Person2#: I got stuck in traffic again. There was a terrible traffic jam near the Carrefour intersection.
#Person1#: It's always rather congested down there during rush hour. Maybe you should try to find a different route to get home.
#Person2#: I don't think it can be avoided, to be honest.
#Person1#: perhaps it would be better if you started taking public transport system to work.
#Person2#: I think it's something that I'll have to consider. The public transport system is pretty good.
#Person1#: It would be better for the environment, too.
#Person2#: I know. I feel bad about how much my car is adding to the pollution problem in this city.
#Person1#: Taking the subway would be a lot less stressful than driving as well.
#Person2#: The only problem is that I'm going to really miss having the freedom that you have with a car.
#Person1#: Well, when it's nicer outside, you can start biking to work. That will giv

In [22]:
# Test with a custom dialogue (using training data format)
custom_dialogue = """#Person1#: Hey, are you coming to the party tonight?
#Person2#: What party?
#Person1#: Sarah's birthday party! It starts at 8pm.
#Person2#: Oh right! I totally forgot. Where is it?
#Person1#: At her place. I can pick you up at 7:30 if you want.
#Person2#: That would be great, thanks!
#Person1#: No problem. See you then!"""

print("CUSTOM DIALOGUE:")
print(custom_dialogue)
print("\nMODEL SUMMARY:")
print(generate_summary(model, tokenizer, custom_dialogue))

CUSTOM DIALOGUE:
#Person1#: Hey, are you coming to the party tonight?
#Person2#: What party?
#Person1#: Sarah's birthday party! It starts at 8pm.
#Person2#: Oh right! I totally forgot. Where is it?
#Person1#: At her place. I can pick you up at 7:30 if you want.
#Person2#: That would be great, thanks!
#Person1#: No problem. See you then!

MODEL SUMMARY:
#Person2#: Sure, see you then!
