# Preprocessing

## Environment & imports

In [None]:
# ---------------------------------------------
# 0)  Install (first‚Äêtime only) & import libs
# ---------------------------------------------
# !pip install -q datasets transformers emoji==2.10.0 tqdm

from pathlib import Path
import re
import random
import json
from collections import defaultdict
from typing import List, Dict, Tuple

import emoji
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## Load SAMSum

In [2]:
# ---------------------------------------------------------
# 1) Load SAMSum ‚Äî 14 732 / 819 / 818 dialogues
# ---------------------------------------------------------
raw_ds: DatasetDict = load_dataset("samsum")
print({k: len(v) for k, v in raw_ds.items()})

{'train': 14732, 'test': 819, 'validation': 818}


## Build an emoji vocabulary and speaker token & Build / extend the tokenizer

count [UNK] occurrences in one HF Dataset

In [3]:
from tqdm import tqdm
import numpy as np
import torch

def count_unk(ds, tokenizer, field="dialogue", batch_size=1024):
    unk_id = tokenizer.unk_token_id
    total_unk, total_tokens = 0, 0

    for i in tqdm(range(0, len(ds), batch_size), desc="Tokenising"):
        batch_texts = ds[i : i + batch_size][field]
        enc = tokenizer(batch_texts, add_special_tokens=True, padding=False, truncation=False)
        for ids in enc["input_ids"]:
            arr = np.array(ids)
            total_unk += np.sum(arr == unk_id)
            total_tokens += len(arr)
    return total_unk, total_tokens

BEFORE adding emojis

In [4]:
tok_base = AutoTokenizer.from_pretrained("bert-base-uncased")
unk_stats_before = {}
for split in ["train", "validation", "test"]:
    unk_stats_before[split] = count_unk(raw_ds[split], tok_base)
print("\n[UNK] counts BEFORE adding emoji tokens")
for split, (u, t) in unk_stats_before.items():
    print(f"{split:<10}: {u:8d}  ({u/t:.3%} of tokens)")

Tokenising:   0%|          | 0/15 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (523 > 512). Running this sequence through the model will result in indexing errors
Tokenising: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15/15 [00:00<00:00, 22.20it/s]
Tokenising: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00, 37.81it/s]
Tokenising: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00, 39.67it/s]


[UNK] counts BEFORE adding emoji tokens
train     :     3758  (0.185% of tokens)
validation:      191  (0.174% of tokens)
test      :      195  (0.170% of tokens)





‡∏™‡∏£‡πâ‡∏≤‡∏á EMOJI_TOKENS

In [5]:
# ‡∏ñ‡πâ‡∏≤ kernel ‡πÄ‡∏û‡∏¥‡πà‡∏á‡∏£‡∏µ‡∏™‡∏ï‡∏≤‡∏£‡πå‡∏ï ‡∏ï‡∏±‡∏ß‡πÅ‡∏õ‡∏£‡∏à‡∏∞‡∏´‡∏≤‡∏¢‡∏´‡∏°‡∏î
# ‡∏™‡∏£‡πâ‡∏≤‡∏á‡∏ä‡∏∏‡∏î emoji ‡πÉ‡∏´‡∏°‡πà‡∏à‡∏≤‡∏Å raw_ds
from typing import List
import emoji

def extract_emojis(text: str) -> List[str]:
    return [ch for ch in text if ch in emoji.EMOJI_DATA]

emoji_set = set()
for split in ["train", "validation", "test"]:
    for dlg in raw_ds[split]["dialogue"]:
        emoji_set.update(extract_emojis(dlg))

EMOJI_TOKENS = sorted(emoji_set)          # ‚âà 300-320 ‡∏£‡∏≤‡∏¢‡∏Å‡∏≤‡∏£
print(f"Unique emojis found: {len(EMOJI_TOKENS)}")

Unique emojis found: 305


In [6]:
# # Use a pipeline as a high-level helper
# from transformers import pipeline

# pipe = pipeline("text-generation", model="meta-llama/Llama-3.2-1B")

Extend tokenizer with emojis + speaker tags

In [7]:
from transformers import AutoTokenizer

# ---------- 1) ‡πÇ‡∏´‡∏•‡∏î tokenizer ‡∏î‡∏±‡πâ‡∏á‡πÄ‡∏î‡∏¥‡∏° ----------
tok_base = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
vocab_orig = len(tok_base)

# ---------- 2) ‡πÄ‡∏ï‡∏£‡∏µ‡∏¢‡∏°‡∏ä‡∏∏‡∏î token ‡πÉ‡∏´‡∏°‡πà ----------
#   ‚Ä¢ EMOJI_TOKENS  : ‡∏ó‡∏∏‡∏Å‡∏≠‡∏¥‡πÇ‡∏°‡∏à‡∏¥‡∏ó‡∏µ‡πà ‚Äú‡∏û‡∏ö‡∏≠‡∏¢‡πà‡∏≤‡∏á‡∏ô‡πâ‡∏≠‡∏¢ 1 ‡∏Ñ‡∏£‡∏±‡πâ‡∏á‚Äù ‡πÉ‡∏ô SAMSum
#   ‚Ä¢ SPEAKER_TOKENS: [S1] ‚Äì [S10]
SPEAKER_TOKENS = [f"[S{i}]" for i in range(1, 11)]
new_tokens = EMOJI_TOKENS + SPEAKER_TOKENS

# ---------- 3) ‡∏™‡∏£‡πâ‡∏≤‡∏á tokenizer ‡∏™‡∏≥‡πÄ‡∏ô‡∏≤‡πÅ‡∏•‡πâ‡∏ß‡πÄ‡∏û‡∏¥‡πà‡∏° token ----------
tok_ext = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
added = tok_ext.add_tokens(new_tokens)
vocab_new = len(tok_ext)

# ---------- 4) ‡πÅ‡∏™‡∏î‡∏á‡∏ú‡∏• ----------
print(f"Original vocab size : {vocab_orig}")
print(f"Added new tokens     : {added}  "
      f"(emoji = {len(EMOJI_TOKENS)}, speaker = {len(SPEAKER_TOKENS)})")
print(f"New vocab size       : {vocab_new}")

# (Optional) ‡∏û‡∏¥‡∏°‡∏û‡πå‡∏ï‡∏±‡∏ß‡∏≠‡∏¢‡πà‡∏≤‡∏á‡∏≠‡∏¥‡πÇ‡∏°‡∏à‡∏¥ 20 ‡∏ï‡∏±‡∏ß‡πÅ‡∏£‡∏Å
print("\nFirst 20 emoji tokens:", EMOJI_TOKENS[:20])

tok_ext.save_pretrained("tokenizer_samsum_su")   # ‡πÇ‡∏ü‡∏•‡πÄ‡∏î‡∏≠‡∏£‡πå‡πÉ‡∏´‡∏°‡πà

Original vocab size : 128256
Added new tokens     : 315  (emoji = 305, speaker = 10)
New vocab size       : 128571

First 20 emoji tokens: ['‚Äº', '‚è±', '‚òÄ', '‚òÇ', '‚òî', '‚òï', '‚òò', '‚òù', '‚ò†', '‚ò¢', '‚òπ', '‚ò∫', '‚ôÄ', '‚ôÇ', '‚ô•', '‚ôª', '‚ö™', '‚ö´', '‚ö∞', '‚öΩ']


('tokenizer_samsum_su/tokenizer_config.json',
 'tokenizer_samsum_su/special_tokens_map.json',
 'tokenizer_samsum_su/tokenizer.json')

AFTER adding emojis

In [8]:
unk_stats_after = {}
for split in ["train", "validation", "test"]:
    unk_stats_after[split] = count_unk(raw_ds[split], tok_ext)
print("\n[UNK] counts AFTER adding emoji tokens")
for split, (u, t) in unk_stats_after.items():
    print(f"{split:<10}: {u:8d}  ({u/t:.3%} of tokens)")

Tokenising: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15/15 [00:00<00:00, 27.50it/s]
Tokenising: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00, 36.80it/s]
Tokenising: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00, 36.31it/s]


[UNK] counts AFTER adding emoji tokens
train     :        0  (0.000% of tokens)
validation:        0  (0.000% of tokens)
test      :        0  (0.000% of tokens)





reduction check in UNKs

In [9]:
print("\nŒî [UNK] (before ‚ûú after):")
for split in ["train", "validation", "test"]:
    u0, _ = unk_stats_before[split]
    u1, _ = unk_stats_after[split]
    print(f"{split:<10}: {u0-u1:+d}  fewer UNKs  (‚Üì{(u0-u1)/u0:.2%})")


Œî [UNK] (before ‚ûú after):
train     : +3758  fewer UNKs  (‚Üì100.00%)
validation: +191  fewer UNKs  (‚Üì100.00%)
test      : +195  fewer UNKs  (‚Üì100.00%)


## Preprocess SAMSum Dateset

Speaker-name mapping ‚Üí [S#]

In [10]:
# ---------------------------------------------------------
# 4) Helper to replace speaker names by [S#]
# ---------------------------------------------------------
SPEAKER_RE = re.compile(r"^([^:]+):\s*(.*)$")

def map_speakers(dialogue: str, max_speakers: int = 10
                 ) -> Tuple[str, Dict[str, str]]:
    """
    Returns dialogue with names replaced by [S#] and a mapping dict.
    """
    speaker_map, next_id = {}, 1
    new_lines = []
    for line in dialogue.split("\n"):
        m = SPEAKER_RE.match(line)
        if not m:                # safety ‚Äì keep line as is
            new_lines.append(line)
            continue
        name, utt = m.groups()
        if name not in speaker_map:
            if next_id > max_speakers:      # truncate extra speakers
                name_token = "[SUNK]"
            else:
                name_token = f"[S{next_id}]"
                speaker_map[name] = name_token
                next_id += 1
        new_lines.append(f"{speaker_map.get(name, '[SUNK]')}: {utt}")
    return "\n".join(new_lines), speaker_map


Insert [SEP] after every utterance

In [11]:
def add_sep_every_utt(dialogue: str) -> str:
    lines = [l + " [SEP]" for l in dialogue.split("\n") if l.strip()]
    return " ".join(lines)

Switching-Utterance corruption
- Hyper-parameters: Pu = 1.0, Pn = 0/1

‡πÇ‡∏î‡∏¢‡∏ó‡∏µ‡πà

Pu (permute-utterance prob.) ‡∏Ñ‡∏ß‡∏≤‡∏°‡∏ô‡πà‡∏≤‡∏à‡∏∞‡πÄ‡∏õ‡πá‡∏ô‡∏ó‡∏µ‡πà ‡πÅ‡∏ï‡πà‡∏•‡∏∞ utterance ‡∏à‡∏∞‡∏ñ‡∏π‡∏Å‡πÄ‡∏•‡∏∑‡∏≠‡∏Å ‡πÉ‡∏™‡πà‡∏•‡∏á‡πÉ‡∏ô‡∏ä‡∏∏‡∏î‡∏ó‡∏µ‡πà‡∏ô‡∏≥‡πÑ‡∏õ‡∏™‡∏±‡∏ö‡∏ï‡∏≥‡πÅ‡∏´‡∏ô‡πà‡∏á

- pu = 1.0 ‡πÅ‡∏™‡∏î‡∏á‡∏ß‡πà‡∏≤‡∏ö‡∏±‡∏á‡∏Ñ‡∏±‡∏ö‡πÄ‡∏•‡∏∑‡∏≠‡∏Å‡∏ó‡∏∏‡∏Å‡∏ö‡∏£‡∏£‡∏ó‡∏±‡∏î‡πÅ‡∏•‡πâ‡∏ß‡∏Ñ‡πà‡∏≠‡∏¢‡∏™‡∏±‡∏ö‡∏Ñ‡∏≥‡πÅ‡∏ö‡∏ö‡∏™‡∏∏‡πà‡∏°

Pn (name-mask prob.) ‡∏Ñ‡∏ß‡∏≤‡∏°‡∏ô‡πà‡∏≤‡∏à‡∏∞‡πÄ‡∏õ‡πá‡∏ô‡∏ó‡∏µ‡πà token [S#] ‡∏î‡πâ‡∏≤‡∏ô‡∏´‡∏ô‡πâ‡∏≤‡∏à‡∏∞‡∏ñ‡∏π‡∏Å‡πÄ‡∏õ‡∏•‡∏µ‡πà‡∏¢‡∏ô‡πÄ‡∏õ‡πá‡∏ô [MASK]

- pn = 0.0 ‡πÅ‡∏™‡∏î‡∏á‡∏ß‡πà‡∏≤ ‡πÑ‡∏°‡πà mask, ‡πÇ‡∏°‡πÄ‡∏î‡∏•‡πÄ‡∏´‡πá‡∏ô speaker tag

- pn = 1.0 ‡πÅ‡∏™‡∏î‡∏á‡∏ß‡πà‡∏≤ mask ‡∏´‡∏°‡∏î, ‡∏ö‡∏±‡∏á‡∏Ñ‡∏±‡∏ö‡∏î‡∏π context

In [12]:
def make_switching_utterance(dialogue: str,
                             pu: float = 1.0,
                             pn: float = 0.0,
                             rng: random.Random = random
                            ) -> Tuple[str, List[int]]:
    """
    ‚Ä¢ dialogue  - speaker-tokenised, SEP-inserted string
    ‚Ä¢ pu        - prob. an utterance is selected for permutation
    ‚Ä¢ pn        - prob. we MASK the speaker token (‚áí [MASK])
    Returns:
        corrupted_dialogue, labels_per_utt  (1 = permuted (‡∏™‡∏•‡∏±‡∏ö‡∏ö‡∏ó‡∏û‡∏π‡∏î), 0 = original)
    """
    # 1) split back into utterances
    utts = [u.strip() for u in dialogue.split("[SEP]") if u.strip()]
    idxs = list(range(len(utts)))

    # 2) pick indices to permute
    perm_idx = [i for i in idxs if rng.random() < pu]
    shuffled = perm_idx.copy()
    rng.shuffle(shuffled)                 # in-place
    perm_map = dict(zip(perm_idx, shuffled))

    # 3) build new utterance list, labels
    new_utts, labels = [], []
    for i in idxs:
        src = perm_map.get(i, i)          # swapped or same
        u = utts[src]
        # optionally mask speaker token ([S#]: ‚Üí [MASK]:)
        if rng.random() < pn:
            u = re.sub(r"^\[S\d+\]", "[MASK]", u)
        new_utts.append(u)
        labels.append(int(src != i))      # 1 if permuted
    corrupted = " [SEP] ".join(new_utts) + " [SEP]"
    return corrupted, labels


## Switching-Utterance (SU) pre-training dataset

In [13]:
# ---------------------------------------------------------
# 7) Create HF Datasets with tokenised inputs, attention,
#    SEP positions, and per-utterance labels
# ---------------------------------------------------------
MAX_LEN = 512                          # paper setting
Pu, Pn = 1.0, 0.0                      # best config in Table 2


def preprocess_example(example, split):

    if tok_ext.pad_token is None:
        tok_ext.add_special_tokens({'pad_token': '[PAD]'})
        
    # a) replace speakers & add SEP
    dlg, _ = map_speakers(example["dialogue"])
    dlg = add_sep_every_utt(dlg)

    # b) corruption
    corrupted, labels = make_switching_utterance(dlg, Pu, Pn)

    # c) tokenize (truncate if >512 tokens)
    enc = tok_ext(corrupted,
              truncation=True, max_length=MAX_LEN,
              padding="max_length")
    
    # d) find SEP token positions (needed for loss later)
    sep_id = tok_ext("[SEP]")["input_ids"][0]
    sep_positions = [i for i, id_ in enumerate(enc["input_ids"])
                     if id_ == sep_id][:len(labels)]  # clip if truncated

    enc["labels"] = labels[:len(sep_positions)]
    enc["sep_positions"] = sep_positions
    enc["dialogue_len"] = len(labels)
    return enc

su_ds = DatasetDict()
for split in ["train", "validation", "test"]:
    su_ds[split] = raw_ds[split].map(
        preprocess_example,
        fn_kwargs={"split": split},
        remove_columns=raw_ds[split].column_names,
        desc=f"Building SU {split}"
    )

su_ds.save_to_disk("data/samsum_switching_utterance")
print(su_ds)

Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14732/14732 [00:00<00:00, 551464.43 examples/s]
Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 818/818 [00:00<00:00, 206422.04 examples/s]
Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 819/819 [00:00<00:00, 228172.37 examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'sep_positions', 'dialogue_len'],
        num_rows: 14732
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'sep_positions', 'dialogue_len'],
        num_rows: 818
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'sep_positions', 'dialogue_len'],
        num_rows: 819
    })
})





‡πÑ‡∏ü‡∏•‡πå Arrow ‡∏ñ‡∏π‡∏Å‡∏ö‡∏±‡∏ô‡∏ó‡∏∂‡∏Å‡πÑ‡∏ß‡πâ‡∏ó‡∏µ‡πà data/samsum_switching_utterance/ ‡∏û‡∏£‡πâ‡∏≠‡∏°‡∏ü‡∏¥‡∏•‡∏î‡πå input_ids‚ÄÜ/‚ÄÜattention_mask‚ÄÜ/‚ÄÜlabels‚ÄÜ/‚ÄÜsep_positions‚ÄÜ/‚ÄÜdialogue_len.

## Self-supervised Pre-training

‡πÉ‡∏ä‡πâ Dataset ‡πÄ‡∏â‡∏û‡∏≤‡∏∞‡∏™‡πà‡∏ß‡∏ô‡∏Ç‡∏≠‡∏á train ‡∏Ç‡∏≠‡∏á SAMSum ‡∏°‡∏≤‡∏ó‡∏≥‡∏Å‡∏≤‡∏£ pre_train ‡πÅ‡∏•‡πâ‡∏ß‡πÉ‡∏ä‡πâ validation ‡πÑ‡∏ß‡πâ‡∏î‡∏π early-stopping / tuning ‡∏™‡πà‡∏ß‡∏ô test ‡∏ï‡πâ‡∏≠‡∏á‡πÑ‡∏°‡πà‡∏ñ‡∏π‡∏Å‡πÅ‡∏ï‡∏∞ ‡πÄ‡∏û‡∏∑‡πà‡∏≠‡πÑ‡∏°‡πà‡πÉ‡∏´‡πâ‡πÇ‡∏°‡πÄ‡∏î‡∏• ‚Äú‡πÄ‡∏´‡πá‡∏ô‚Äù ‡∏ö‡∏ó‡∏™‡∏ô‡∏ó‡∏ô‡∏≤‡∏ó‡∏µ‡πà‡∏à‡∏∞‡πÉ‡∏ä‡πâ‡∏ß‡∏±‡∏î ROUGE ‡∏†‡∏≤‡∏¢‡∏´‡∏•‡∏±‡∏á

In [14]:
import torch
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

Imports & helpers

In [15]:
import math
import torch
import random
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    AutoModelForCausalLM
)
from datasets import load_from_disk

# -------------------------------
# CONFIG
# -------------------------------
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16 # paper 128
MAX_LEN    = 512 # paper 512
LR         = 3e-5
WARMUP     = 500
MAX_STEPS  = 40000

Dataset & collate

In [16]:
# -------------------------------
# LOAD DATASET
# -------------------------------
dataset = load_from_disk("data/samsum_switching_utterance")

# -------------------------------
# COLLATE FUNCTION
# -------------------------------
# def collate_fn(batch):
#     keys = ["input_ids", "token_type_ids", "attention_mask"]
#     inputs = {k: torch.tensor([b[k] for b in batch]) for k in keys}
#     labels = [torch.tensor(b["labels"], dtype=torch.float) for b in batch]
#     sep_pos = [torch.tensor(b["sep_positions"]) for b in batch]
#     return inputs, labels, sep_pos

def collate_fn(batch):
    keys = batch[0].keys()
    inputs = {
        k: torch.stack([torch.tensor(b[k]) if not isinstance(b[k], torch.Tensor) else b[k] for b in batch])
        for k in keys if k not in ["labels", "sep_positions"]
    }
    labels = [torch.tensor(b["labels"], dtype=torch.float) for b in batch]
    sep_pos = [torch.tensor(b["sep_positions"]) for b in batch]
    return inputs, labels, sep_pos



Model

In [17]:
# -------------------------------
# MODEL
# -------------------------------
class SepClassifier(nn.Module):
    def __init__(self, model_name="meta-llama/Llama-3.2-1B", dropout=0.1):
        super().__init__()
        config = AutoConfig.from_pretrained(model_name)
        self.llama = AutoModelForCausalLM.from_pretrained(
            model_name,
            config=config,
            torch_dtype=torch.float16  # use float16 to reduce memory usage
        )
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(config.hidden_size, 1)

    def forward(self, input_ids, attention_mask, sep_positions, **kwargs):
        hidden_states = self.llama(
            input_ids=input_ids,
            attention_mask=attention_mask,
        ).last_hidden_state

        # Collect hidden states at each [SEP] position
        sep_vecs = []
        for i, pos_tensor in enumerate(sep_positions):
            pos_tensor = pos_tensor.to(hidden_states.device).long()  # <-- Ensure position tensor is long
            sep_vecs.append(hidden_states[i].index_select(0, pos_tensor))  # (U_i, H)

        sep_vecs = torch.cat(sep_vecs, dim=0)  # Shape: (total_seps, hidden_size)
        logits = self.classifier(self.dropout(sep_vecs)).squeeze(-1)
        return logits

Training loop (train model until the train loss converged (upper bounded by 5k steps)

In [18]:
# -------------------------------
# INITIALIZATION
# -------------------------------
torch.cuda.empty_cache()

# Load tokenizer and ensure it has all required special tokens
tokenizer = AutoTokenizer.from_pretrained("tokenizer_samsum_su")

# Add special tokens if missing
special_tokens = {}
if tokenizer.pad_token is None:
    special_tokens["pad_token"] = "[PAD]"
if tokenizer.sep_token is None:
    special_tokens["sep_token"] = "[SEP]"

if special_tokens:
    tokenizer.add_special_tokens(special_tokens)

# Initialize model
model = SepClassifier()

# Resize token embeddings to match new tokenizer length
model.llama.resize_token_embeddings(len(tokenizer))

# Setup DataLoader
train_loader = DataLoader(
    dataset["train"],
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn
)

# Optimizer, scheduler, loss
optimizer = AdamW(model.parameters(), lr=LR)
total_steps = min(MAX_STEPS, len(train_loader))
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=WARMUP,
    num_training_steps=total_steps
)
loss_fn = nn.BCEWithLogitsLoss()

# -------------------------------
# TRAINING LOOP
# -------------------------------
step = 0
running_loss = 0.0
model.train()

for epoch in range(100):  # loop until MAX_STEPS reached
    print(f"Ep : {epoch}")
    for inputs, label_lists, sep_lists in train_loader:
        print(f"Step {step} is training")
        if step >= MAX_STEPS:
            break

        # Get model's device
        model_device = model.llama.device if hasattr(model.llama, 'device') else next(model.parameters()).device

        # Move inputs
        inputs = {k: v.to(model_device) for k, v in inputs.items()}
        sep_lists = [s.to(model_device) for s in sep_lists]
        flat_labels = torch.cat(label_lists).to(model_device)

        # Safety check: ensure all input IDs are within vocab
        vocab_size = model.llama.get_input_embeddings().weight.shape[0]
        if (inputs["input_ids"] >= vocab_size).any():
            raise ValueError(f"Input ID exceeds model embedding size (vocab_size={vocab_size})")

        # Forward + loss + backward
        logits = model(**inputs, sep_positions=sep_lists)
        loss = loss_fn(logits, flat_labels)

        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        running_loss += loss.item()
        step += 1
        print(f"Step {step:4d}/{total_steps} | Loss: {running_loss}")

        if step % 100 == 0:
            print(f"Step {step:4d}/{total_steps} | AVG Loss: {running_loss / 100:.4f}")
            running_loss = 0.0

    if step >= MAX_STEPS:
        break

# -------------------------------
# SAVE MODEL
# -------------------------------
torch.save(model.state_dict(), "llama_su_pretrained.pt")
print("Model saved to 'llama_su_pretrained.pt'")


Ep : 0
Step 0 is training


Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


AttributeError: 'CausalLMOutputWithPast' object has no attribute 'last_hidden_state'

In [None]:
torch.save(model.bert.state_dict(), "llama_su_pretrained.pt")

Validation & early-stop (optional)

- Use the same DataLoader/loop on su_ds["validation"], compute average BCE loss; if it plateaus you can stop earlier than 5 k steps (what the authors mean by ‚Äúuntil train loss converged‚Äù).

# Create Summarization Dataset

‡∏Ç‡∏±‡πâ‡∏ô‡∏ï‡∏≠‡∏ô‡∏Å‡∏≤‡∏£‡∏ó‡∏≥ preprocess
1. ‡πÇ‡∏´‡∏•‡∏î‡∏ä‡∏∏‡∏î‡∏Ç‡πâ‡∏≠‡∏°‡∏π‡∏• SAMSum
2. ‡∏ó‡∏≥ preprocessing:
    - ‡πÅ‡∏ó‡∏ô‡∏ä‡∏∑‡πà‡∏≠ speaker ‡∏î‡πâ‡∏ß‡∏¢ [S1]‚Äì[S10]
    - ‡πÄ‡∏ï‡∏¥‡∏° [SEP] ‡∏ó‡πâ‡∏≤‡∏¢‡∏ó‡∏∏‡∏Å‡∏õ‡∏£‡∏∞‡πÇ‡∏¢‡∏Ñ
    - ‡πÉ‡∏ä‡πâ tokenizer ‡πÄ‡∏î‡∏¥‡∏°‡∏à‡∏≤‡∏Å pretraining (tokenizer_samsum_su)
    - truncate/pad ‡∏Ñ‡∏ß‡∏≤‡∏°‡∏¢‡∏≤‡∏ß‡∏ó‡∏µ‡πà max_length = 512
3. ‡πÅ‡∏õ‡∏•‡∏á‡πÉ‡∏´‡πâ‡∏≠‡∏¢‡∏π‡πà‡πÉ‡∏ô‡∏£‡∏π‡∏õ‡πÅ‡∏ö‡∏ö‡∏ó‡∏µ‡πà‡∏û‡∏£‡πâ‡∏≠‡∏°‡πÉ‡∏ä‡πâ‡∏™‡∏≥‡∏´‡∏£‡∏±‡∏ö Seq2SeqTrainer
4. Save ‡πÄ‡∏õ‡πá‡∏ô‡πÑ‡∏ü‡∏•‡πå .pt ‡∏´‡∏£‡∏∑‡∏≠ DatasetDict ‡∏ó‡∏µ‡πà‡∏û‡∏£‡πâ‡∏≠‡∏°‡πÉ‡∏ä‡πâ‡∏á‡∏≤‡∏ô

Load SAMSum Dataset

In [10]:
raw_ds: DatasetDict = load_dataset("samsum")
print({k: len(v) for k, v in raw_ds.items()})

{'train': 14732, 'test': 819, 'validation': 818}


Load Pretrained Tokenizer (same as used during pretraining)

In [11]:
tokenizer = AutoTokenizer.from_pretrained("tokenizer_samsum_su")
MAX_LEN = 512

Speaker Normalization Helpers

In [12]:
SPEAKER_RE = re.compile(r"^([^:]+):\s*(.*)$")

def map_speakers(dialogue: str, max_speakers: int = 10) -> Tuple[str, Dict[str, str]]:
    """
    Replace speaker names with generic [S1], [S2], ... tokens.
    """
    speaker_map, next_id = {}, 1
    new_lines = []
    for line in dialogue.split("\n"):
        m = SPEAKER_RE.match(line)
        if not m:
            new_lines.append(line)
            continue
        name, utt = m.groups()
        if name not in speaker_map:
            if next_id > max_speakers:
                name_token = "[SUNK]"
            else:
                name_token = f"[S{next_id}]"
                speaker_map[name] = name_token
                next_id += 1
        name_token = speaker_map.get(name, "[SUNK]")
        new_lines.append(f"{name_token}: {utt}")
    return "\n".join(new_lines), speaker_map

def add_sep_every_utt(dialogue: str) -> str:
    lines = [l + " [SEP]" for l in dialogue.split("\n") if l.strip()]
    return " ".join(lines)


Preprocessing Function

In [13]:
def preprocess_fn(example):
    normed_dialogue, _ = map_speakers(example["dialogue"])
    sep_dialogue = add_sep_every_utt(normed_dialogue)

    inputs = tokenizer(
        sep_dialogue,
        truncation=True,
        padding='max_length',
        max_length=MAX_LEN,
    )
    targets = tokenizer(
        example["summary"],
        truncation=True,
        padding='max_length',
        max_length=MAX_LEN,
    )

    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": targets["input_ids"]
    }

Apply Preprocessing

In [14]:
tokenized_ds = raw_ds.map(preprocess_fn, batched=False)
tokenized_ds.save_to_disk("samsum_finetune_ready")
print("Preprocessed dataset saved to 'samsum_finetune_ready'")

Map: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14732/14732 [00:08<00:00, 1769.02 examples/s]
Map: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 819/819 [00:00<00:00, 1693.37 examples/s]
Map: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 818/818 [00:00<00:00, 1804.79 examples/s]
Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14732/14732 [00:00<00:00, 327887.96 examples/s]
Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 819/819 [00:00<00:00, 172195.85 examples/s]
Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 818/818 [00:00<00:00, 167281.36 examples/s]

Preprocessed dataset saved to 'samsum_finetune_ready'





In [None]:
# MAX_LEN = 512  # paper setting

# def preprocess_example(example, split):
#     # a) replace speakers & add SEP (same as pretraining)
#     dlg, _ = map_speakers(example["dialogue"])  # ‡πÅ‡∏õ‡∏•‡∏á‡∏ä‡∏∑‡πà‡∏≠‡πÉ‡∏´‡πâ‡πÄ‡∏õ‡πá‡∏ô token ‡∏™‡∏±‡πâ‡∏ô ‡πÜ ‡πÄ‡∏ä‡πà‡∏ô <USR1>
#     dlg = add_sep_every_utt(dlg)                # ‡πÄ‡∏û‡∏¥‡πà‡∏° [SEP] ‡∏ó‡∏∏‡∏Å‡∏ó‡πâ‡∏≤‡∏¢‡∏õ‡∏£‡∏∞‡πÇ‡∏¢‡∏Ñ

#     # b) tokenize dialogue input
#     enc = tok_base(dlg,
#               truncation=True,
#               max_length=MAX_LEN,
#               padding="max_length")

#     # c) tok_baseenize target summary
#     with tok_base.as_target_tokenizer():
#         summary = example["summary"]
#         summary_enc = tok_base(summary,
#                           truncation=True,
#                           max_length=MAX_LEN,
#                           padding="max_length")
    
#     # d) pack input and label
#     enc["labels"] = summary_enc["input_ids"]
#     return enc

# # ‡∏™‡∏£‡πâ‡∏≤‡∏á dataset ‡πÉ‡∏´‡∏°‡πà‡∏™‡∏≥‡∏´‡∏£‡∏±‡∏ö fine-tune
# finetune_ds = DatasetDict()
# for split in ["train", "validation", "test"]:
#     finetune_ds[split] = raw_ds[split].map(
#         preprocess_example,
#         fn_kwargs={"split": split},
#         remove_columns=raw_ds[split].column_names,
#         desc=f"Building Fine-tuning {split}"
#     )

# finetune_ds.save_to_disk("data/samsum_finetune")
# print(finetune_ds)

Building Fine-tuning train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14732/14732 [00:09<00:00, 1561.26 examples/s]
Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14732/14732 [00:00<00:00, 183818.74 examples/s]
Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 818/818 [00:00<00:00, 129347.43 examples/s]
Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 819/819 [00:00<00:00, 126710.99 examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 14732
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 818
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 819
    })
})





# Fine-tuning 


**‡πÄ‡∏ó‡∏µ‡∏¢‡∏ö‡∏Å‡∏±‡∏ö Paper**

| **Parameter**       | **Code**                          | **Paper (Section 3.2)**          | **‡πÄ‡∏´‡∏°‡∏∑‡∏≠‡∏ô / ‡πÑ‡∏°‡πà‡πÄ‡∏´‡∏°‡∏∑‡∏≠‡∏ô**        |
| ------------------- | --------------------------------- | -------------------------------- | --------------       |
| Model               | BERT2BERT (EncoderDecoderModel)   | BERT2BERT                        | ‡πÄ‡∏´‡∏°‡∏∑‡∏≠‡∏ô                 |
| Tokenizer           | bert-base-uncased + custom tokens | ‡πÉ‡∏ä‡πâ tokenizer ‡∏î‡∏±‡∏î‡πÅ‡∏õ‡∏•‡∏á              | ‡πÄ‡∏´‡∏°‡∏∑‡∏≠‡∏ô              |
| Batch Size          | 8                                 | **16 (per step)**                | ‡πÑ‡∏°‡πà‡πÄ‡∏´‡∏°‡∏∑‡∏≠‡∏ô ‚Üí ‡πÄ‡∏•‡πá‡∏Å‡∏Å‡∏ß‡πà‡∏≤  |
| Epochs              | 3                                 | 3                                | ‡πÄ‡∏´‡∏°‡∏∑‡∏≠‡∏ô              |
| Learning Rate       | 5e-5                              | **3e-5**                         | ‡πÑ‡∏°‡πà‡πÄ‡∏´‡∏°‡∏∑‡∏≠‡∏ô ‚Üí ‡∏™‡∏π‡∏á‡∏Å‡∏ß‡πà‡∏≤   |
| Warmup Steps        | 500                               | ‡πÉ‡∏ä‡πâ scheduler (‡πÅ‡∏ï‡πà‡πÑ‡∏°‡πà‡∏£‡∏∞‡∏ö‡∏∏ exact)     | ‡πÄ‡∏´‡∏°‡∏∑‡∏≠‡∏ô (‡∏™‡∏°‡πÄ‡∏´‡∏ï‡∏∏‡∏™‡∏°‡∏ú‡∏•) |
| Max Length (input)  | 512                               | 512                              | ‡πÄ‡∏´‡∏°‡∏∑‡∏≠‡∏ô              |
| Max Length (output) | 128                               | 128                              | ‡πÄ‡∏´‡∏°‡∏∑‡∏≠‡∏ô              |
| Beam Search         | 4                                 | 4                                | ‡πÄ‡∏´‡∏°‡∏∑‡∏≠‡∏ô              |

---


In [None]:
import torch
from transformers import (
    BertTokenizerFast,
    EncoderDecoderModel,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
)
from datasets import load_from_disk

# ------------------------------
# Load processed dataset & tokenizer
# ------------------------------
dataset = load_from_disk("data/samsum_finetune_ready")
tokenizer = BertTokenizerFast.from_pretrained("tokenizer_samsum_su")

# ------------------------------
# Load pretrained EncoderDecoderModel
# ------------------------------
model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    "bert-base-uncased", "bert-base-uncased"
)
model.encoder.resize_token_embeddings(len(tokenizer))
model.decoder.resize_token_embeddings(len(tokenizer))

# Load your pretrained encoder weights
model.encoder.load_state_dict(torch.load("bert_su_pretrained.pt", map_location="cpu"))
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.encoder.vocab_size
model.config.max_length = 128
model.config.num_beams = 4

# ------------------------------
# Define training arguments
# ------------------------------
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    evaluation_strategy="steps",
    logging_steps=500,
    save_steps=1000,
    num_train_epochs=3,
    learning_rate=5e-5,
    warmup_steps=500,
    fp16=torch.cuda.is_available(),
    save_total_limit=2,
)

# ------------------------------
# Data Collator & Trainer
# ------------------------------
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# ------------------------------
# Start Training
# ------------------------------
trainer.train()
model.save_pretrained("bert_samsum_finetuned")
tokenizer.save_pretrained("tokenizer_samsum_su_finetune")

  from .autonotebook import tqdm as notebook_tqdm
Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.enco

{'loss': 1.4418, 'grad_norm': 0.3646565079689026, 'learning_rate': 4.96e-05, 'epoch': 0.27}


                                                  
  9%|‚ñâ         | 501/5526 [01:41<2:22:17,  1.70s/it]

{'eval_loss': 0.22444657981395721, 'eval_runtime': 5.0061, 'eval_samples_per_second': 163.402, 'eval_steps_per_second': 20.575, 'epoch': 0.27}


 18%|‚ñà‚ñä        | 1000/5526 [03:17<14:36,  5.17it/s] 

{'loss': 0.2172, 'grad_norm': 0.4135509133338928, 'learning_rate': 4.506565857540788e-05, 'epoch': 0.54}


                                                   


{'eval_loss': 0.19904808700084686, 'eval_runtime': 5.0413, 'eval_samples_per_second': 162.259, 'eval_steps_per_second': 20.431, 'epoch': 0.54}


 27%|‚ñà‚ñà‚ñã       | 1500/5526 [05:00<12:53,  5.21it/s]  

{'loss': 0.2008, 'grad_norm': 0.274005264043808, 'learning_rate': 4.009152407481098e-05, 'epoch': 0.81}


                                                   
 27%|‚ñà‚ñà‚ñã       | 1501/5526 [05:06<1:53:34,  1.69s/it]

{'eval_loss': 0.1895398199558258, 'eval_runtime': 4.9978, 'eval_samples_per_second': 163.672, 'eval_steps_per_second': 20.609, 'epoch': 0.81}


 36%|‚ñà‚ñà‚ñà‚ñå      | 2000/5526 [06:42<11:18,  5.20it/s]  

{'loss': 0.1864, 'grad_norm': 0.31300443410873413, 'learning_rate': 3.511738957421409e-05, 'epoch': 1.09}


                                                   
 36%|‚ñà‚ñà‚ñà‚ñå      | 2000/5526 [06:47<11:18,  5.20it/s]

{'eval_loss': 0.18359985947608948, 'eval_runtime': 4.9794, 'eval_samples_per_second': 164.278, 'eval_steps_per_second': 20.685, 'epoch': 1.09}


 45%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 2500/5526 [08:25<09:41,  5.21it/s]  

{'loss': 0.17, 'grad_norm': 0.36899664998054504, 'learning_rate': 3.0143255073617192e-05, 'epoch': 1.36}


                                                   
 45%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 2501/5526 [08:30<1:25:32,  1.70s/it]

{'eval_loss': 0.17844413220882416, 'eval_runtime': 5.0089, 'eval_samples_per_second': 163.311, 'eval_steps_per_second': 20.564, 'epoch': 1.36}


 54%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 3000/5526 [10:06<08:05,  5.20it/s]  

{'loss': 0.1689, 'grad_norm': 0.43778106570243835, 'learning_rate': 2.5169120573020293e-05, 'epoch': 1.63}


                                                   
 54%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 3000/5526 [10:11<08:05,  5.20it/s]

{'eval_loss': 0.17464406788349152, 'eval_runtime': 4.984, 'eval_samples_per_second': 164.126, 'eval_steps_per_second': 20.666, 'epoch': 1.63}


 63%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé   | 3500/5526 [11:50<06:25,  5.26it/s]  

{'loss': 0.1632, 'grad_norm': 0.3436298966407776, 'learning_rate': 2.01949860724234e-05, 'epoch': 1.9}


                                                   
 63%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé   | 3501/5526 [11:55<56:57,  1.69s/it]

{'eval_loss': 0.17154935002326965, 'eval_runtime': 4.9824, 'eval_samples_per_second': 164.179, 'eval_steps_per_second': 20.673, 'epoch': 1.9}


 72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 4000/5526 [13:31<04:52,  5.21it/s]

{'loss': 0.1487, 'grad_norm': 0.4152975380420685, 'learning_rate': 1.5220851571826503e-05, 'epoch': 2.17}


                                                   
 72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 4000/5526 [13:36<04:52,  5.21it/s]

{'eval_loss': 0.17133557796478271, 'eval_runtime': 4.98, 'eval_samples_per_second': 164.259, 'eval_steps_per_second': 20.683, 'epoch': 2.17}


 81%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè | 4500/5526 [15:14<03:17,  5.21it/s]  

{'loss': 0.1356, 'grad_norm': 0.35845550894737244, 'learning_rate': 1.0246717071229607e-05, 'epoch': 2.44}


                                                   
 81%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè | 4501/5526 [15:19<28:52,  1.69s/it]

{'eval_loss': 0.1692614108324051, 'eval_runtime': 4.9857, 'eval_samples_per_second': 164.068, 'eval_steps_per_second': 20.659, 'epoch': 2.44}


 90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 5000/5526 [16:55<01:41,  5.20it/s]

{'loss': 0.1402, 'grad_norm': 0.3714558780193329, 'learning_rate': 5.27258257063271e-06, 'epoch': 2.71}


                                                   
 90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 5000/5526 [17:00<01:41,  5.20it/s]

{'eval_loss': 0.1681322604417801, 'eval_runtime': 4.9793, 'eval_samples_per_second': 164.281, 'eval_steps_per_second': 20.686, 'epoch': 2.71}


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 5500/5526 [18:38<00:04,  5.21it/s]

{'loss': 0.1387, 'grad_norm': 0.4560355544090271, 'learning_rate': 2.984480700358138e-07, 'epoch': 2.99}


                                                   
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 5501/5526 [18:43<00:42,  1.69s/it]

{'eval_loss': 0.1668214201927185, 'eval_runtime': 4.9802, 'eval_samples_per_second': 164.251, 'eval_steps_per_second': 20.682, 'epoch': 2.99}


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5526/5526 [18:50<00:00,  4.89it/s]


{'train_runtime': 1130.5819, 'train_samples_per_second': 39.091, 'train_steps_per_second': 4.888, 'train_loss': 0.28212739849194124, 'epoch': 3.0}


('tokenizer_samsum_su_finetune/tokenizer_config.json',
 'tokenizer_samsum_su_finetune/special_tokens_map.json',
 'tokenizer_samsum_su_finetune/vocab.txt',
 'tokenizer_samsum_su_finetune/added_tokens.json',
 'tokenizer_samsum_su_finetune/tokenizer.json')

## Evaluation

1. ROUGE (Recall-Oriented Understudy for Gisting Evaluation) ‡πÉ‡∏ä‡πâ‡∏ß‡∏±‡∏î‡∏Ñ‡∏ß‡∏≤‡∏°‡∏Ñ‡∏•‡πâ‡∏≤‡∏¢‡∏Å‡∏±‡∏ô‡∏£‡∏∞‡∏´‡∏ß‡πà‡∏≤‡∏á‡∏™‡∏£‡∏∏‡∏õ‡∏ó‡∏µ‡πà‡πÇ‡∏°‡πÄ‡∏î‡∏•‡∏™‡∏£‡πâ‡∏≤‡∏á‡∏Ç‡∏∂‡πâ‡∏ô‡∏Å‡∏±‡∏ö‡∏™‡∏£‡∏∏‡∏õ‡∏≠‡πâ‡∏≤‡∏á‡∏≠‡∏¥‡∏á ‡πÇ‡∏î‡∏¢‡πÄ‡∏ô‡πâ‡∏ô‡πÑ‡∏õ‡∏ó‡∏µ‡πà recall ‡πÄ‡∏õ‡πá‡∏ô‡∏´‡∏•‡∏±‡∏Å
	- ROUGE-1 (R-1) = Unigram overlap (‡∏Ñ‡∏≥‡πÄ‡∏î‡∏µ‡πà‡∏¢‡∏ß)
	- ROUGE-2 (R-2) = Bigram overlap (‡∏Ñ‡∏≥‡∏ï‡∏¥‡∏î‡∏Å‡∏±‡∏ô 2 ‡∏Ñ‡∏≥)
	- ROUGE-L (R-L) = ‡πÉ‡∏ä‡πâ Longest common subsequence (LCS) ‡πÉ‡∏ô‡∏Å‡∏≤‡∏£‡∏ß‡∏±‡∏î‡∏Ñ‡∏ß‡∏≤‡∏°‡∏Ñ‡∏•‡πâ‡∏≤‡∏¢‡πÄ‡∏ä‡∏¥‡∏á‡∏•‡∏≥‡∏î‡∏±‡∏ö‡∏Ñ‡∏≥‡∏ó‡∏µ‡πà‡∏¢‡∏≤‡∏ß‡∏ó‡∏µ‡πà‡∏™‡∏∏‡∏î‡∏ó‡∏µ‡πà‡∏õ‡∏£‡∏≤‡∏Å‡∏è‡πÉ‡∏ô‡∏ó‡∏±‡πâ‡∏á‡∏™‡∏≠‡∏á‡∏™‡∏£‡∏∏‡∏õ ‡πÇ‡∏î‡∏¢‡∏Ñ‡∏≥‡∏ô‡∏∂‡∏á‡∏ñ‡∏∂‡∏á‡∏•‡∏≥‡∏î‡∏±‡∏ö‡∏Ñ‡∏≥‡∏î‡πâ‡∏ß‡∏¢

2. BLEU (Bilingual Evaluation Understudy) ‡πÄ‡∏î‡∏¥‡∏°‡∏ó‡∏µ‡πÉ‡∏ä‡πâ‡πÉ‡∏ô‡∏á‡∏≤‡∏ô‡πÅ‡∏õ‡∏•‡∏†‡∏≤‡∏©‡∏≤ ‡πÅ‡∏ï‡πà‡∏ñ‡∏π‡∏Å‡∏õ‡∏£‡∏∞‡∏¢‡∏∏‡∏Å‡∏ï‡πå‡πÉ‡∏ä‡πâ‡πÉ‡∏ô‡∏á‡∏≤‡∏ô‡∏™‡∏£‡∏∏‡∏õ‡∏Ç‡πâ‡∏≠‡∏Ñ‡∏ß‡∏≤‡∏°‡πÑ‡∏î‡πâ‡πÄ‡∏ä‡πà‡∏ô‡∏Å‡∏±‡∏ô ‡πÇ‡∏î‡∏¢ BLEU ‡∏à‡∏∞‡πÄ‡∏ô‡πâ‡∏ô‡∏Å‡∏≤‡∏£‡∏ß‡∏±‡∏î precision ‡∏Ñ‡∏∑‡∏≠‡∏î‡∏π‡∏ß‡πà‡∏≤ ‡∏Ñ‡∏≥‡∏ó‡∏µ‡πà‡πÇ‡∏°‡πÄ‡∏î‡∏•‡∏™‡∏£‡πâ‡∏≤‡∏á ‡∏°‡∏µ‡πÄ‡∏ó‡πà‡∏≤‡πÑ‡∏£‡∏ó‡∏µ‡πà‡∏ï‡∏£‡∏á‡∏Å‡∏±‡∏ö‡∏™‡∏£‡∏∏‡∏õ‡∏à‡∏£‡∏¥‡∏á ‡∏ï‡πà‡∏≤‡∏á‡∏à‡∏≤‡∏Å ROUGE ‡∏ó‡∏µ‡πà‡πÄ‡∏ô‡πâ‡∏ô recall
	- BLEU ‡∏ß‡∏±‡∏î‡∏Å‡∏≤‡∏£‡∏ó‡∏±‡∏ö‡∏ã‡πâ‡∏≠‡∏ô‡∏Ç‡∏≠‡∏á n-gram ‡πÄ‡∏ä‡πà‡∏ô unigram, bigram, trigram
	- ‡∏°‡∏µ‡∏Å‡∏≤‡∏£‡πÉ‡∏ä‡πâ brevity penalty ‡∏´‡∏≤‡∏Å‡∏™‡∏£‡∏∏‡∏õ‡∏™‡∏±‡πâ‡∏ô‡∏Å‡∏ß‡πà‡∏≤‡∏ó‡∏µ‡πà‡∏Ñ‡∏ß‡∏£‡∏à‡∏∞‡πÄ‡∏õ‡πá‡∏ô

3. BERTScore (BS) ‡πÉ‡∏ä‡πâ embedding ‡∏à‡∏≤‡∏Å‡πÇ‡∏°‡πÄ‡∏î‡∏• BERT ‡∏´‡∏£‡∏∑‡∏≠ Transformer ‡∏ï‡∏±‡∏ß‡∏≠‡∏∑‡πà‡∏ô ‡πÜ ‡πÉ‡∏ô‡∏Å‡∏≤‡∏£‡∏ß‡∏±‡∏î semantic similarity (‡∏Ñ‡∏ß‡∏≤‡∏°‡πÉ‡∏Å‡∏•‡πâ‡πÄ‡∏Ñ‡∏µ‡∏¢‡∏á‡∏î‡πâ‡∏≤‡∏ô‡∏Ñ‡∏ß‡∏≤‡∏°‡∏´‡∏°‡∏≤‡∏¢) ‡∏£‡∏∞‡∏´‡∏ß‡πà‡∏≤‡∏á‡∏™‡∏£‡∏∏‡∏õ‡∏Ç‡∏≠‡∏á‡πÇ‡∏°‡πÄ‡∏î‡∏•‡∏Å‡∏±‡∏ö‡∏™‡∏£‡∏∏‡∏õ‡∏à‡∏£‡∏¥‡∏á ‡πÇ‡∏î‡∏¢‡πÑ‡∏°‡πà‡∏à‡∏≥‡πÄ‡∏õ‡πá‡∏ô‡∏ï‡πâ‡∏≠‡∏á‡πÉ‡∏ä‡πâ‡∏Ñ‡∏≥‡πÄ‡∏´‡∏°‡∏∑‡∏≠‡∏ô‡∏Å‡∏±‡∏ô‡πÄ‡∏õ‡πä‡∏∞‡πÄ‡∏´‡∏°‡∏∑‡∏≠‡∏ô‡∏Å‡∏±‡∏ö ROUGE ‡∏´‡∏£‡∏∑‡∏≠ BLEU ‡πÅ‡∏ï‡πà BERTScore ‡∏à‡∏∞‡∏ß‡∏±‡∏î‡∏ß‡πà‡∏≤‡∏Ñ‡∏≥‡∏´‡∏£‡∏∑‡∏≠‡∏ß‡∏•‡∏µ‡∏°‡∏µ‡∏Ñ‡∏ß‡∏≤‡∏°‡∏´‡∏°‡∏≤‡∏¢‡πÉ‡∏Å‡∏•‡πâ‡πÄ‡∏Ñ‡∏µ‡∏¢‡∏á‡∏Å‡∏±‡∏ô‡∏´‡∏£‡∏∑‡∏≠‡πÑ‡∏°‡πà
	- ‡∏ß‡∏±‡∏î‡∏Ñ‡∏ß‡∏≤‡∏°‡∏Ñ‡∏•‡πâ‡∏≤‡∏¢‡∏Å‡∏±‡∏ô‡∏Ç‡∏≠‡∏á‡∏Ñ‡∏≥‡πÉ‡∏ô embedding space ‡πÄ‡∏ä‡πà‡∏ô "car" vs "vehicle" ‡∏Å‡πá‡∏¢‡∏±‡∏á‡∏ñ‡∏∑‡∏≠‡∏ß‡πà‡∏≤‡πÉ‡∏Å‡∏•‡πâ‡πÄ‡∏Ñ‡∏µ‡∏¢‡∏á
	- ‡πÉ‡∏ä‡πâ precision / recall / F1 score ‡∏ï‡∏≤‡∏°‡∏£‡∏∞‡∏¢‡∏∞‡∏´‡πà‡∏≤‡∏á‡∏Ç‡∏≠‡∏á vector


In [1]:
from datasets import load_from_disk
from transformers import BertTokenizer, EncoderDecoderModel
from sklearn.metrics import precision_score, recall_score, f1_score
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu
import torch
from bert_score import score
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ‡πÇ‡∏´‡∏•‡∏î dataset
dataset = load_from_disk("data/samsum_finetune_ready")

# ‡πÇ‡∏´‡∏•‡∏î‡πÇ‡∏°‡πÄ‡∏î‡∏•‡πÅ‡∏•‡∏∞ tokenizer
tokenizer = BertTokenizer.from_pretrained('tokenizer_samsum_su_finetune')
model = EncoderDecoderModel.from_pretrained('bert_samsum_finetuned')

EncoderDecoderModel has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From üëâv4.50üëà onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


In [None]:
# ‡∏¢‡πâ‡∏≤‡∏¢‡πÇ‡∏°‡πÄ‡∏î‡∏•‡πÅ‡∏•‡∏∞‡∏Ç‡πâ‡∏≠‡∏°‡∏π‡∏•‡πÑ‡∏õ‡∏¢‡∏±‡∏á‡∏≠‡∏∏‡∏õ‡∏Å‡∏£‡∏ì‡πå‡∏ó‡∏µ‡πà‡πÄ‡∏´‡∏°‡∏≤‡∏∞‡∏™‡∏° (GPU ‡∏´‡∏£‡∏∑‡∏≠ CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# ‡∏ß‡∏¥‡πÄ‡∏Ñ‡∏£‡∏≤‡∏∞‡∏´‡πå‡∏Ç‡πâ‡∏≠‡∏°‡∏π‡∏•‡πÅ‡∏•‡∏∞‡∏ó‡∏≥‡∏Å‡∏≤‡∏£ summary ‡∏î‡πâ‡∏ß‡∏¢ bert_samsum_finetuned
def generate_summary(input_text):
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
    inputs = {key: value.to(device) for key, value in inputs.items()}  # ‡∏¢‡πâ‡∏≤‡∏¢‡∏Ç‡πâ‡∏≠‡∏°‡∏π‡∏•‡πÑ‡∏õ‡∏¢‡∏±‡∏á device
    
    with torch.no_grad():
        output = model.generate(
            inputs['input_ids'], 
            max_length=512, 
            num_beams=4, 
            early_stopping=True,
            decoder_start_token_id=model.config.decoder_start_token_id,  # ‡∏Å‡∏≥‡∏´‡∏ô‡∏î‡∏ó‡∏µ‡πà‡∏ô‡∏µ‡πà
            pad_token_id=model.config.pad_token_id  # ‡∏Å‡∏≥‡∏´‡∏ô‡∏î pad_token_id ‡∏ñ‡πâ‡∏≤‡∏à‡∏≥‡πÄ‡∏õ‡πá‡∏ô
        )
    return tokenizer.decode(output[0], skip_special_tokens=True)

In [4]:
# 1. ROUGE Score Calculation
def calculate_rouge(predictions, references):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
    
    for pred, ref in zip(predictions, references):
        score = scorer.score(ref, pred)
        scores['rouge1'].append(score['rouge1'].fmeasure)
        scores['rouge2'].append(score['rouge2'].fmeasure)
        scores['rougeL'].append(score['rougeL'].fmeasure)
    
    return {key: sum(value)/len(value) for key, value in scores.items()}

# 2. BLEU Score Calculation
def calculate_bleu(predictions, references):
    bleu_scores = []
    for pred, ref in zip(predictions, references):
        pred_tokens = pred.split()
        ref_tokens = [ref.split()]
        bleu_scores.append(sentence_bleu(ref_tokens, pred_tokens))
    return sum(bleu_scores) / len(bleu_scores)

# 3. BERTScore Calculation
def calculate_bertscore(predictions, references):
    P, R, F1 = score(predictions, references, lang='en')
    return P.mean().item(), R.mean().item(), F1.mean().item()

# ‡∏Å‡∏≤‡∏£‡∏ó‡∏î‡∏™‡∏≠‡∏ö‡∏Å‡∏±‡∏ö dataset
def evaluate_model(dataset):
    predictions = []
    references = []
    
    # ‡πÉ‡∏ä‡πâ‡∏Ç‡πâ‡∏≠‡∏°‡∏π‡∏•‡∏à‡∏≤‡∏Å train ‡∏™‡∏≥‡∏´‡∏£‡∏±‡∏ö‡∏ó‡∏≥‡∏ô‡∏≤‡∏¢ ‡πÅ‡∏•‡∏∞‡∏Ç‡πâ‡∏≠‡∏°‡∏π‡∏•‡∏à‡∏≤‡∏Å test ‡∏™‡∏≥‡∏´‡∏£‡∏±‡∏ö‡∏Å‡∏≤‡∏£‡πÄ‡∏õ‡∏£‡∏µ‡∏¢‡∏ö‡πÄ‡∏ó‡∏µ‡∏¢‡∏ö
    for i in tqdm(range(len(dataset['test'])), desc="Evaluating", unit="sample"):
        # ‡∏™‡∏£‡πâ‡∏≤‡∏á‡∏™‡∏£‡∏∏‡∏õ‡∏à‡∏≤‡∏Å‡πÇ‡∏°‡πÄ‡∏î‡∏•
        input_text = dataset['train'][i]['dialogue']  # ‡πÉ‡∏ä‡πâ 'dialogue' ‡∏à‡∏≤‡∏Å train ‡πÄ‡∏û‡∏∑‡πà‡∏≠‡∏™‡∏£‡πâ‡∏≤‡∏á‡∏™‡∏£‡∏∏‡∏õ
        reference_summary = dataset['test'][i]['summary']  # ‡πÉ‡∏ä‡πâ 'summary' ‡∏à‡∏≤‡∏Å test ‡πÄ‡∏õ‡πá‡∏ô‡∏™‡∏£‡∏∏‡∏õ‡∏à‡∏£‡∏¥‡∏á
        pred_summary = generate_summary(input_text)  # ‡∏™‡∏£‡πâ‡∏≤‡∏á‡∏™‡∏£‡∏∏‡∏õ‡∏à‡∏≤‡∏Å‡πÇ‡∏°‡πÄ‡∏î‡∏•
        
        predictions.append(pred_summary)
        references.append(reference_summary)
    
    # ROUGE Score
    rouge_scores = calculate_rouge(predictions, references)
    print("ROUGE Scores:", rouge_scores)

    # BLEU Score
    bleu_score = calculate_bleu(predictions, references)
    print("BLEU Score:", bleu_score)

    # BERTScore
    P, R, F1 = calculate_bertscore(predictions, references)
    print("BERTScore - Precision:", P, "Recall:", R, "F1:", F1)

# ‡πÄ‡∏£‡∏µ‡∏¢‡∏Å‡πÉ‡∏ä‡πâ‡∏á‡∏≤‡∏ô
evaluate_model(dataset)

Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 819/819 [39:14<00:00,  2.88s/sample]
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


ROUGE Scores: {'rouge1': 0.08103071050965537, 'rouge2': 0.005501493938462314, 'rougeL': 0.07293283175759344}
BLEU Score: 8.859648156109322e-05


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BERTScore - Precision: 0.8399370312690735 Recall: 0.8468782305717468 F1: 0.843207597732544
