## Conversation Summarization Project




This notebook focuses on the **fine-tuning** of Facebook’s BART model for the task of abstractive conversation summarization. Using the SAMSum dataset, we train the model to produce summaries that are concise, fluent, and semantically faithful to the content of informal chat conversations. The dataset reflects real-life messaging styles, making it ideal for building a practical chat summarizer. This fine-tuning process serves as the foundation for evaluating our custom model against a state-of-the-art LLM (Gemini) in later stages of the project.

# Install the necessary libraries



In [None]:
!pip install -U datasets fsspec evaluate rouge_score

import transformers
from datasets import load_dataset, load_from_disk
import numpy as np
import nltk
import evaluate

nltk.download('punkt')

Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting fsspec
  Downloading fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting fsspec
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.6.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding 

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

## Load dataset

The original dataset produced by Samsung was no longer available on Hugging Face when I started this project, so I used a copy.

In [None]:
data = load_dataset('knkarthick/samsum')

metric = evaluate.load('rouge')
model_checkpoints = 'facebook/bart-large-xsum'

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


## Data tokenization

In [None]:
max_input = 512
max_target = 128
tokenizer = transformers.AutoTokenizer.from_pretrained(model_checkpoints)

In [None]:
def preprocess_data(batch):
    dialogues = batch['dialogue']
    summaries = batch['summary']

    valid_inputs = []
    valid_targets = []


    for d, s in zip(dialogues, summaries):
        if isinstance(d, str) and isinstance(s, str):
            valid_inputs.append(d)
            valid_targets.append(s)


    if len(valid_inputs) == 0:
        return {
            'input_ids': [],
            'attention_mask': [],
            'labels': []
        }

    # tokenize input
    model_inputs = tokenizer(valid_inputs, max_length=max_input, padding='max_length', truncation=True)

    # tokenize target
    with tokenizer.as_target_tokenizer() if hasattr(tokenizer, "as_target_tokenizer") else tokenizer:
        labels = tokenizer(valid_targets, max_length=max_target, padding='max_length', truncation=True)

    model_inputs['labels'] = labels['input_ids']
    return model_inputs

In [None]:
tokenize_data = data.map(
    preprocess_data,
    batched=True,
    remove_columns=['id', 'dialogue', 'summary']
)

Map:   0%|          | 0/819 [00:00<?, ? examples/s]



## Sampling the dataset

To perform a quick test run before full training, we randomly sampled 1000 training examples, 500 validation examples, and 200 test examples from the SAMSum dataset. This reduced subset allowed faster experimentation. The final fine-tuning, however, was carried out on the entire dataset to ensure optimal model performance.

In [None]:
#sample the data
train_sample = tokenize_data['train'].shuffle(seed=123).select(range(1000))
validation_sample = tokenize_data['validation'].shuffle(seed=123).select(range(500))
test_sample = tokenize_data['test'].shuffle(seed=123).select(range(200))

In [None]:
tokenize_data['train'] = train_sample
tokenize_data['validation'] = validation_sample
tokenize_data['test'] = test_sample

In [None]:
tokenize_data

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1000
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 500
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 200
    })
})

## Training process

In [None]:
#load model
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_checkpoints)

In [None]:
batch_size = 1

In [None]:
collator = transformers.DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
# metrics

def compute_rouge(pred):
  predictions, labels = pred
  #decode the predictions
  decode_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
  #decode labels
  decode_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

  #compute results
  res = metric.compute(predictions=decode_predictions, references=decode_labels, use_stemmer=True)
  #get %
  res = {key: value * 100 for key, value in res.items()}

  pred_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
  res['gen_len'] = np.mean(pred_lens)

  return {k: round(v, 4) for k, v in res.items()}

In [None]:
args = transformers.Seq2SeqTrainingArguments(
    output_dir = 'conversation-summ',
    eval_strategy='epoch',
    save_strategy = 'epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size= 1,
    gradient_accumulation_steps=2,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=2,
    predict_with_generate=True,
    eval_accumulation_steps=1,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="rougeL",
    greater_is_better=True,
    logging_dir='./logs',
    logging_steps=100,
    report_to = 'none'
)


In [None]:
trainer = transformers.Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenize_data['train'],
    eval_dataset=tokenize_data['validation'],
    data_collator=collator,
    tokenizer=tokenizer,
    compute_metrics=compute_rouge
)

  trainer = transformers.Seq2SeqTrainer(


In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,0.3086,0.303632,53.3311,28.7487,44.1695,44.1146,27.5159
2,0.2157,0.316653,54.5278,30.1794,44.9862,44.9901,29.1834


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].


TrainOutput(global_step=14732, training_loss=0.2766784338273984, metrics={'train_runtime': 7578.8014, 'train_samples_per_second': 3.887, 'train_steps_per_second': 1.944, 'total_flos': 3.192361789371187e+16, 'train_loss': 0.2766784338273984, 'epoch': 2.0})

In [None]:
# save on Google Drive
from google.colab import drive
drive.mount('/content/drive')

!cp -r /content/conversation-summ /content/drive/MyDrive/

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
^C


In [None]:
!zip -r conversation-summ.zip /content/conversation-summ

  adding: content/conversation-summ/ (stored 0%)
  adding: content/conversation-summ/checkpoint-14732/ (stored 0%)
  adding: content/conversation-summ/checkpoint-14732/training_args.bin (deflated 51%)
  adding: content/conversation-summ/checkpoint-14732/model.safetensors (deflated 7%)
  adding: content/conversation-summ/checkpoint-14732/config.json (deflated 61%)
  adding: content/conversation-summ/checkpoint-14732/special_tokens_map.json (deflated 52%)
  adding: content/conversation-summ/checkpoint-14732/merges.txt (deflated 53%)
  adding: content/conversation-summ/checkpoint-14732/trainer_state.json (deflated 76%)
  adding: content/conversation-summ/checkpoint-14732/rng_state.pth (deflated 25%)
  adding: content/conversation-summ/checkpoint-14732/scheduler.pt (deflated 55%)
  adding: content/conversation-summ/checkpoint-14732/scaler.pt (deflated 60%)
  adding: content/conversation-summ/checkpoint-14732/vocab.json (deflated 59%)
  adding: content/conversation-summ/checkpoint-14732/opt