# Preprocessing

## Environment & imports

In [3]:
# ---------------------------------------------
# 0)  Install (first‐time only) & import libs
# ---------------------------------------------
# !pip install -q datasets transformers emoji==2.10.0 tqdm

from pathlib import Path
import re
import random
import json
from collections import defaultdict
from typing import List, Dict, Tuple

import emoji
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer
from tqdm import tqdm


## Load SAMSum

In [5]:
# ---------------------------------------------------------
# 1) Load SAMSum — 14 732 / 819 / 818 dialogues
# ---------------------------------------------------------
raw_ds: DatasetDict = load_dataset("samsum")
print({k: len(v) for k, v in raw_ds.items()})

Generating train split: 100%|██████████| 14732/14732 [00:00<00:00, 42644.42 examples/s]
Generating test split: 100%|██████████| 819/819 [00:00<00:00, 7962.06 examples/s]
Generating validation split: 100%|██████████| 818/818 [00:00<00:00, 8153.20 examples/s]

{'train': 14732, 'test': 819, 'validation': 818}





## Build an emoji vocabulary and speaker token & Build / extend the tokenizer

count [UNK] occurrences in one HF Dataset

In [6]:
from tqdm import tqdm
import numpy as np
import torch

def count_unk(ds, tokenizer, field="dialogue", batch_size=1024):
    unk_id = tokenizer.unk_token_id
    total_unk, total_tokens = 0, 0

    for i in tqdm(range(0, len(ds), batch_size), desc="Tokenising"):
        batch_texts = ds[i : i + batch_size][field]
        enc = tokenizer(batch_texts, add_special_tokens=True, padding=False, truncation=False)
        for ids in enc["input_ids"]:
            arr = np.array(ids)
            total_unk += np.sum(arr == unk_id)
            total_tokens += len(arr)
    return total_unk, total_tokens

BEFORE adding emojis

In [7]:
tok_base = AutoTokenizer.from_pretrained("bert-base-uncased")
unk_stats_before = {}
for split in ["train", "validation", "test"]:
    unk_stats_before[split] = count_unk(raw_ds[split], tok_base)
print("\n[UNK] counts BEFORE adding emoji tokens")
for split, (u, t) in unk_stats_before.items():
    print(f"{split:<10}: {u:8d}  ({u/t:.3%} of tokens)")

Tokenising:   0%|          | 0/15 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (523 > 512). Running this sequence through the model will result in indexing errors
Tokenising: 100%|██████████| 15/15 [00:00<00:00, 21.67it/s]
Tokenising: 100%|██████████| 1/1 [00:00<00:00, 35.38it/s]
Tokenising: 100%|██████████| 1/1 [00:00<00:00, 33.18it/s]


[UNK] counts BEFORE adding emoji tokens
train     :     3758  (0.185% of tokens)
validation:      191  (0.174% of tokens)
test      :      195  (0.170% of tokens)





สร้าง EMOJI_TOKENS

In [9]:
# ถ้า kernel เพิ่งรีสตาร์ต ตัวแปรจะหายหมด
# สร้างชุด emoji ใหม่จาก raw_ds
from typing import List
import emoji

def extract_emojis(text: str) -> List[str]:
    return [ch for ch in text if ch in emoji.EMOJI_DATA]

emoji_set = set()
for split in ["train", "validation", "test"]:
    for dlg in raw_ds[split]["dialogue"]:
        emoji_set.update(extract_emojis(dlg))

EMOJI_TOKENS = sorted(emoji_set)          # ≈ 300-320 รายการ
print(f"Unique emojis found: {len(EMOJI_TOKENS)}")


Unique emojis found: 305


Extend tokenizer with emojis + speaker tags

In [18]:
from transformers import AutoTokenizer

# ---------- 1) โหลด tokenizer ดั้งเดิม ----------
tok_base = AutoTokenizer.from_pretrained("bert-base-uncased")
vocab_orig = len(tok_base)

# ---------- 2) เตรียมชุด token ใหม่ ----------
#   • EMOJI_TOKENS  : ทุกอิโมจิที่ “พบอย่างน้อย 1 ครั้ง” ใน SAMSum
#   • SPEAKER_TOKENS: [S1] – [S10]
SPEAKER_TOKENS = [f"[S{i}]" for i in range(1, 11)]
new_tokens = EMOJI_TOKENS + SPEAKER_TOKENS

# ---------- 3) สร้าง tokenizer สำเนาแล้วเพิ่ม token ----------
tok_ext = AutoTokenizer.from_pretrained("bert-base-uncased")
added = tok_ext.add_tokens(new_tokens)
vocab_new = len(tok_ext)

# ---------- 4) แสดงผล ----------
print(f"Original vocab size : {vocab_orig}")
print(f"Added new tokens     : {added}  "
      f"(emoji = {len(EMOJI_TOKENS)}, speaker = {len(SPEAKER_TOKENS)})")
print(f"New vocab size       : {vocab_new}")

# (Optional) พิมพ์ตัวอย่างอิโมจิ 20 ตัวแรก
print("\nFirst 20 emoji tokens:", EMOJI_TOKENS[:20])

tok_ext.save_pretrained("tokenizer_samsum_su")   # โฟลเดอร์ใหม่

Original vocab size : 30522
Added new tokens     : 315  (emoji = 305, speaker = 10)
New vocab size       : 30836

First 20 emoji tokens: ['‼', '⏱', '☀', '☂', '☔', '☕', '☘', '☝', '☠', '☢', '☹', '☺', '♀', '♂', '♥', '♻', '⚪', '⚫', '⚰', '⚽']


('tokenizer_samsum_su/tokenizer_config.json',
 'tokenizer_samsum_su/special_tokens_map.json',
 'tokenizer_samsum_su/vocab.txt',
 'tokenizer_samsum_su/added_tokens.json',
 'tokenizer_samsum_su/tokenizer.json')

AFTER adding emojis

In [12]:
unk_stats_after = {}
for split in ["train", "validation", "test"]:
    unk_stats_after[split] = count_unk(raw_ds[split], tok_ext)
print("\n[UNK] counts AFTER adding emoji tokens")
for split, (u, t) in unk_stats_after.items():
    print(f"{split:<10}: {u:8d}  ({u/t:.3%} of tokens)")

Tokenising:   0%|          | 0/15 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (523 > 512). Running this sequence through the model will result in indexing errors
Tokenising: 100%|██████████| 15/15 [00:00<00:00, 21.41it/s]
Tokenising: 100%|██████████| 1/1 [00:00<00:00, 33.54it/s]
Tokenising: 100%|██████████| 1/1 [00:00<00:00, 31.56it/s]


[UNK] counts AFTER adding emoji tokens
train     :      451  (0.022% of tokens)
validation:        4  (0.004% of tokens)
test      :       24  (0.021% of tokens)





reduction check in UNKs

In [13]:
print("\nΔ [UNK] (before ➜ after):")
for split in ["train", "validation", "test"]:
    u0, _ = unk_stats_before[split]
    u1, _ = unk_stats_after[split]
    print(f"{split:<10}: {u0-u1:+d}  fewer UNKs  (↓{(u0-u1)/u0:.2%})")


Δ [UNK] (before ➜ after):
train     : +3307  fewer UNKs  (↓88.00%)
validation: +187  fewer UNKs  (↓97.91%)
test      : +171  fewer UNKs  (↓87.69%)


## Preprocess SAMSum Dateset

Speaker-name mapping → [S#]

In [14]:
# ---------------------------------------------------------
# 4) Helper to replace speaker names by [S#]
# ---------------------------------------------------------
SPEAKER_RE = re.compile(r"^([^:]+):\s*(.*)$")

def map_speakers(dialogue: str, max_speakers: int = 10
                 ) -> Tuple[str, Dict[str, str]]:
    """
    Returns dialogue with names replaced by [S#] and a mapping dict.
    """
    speaker_map, next_id = {}, 1
    new_lines = []
    for line in dialogue.split("\n"):
        m = SPEAKER_RE.match(line)
        if not m:                # safety – keep line as is
            new_lines.append(line)
            continue
        name, utt = m.groups()
        if name not in speaker_map:
            if next_id > max_speakers:      # truncate extra speakers
                name_token = "[SUNK]"
            else:
                name_token = f"[S{next_id}]"
                speaker_map[name] = name_token
                next_id += 1
        new_lines.append(f"{speaker_map.get(name, '[SUNK]')}: {utt}")
    return "\n".join(new_lines), speaker_map


Insert [SEP] after every utterance

In [15]:
def add_sep_every_utt(dialogue: str) -> str:
    lines = [l + " [SEP]" for l in dialogue.split("\n") if l.strip()]
    return " ".join(lines)


Switching-Utterance corruption
- Hyper-parameters: Pu = 1.0, Pn = 0/1

โดยที่

Pu (permute-utterance prob.) ความน่าจะเป็นที่ แต่ละ utterance จะถูกเลือก ใส่ลงในชุดที่นำไปสับตำแหน่ง

- pu = 1.0 แสดงว่าบังคับเลือกทุกบรรทัดแล้วค่อยสับคำแบบสุ่ม

Pn (name-mask prob.) ความน่าจะเป็นที่ token [S#] ด้านหน้าจะถูกเปลี่ยนเป็น [MASK]

- pn = 0.0 แสดงว่า ไม่ mask, โมเดลเห็น speaker tag

- pn = 1.0 แสดงว่า mask หมด, บังคับดู context

In [None]:
def make_switching_utterance(dialogue: str,
                             pu: float = 1.0,
                             pn: float = 0.0,
                             rng: random.Random = random
                            ) -> Tuple[str, List[int]]:
    """
    • dialogue  - speaker-tokenised, SEP-inserted string
    • pu        - prob. an utterance is selected for permutation
    • pn        - prob. we MASK the speaker token (⇒ [MASK])
    Returns:
        corrupted_dialogue, labels_per_utt  (1 = permuted (สลับบทพูด), 0 = original)
    """
    # 1) split back into utterances
    utts = [u.strip() for u in dialogue.split("[SEP]") if u.strip()]
    idxs = list(range(len(utts)))

    # 2) pick indices to permute
    perm_idx = [i for i in idxs if rng.random() < pu]
    shuffled = perm_idx.copy()
    rng.shuffle(shuffled)                 # in-place
    perm_map = dict(zip(perm_idx, shuffled))

    # 3) build new utterance list, labels
    new_utts, labels = [], []
    for i in idxs:
        src = perm_map.get(i, i)          # swapped or same
        u = utts[src]
        # optionally mask speaker token ([S#]: → [MASK]:)
        if rng.random() < pn:
            u = re.sub(r"^\[S\d+\]", "[MASK]", u)
        new_utts.append(u)
        labels.append(int(src != i))      # 1 if permuted
    corrupted = " [SEP] ".join(new_utts) + " [SEP]"
    return corrupted, labels


## Switching-Utterance (SU) pre-training dataset

In [17]:
# ---------------------------------------------------------
# 7) Create HF Datasets with tokenised inputs, attention,
#    SEP positions, and per-utterance labels
# ---------------------------------------------------------
MAX_LEN = 512                          # paper setting
Pu, Pn = 1.0, 0.0                      # best config in Table 2

def preprocess_example(example, split):
    # a) replace speakers & add SEP
    dlg, _ = map_speakers(example["dialogue"])
    dlg = add_sep_every_utt(dlg)

    # b) corruption
    corrupted, labels = make_switching_utterance(dlg, Pu, Pn)

    # c) tokenize (truncate if >512 tokens)
    enc = tok(corrupted,
              truncation=True, max_length=MAX_LEN,
              padding="max_length")
    
    # d) find SEP token positions (needed for loss later)
    sep_id = tok("[SEP]")["input_ids"][0]
    sep_positions = [i for i, id_ in enumerate(enc["input_ids"])
                     if id_ == sep_id][:len(labels)]  # clip if truncated

    enc["labels"] = labels[:len(sep_positions)]
    enc["sep_positions"] = sep_positions
    enc["dialogue_len"] = len(labels)
    return enc

su_ds = DatasetDict()
for split in ["train", "validation", "test"]:
    su_ds[split] = raw_ds[split].map(
        preprocess_example,
        fn_kwargs={"split": split},
        remove_columns=raw_ds[split].column_names,
        desc=f"Building SU {split}"
    )

su_ds.save_to_disk("data/samsum_switching_utterance")
print(su_ds)


Building SU train: 100%|██████████| 14732/14732 [00:07<00:00, 1850.04 examples/s]
Building SU validation: 100%|██████████| 818/818 [00:00<00:00, 1841.35 examples/s]
Building SU test: 100%|██████████| 819/819 [00:00<00:00, 1815.59 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 14732/14732 [00:00<00:00, 581300.38 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 818/818 [00:00<00:00, 204161.90 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 819/819 [00:00<00:00, 201261.72 examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels', 'sep_positions', 'dialogue_len'],
        num_rows: 14732
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels', 'sep_positions', 'dialogue_len'],
        num_rows: 818
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels', 'sep_positions', 'dialogue_len'],
        num_rows: 819
    })
})





ไฟล์ Arrow ถูกบันทึกไว้ที่ data/samsum_switching_utterance/ พร้อมฟิลด์ input_ids / attention_mask / labels / sep_positions / dialogue_len.

## Self-supervised Pre-training

ใช้ Dataset เฉพาะส่วนของ train ของ SAMSum มาทำการ pre_train แล้วใช้ validation ไว้ดู early-stopping / tuning ส่วน test ต้องไม่ถูกแตะ เพื่อไม่ให้โมเดล “เห็น” บทสนทนาที่จะใช้วัด ROUGE ภายหลัง

In [1]:
import torch
torch.cuda.empty_cache()
torch.cuda.ipc_collect()


Imports & helpers

In [2]:
import math
import torch
import random
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    get_linear_schedule_with_warmup
)
from datasets import load_from_disk

# -------------------------------
# CONFIG
# -------------------------------
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16 # paper 128
MAX_LEN    = 512 # paper 512
LR         = 3e-5
WARMUP     = 500
MAX_STEPS  = 5000




  from .autonotebook import tqdm as notebook_tqdm


Dataset & collate

In [3]:
# -------------------------------
# LOAD DATASET
# -------------------------------
dataset = load_from_disk("data/samsum_switching_utterance")

# -------------------------------
# COLLATE FUNCTION
# -------------------------------
def collate_fn(batch):
    keys = ["input_ids", "token_type_ids", "attention_mask"]
    inputs = {k: torch.tensor([b[k] for b in batch]) for k in keys}
    labels = [torch.tensor(b["labels"], dtype=torch.float) for b in batch]
    sep_pos = [torch.tensor(b["sep_positions"]) for b in batch]
    return inputs, labels, sep_pos


Model

In [5]:
# -------------------------------
# MODEL
# -------------------------------
class SepClassifier(nn.Module):
    def __init__(self, model_name="bert-base-uncased", dropout=0.1):
        super().__init__()
        config = AutoConfig.from_pretrained(model_name)
        self.bert = AutoModel.from_pretrained(model_name, config=config)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(config.hidden_size, 1)

    def forward(self, input_ids, attention_mask, token_type_ids, sep_positions):
        hidden_states = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        ).last_hidden_state

        # Collect hidden states at each [SEP] position
        # sep_vecs = [hidden_states[i, pos] for i, pos in enumerate(sep_positions)]
        sep_vecs = []
        for i, pos_tensor in enumerate(sep_positions):
            pos_tensor = pos_tensor.to(hidden_states.device).long()  # <-- เพิ่มการ cast
            sep_vecs.append(hidden_states[i].index_select(0, pos_tensor))  # (U_i, H)

        sep_vecs = torch.cat(sep_vecs, dim=0)  # Shape: (total_seps, hidden_size)
        logits = self.classifier(self.dropout(sep_vecs)).squeeze(-1)
        return logits


Training loop (train model until the train loss converged (upper bounded by 5k steps)

In [None]:
# -------------------------------
# INITIALIZATION
# -------------------------------
model = SepClassifier().to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained("tokenizer_samsum_su")
model.bert.resize_token_embeddings(len(tokenizer)) # Adjust embedding size for extended tokens

train_loader = DataLoader(
    dataset["train"],
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn
)

optimizer = AdamW(model.parameters(), lr=LR)
total_steps = min(MAX_STEPS, len(train_loader))
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=WARMUP,
    num_training_steps=total_steps
)
loss_fn = nn.BCEWithLogitsLoss()

# -------------------------------
# TRAINING LOOP
# -------------------------------
step = 0
running_loss = 0.0
model.train()

for epoch in range(100):  # loop until MAX_STEPS reached
    for inputs, label_lists, sep_lists in train_loader:
        if step >= MAX_STEPS:
            break

        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        flat_labels = torch.cat(label_lists).to(DEVICE)

        logits = model(**inputs, sep_positions=sep_lists)
        loss = loss_fn(logits, flat_labels)

        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        running_loss += loss.item()
        step += 1

        if step % 100 == 0:
            print(f"Step {step:4d}/{total_steps} | Loss: {running_loss / 100:.4f}")
            running_loss = 0.0

    if step >= MAX_STEPS:
        break

# -------------------------------
# SAVE MODEL
# -------------------------------
torch.save(model.state_dict(), "bert_su_pretrained.pt")
print("Model saved to 'bert_su_pretrained.pt'")



Validation & early-stop (optional)

- Use the same DataLoader/loop on su_ds["validation"], compute average BCE loss; if it plateaus you can stop earlier than 5 k steps (what the authors mean by “until train loss converged”).

# Create Summarization Dataset

# Fine-tuning 