In [145]:
import random
import spacy

from datasets import load_dataset
from itertools import tee
from tqdm import tqdm

In [142]:
nlp = spacy.load("en_core_web_lg")
dataset = load_dataset("samsum")

Found cached dataset samsum (C:/Users/user/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)
100%|██████████| 3/3 [00:00<00:00, 130.42it/s]


In [516]:
def pairwise(iterable, reverse=False):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    if reverse:
        a, b = tee(reversed(iterable))
    else:
        a, b = tee(iterable)
    next(b, None)
    return zip(a, b)

def extract_speaker_names(data_list):
    speaker_names = set()
    for data in data_list:
        dialog = data["dialogue"]
        for line in dialog.split("\n"):
            speaker = line.split(":")[0].strip()
            if speaker:
                speaker_names.add(speaker)
    return list(speaker_names)

def extract_entities(data_list):
    entities = {}
    for data in tqdm(data_list):
        dialog = data["dialogue"]
        turns = dialog.split("\n")
        for turn in turns:
            turn.replace("\r", "")
            turn_doc = nlp(turn)
            dialog_entities = [ent for ent in turn_doc.ents if not any(char.isdigit() for char in ent.text)]
            for ent in dialog_entities:
                if ent.label_ not in entities.keys():
                    entities[ent.label_] = set()
                entities[ent.label_].add(ent.text)
    
    for k in entities.keys():
        entities[k] = list(entities[k])
    
    return entities

def extract_digits(num):
    num_ = None
    if num.text.isdigit():
        num_ = int(num.text)
    else:
        num_= int(''.join(filter(str.isdigit, num.text)))
    
    return num_

In [165]:
entities = extract_entities(dataset["train"])

100%|██████████| 14732/14732 [08:26<00:00, 29.06it/s]


In [379]:
speaker_names = extract_speaker_names(dataset["train"])
entities["PERSON"] = speaker_names
pronouns = ["he", "she", "it", "they", "him", "her", "them", "his", "hers", "its", "theirs"]
modals = ["shall", "should", "can", "could", "will", "would", "may", "must", "might"]
verb_form = ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"]
modal_groups = {
        "shall": ["shall", "should"],
        "will": ["will", "'ll", "would"],
        "can": ["can", "could"],
        "may": ["may", "might"],
        "must": ["must"]
    }
all_modals = list(set(modals + [modal for group in modal_groups.values() for modal in group]))

In [728]:
def update_labels(labels, label_type, start_idx, end_idx, new_token):
    new_labels = []
    token_length = len(new_token.split())
    for i in range(len(labels)):
        if i < start_idx:
            new_labels.append(labels[i])
        elif i == start_idx:
            new_labels.append(f"B-{label_type}")
            for _ in range(1, token_length):
                new_labels.append(f"I-{label_type}")
        elif i > start_idx and i < end_idx:
            continue
        else:
            new_labels.append(labels[i])

    return new_labels

In [729]:
label_tags = ["O", "B-E", "I-E", "B-P", "I-P", "B-V", "I-V", "B-N", "I-N", "B-Q", "I-Q"]

label2id = {label: i for i, label in enumerate(label_tags)}
id2label = {i: label for i, label in enumerate(label_tags)}

## Entity & Pronoun Distortion

In [480]:
def replace_entities(summary, labels, alpha=0.5, **kwargs):
    dialog = kwargs["dialog"]
    dialog_entities = []
    turns = dialog.split("\n")
    for turn in turns:
        turn.replace("\r", "")
        turn_doc = nlp(turn)
        dialog_entities.extend([ent for ent in turn_doc.ents if not any(char.isdigit() for char in ent.text)])

    summary_doc = nlp(summary)
    summary_entities = [ent for ent in summary_doc.ents if not any(char.isdigit() for char in ent.text)]

    num_entities_to_replace = round(alpha * len(summary_entities))
    summary_entities = random.sample(summary_entities, num_entities_to_replace)
    summary_entities.sort(key=lambda ent: ent.start)

    if len(summary_entities) == 0:
        return summary, labels

    for entity in reversed(summary_entities):
        candidates = list(set([ent.text for ent in dialog_entities if ent.label_ == entity.label_]))
        if len(candidates) > 1:
            new_entity = random.choice(candidates)
            while new_entity == entity.text.lower():
                new_entity = random.choice(candidates)
        else:
            new_entity = random.choice(entities[entity.label_])
            while new_entity == entity.text.lower():
                new_entity = random.choice(entities[entity.label_])
        
        labels = update_labels(labels, "E", entity.start, entity.end, new_entity)
        summary = summary[:entity.start_char] + new_entity + summary[entity.end_char:]

    return summary, labels

In [481]:
def replace_pronoun(summary, labels, alpha=0.5, **kwargs):
    summary_doc = nlp(summary)
    summary_pronouns = [token for token in summary_doc if token.pos_ == "PRON"]
    
    num_pronoun_to_replace = round(alpha * len(summary_pronouns))
    summary_pronouns = random.sample(summary_pronouns, num_pronoun_to_replace)
    
    for token in reversed(summary_doc):
        if token.pos_ == "PRON" and token in summary_pronouns:
            pronoun = token
            new_pronoun = random.choice(pronouns)
            while new_pronoun == pronoun.text.lower():
                new_pronoun = random.choice(pronouns)
            
            labels = update_labels(labels, "P", pronoun.i, pronoun.i + 1, new_pronoun)
            summary = summary[:pronoun.idx] + new_pronoun + summary[pronoun.idx + len(pronoun):]
    
    return summary, labels

## Verb Distortion

In [482]:
def replace_verb(summary, labels, alpha=0.5, **kwargs):
    # Extract verbs from the dialog
    dialog = kwargs["dialog"]    
    dialog_doc = nlp(dialog)
    dialog_verbs = [token.text for token in dialog_doc if token.pos_ == "VERB"]
    dialog_verbs = list(set(dialog_verbs))

    summary_doc = nlp(summary)
    summary_verbs = [token for token in summary_doc if token.pos_ == "VERB"]

    num_verb_to_replace = round(alpha * len(summary_verbs))
    summary_verbs = random.sample(summary_verbs, num_verb_to_replace)

    # Replace verbs in the summary if only there are more than 1 verb in the dialogue
    if len(dialog_verbs) > 1:
        for token in reversed(summary_doc):
            if token.pos_ == "VERB" and token in summary_verbs:
                verb = token
                new_verb = random.choice(dialog_verbs)
                while new_verb == verb.text.lower():
                    new_verb = random.choice(dialog_verbs)
                
                labels = update_labels(labels, "V", verb.i, verb.i + 1, new_verb)
                summary = summary[:verb.idx] + new_verb + summary[verb.idx + len(verb):]

    return summary, labels

In [495]:
def replace_verb_form(summary, labels, alpha=0.5, **kwargs):
    summary_doc = nlp(summary)
    summary_verbs = [token for token in summary_doc if token.pos_ == "VERB"]

    num_verb_to_replace = round(alpha * len(summary_verbs))
    summary_verbs = random.sample(summary_verbs, num_verb_to_replace)

    for token in reversed(summary_doc):
        if token.pos_ == "VERB" and token in summary_verbs:
            verb = token
            
            new_form = random.choice(verb_form)
            new_verb = token._.inflect(new_form)
            if new_verb is None:
                continue
            while new_verb == verb.text.lower():
                new_form = random.choice(verb_form)
                new_verb = token._.inflect(new_form)
            
            labels = update_labels(labels, "V", verb.i, verb.i + 1, new_verb)
            summary = summary[:verb.idx] + new_verb + summary[verb.idx + len(verb):]
    
    return summary, labels

In [485]:
def replace_modal(summary, labels, alpha=0.5, **kwargs):
    summary_doc = nlp(summary)
    summary_modals = [token for token in summary_doc if token.pos_ == "MD" or token.text == "'ll"]

    num_modal_to_replace = round(alpha * len(summary_modals))
    summary_modals = random.sample(summary_modals, num_modal_to_replace)
    
    for token in reversed(summary_doc):
        if token.tag_ == "MD" and token in summary_modals:
            modal = token
            modal_group = None
            # Find the group that this modal belongs to
            for group, modals in modal_groups.items():
                if modal.text.lower() in modals:
                    modal_group = group
                    break
            
            new_modal = random.choice([m for m in all_modals if m not in modal_groups.get(modal_group, [])])
            while new_modal == modal.text.lower():
                new_modal = random.choice(modals)

            labels = update_labels(labels, "V", modal.i, modal.i + 1, new_modal)
            summary = summary[:modal.idx] + new_modal + summary[modal.idx + len(modal):]

    return summary, labels

## Quantity Distortion

In [525]:
def replace_quantity(summary, labels, alpha=0.5, **kwargs):
    dialog = kwargs["dialog"]    
    dialog_doc = nlp(dialog)
    dialog_nums = [token for token in dialog_doc if token.pos_ == "NUM"]
    date_num = list(set([num.text for num in dialog_nums if num.ent_type_ == "DATE"]))
    time_num = list(set([num.text for num in dialog_nums if num.ent_type_ == "TIME"]))
    
    summary_doc = nlp(summary)
    summary_num = [token for token in summary_doc if token.pos_ == "NUM"]

    num_num_to_replace = round(alpha * len(summary_num))
    summary_num = random.sample(summary_num, num_num_to_replace)

    for token in reversed(summary_doc):
        if token.pos_ == "NUM" and token in summary_num and any(char.isdigit() for char in token.text):
            num = token
            if num.ent_type_ == "DATE":
                if len(date_num) > 1:
                    new_num = random.choice(date_num)
                    while new_num == num.text:
                        new_num = random.choice(date_num)
                else:
                    num_ = extract_digits(num)
                    if num_ > 1000:
                        new_num = str(random.randint(num_ - 20, num_ + 20))
                    else:
                        if num_ > 31:
                            new_num = str(random.randint(1, 31))
                        else:
                            new_num = str(random.randint(max(1, num_ - 5), min(31, num_ + 5)))
            elif num.ent_type_ == "TIME" or ":" in num.text:
                if len(time_num) > 1:
                    new_num = random.choice(time_num)
                    while new_num == num.text:
                        new_num = random.choice(time_num)
                else:
                    if ":" in num.text or "." in num.text:
                        if ":" in num.text:
                            num_ = num.text.split(":")
                        elif "." in num.text:
                            num_ = num.text.split(".")
                        hour = int(num_[0])
                        minute = int(num_[1])
                        new_num = str(random.randint(max(0, hour - 5), min(23, hour + 5))) + ":" + str(random.randint(max(0, minute - 10), min(59, minute + 10)))
                    else:
                        num_ = extract_digits(num)
                        if num_ > 23:
                            new_num = str(random.randint(0, 23))
                        else:
                            new_num = str(random.randint(max(0, num_ - 5), min(23, num_ + 5)))
            else:
                num_ = extract_digits(num)
                new_num = str(random.randint(max(1, num_ - 10), num_ + 10))
            
            labels = update_labels(labels, "Q", num.i, num.i + 1, new_num)
            summary = summary[:num.idx] + new_num + summary[num.idx + len(num):]
    
    return summary, labels

## Noun Distortion

In [461]:
def replace_noun(summary, labels, alpha=0.5, **kwargs):
    # Extract nouns from the dialog
    dialog = kwargs["dialog"]
    dialog_doc = nlp(dialog)
    dialog_nouns = [token.text for token in dialog_doc if token.pos_ == "NOUN" and not token.ent_type_]
    dialog_nouns = list(set(dialog_nouns))
    
    summary_doc = nlp(summary)
    summary_nouns = [token for token in summary_doc if token.pos_ == "NOUN" and not token.ent_type_]

    num_noun_to_replace = round(alpha * len(summary_nouns))
    summary_nouns = random.sample(summary_nouns, num_noun_to_replace)

    # Replace nouns in the summary if only there are more than 1 noun in the dialogue
    if len(dialog_nouns) > 1:
        for token in reversed(summary_doc):
            if token.pos_ == "NOUN" and not token.ent_type_ and token in summary_nouns:
                noun = token
                new_noun = random.choice(dialog_nouns)
                while new_noun == noun.text.lower():
                    new_noun = random.choice(dialog_nouns)
                
                labels = update_labels(labels, "N", noun.i, noun.i + 1, new_noun)
                summary = summary[:noun.idx] + new_noun + summary[noun.idx + len(noun):]

    return summary, labels

# Experiment

In [747]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, AdamW

# Load the pre-trained BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', use_fast=True)

# Load the pre-trained BERT model with a token classification head
model = AutoModelForTokenClassification.from_pretrained('bert-base-cased', id2label=id2label, label2id=label2id)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

In [735]:
from tqdm import tqdm
from datasets import Dataset, DatasetDict

X = 0.5  # percentage of data to distort
Y = 0.5  # percentage of distort functions to use
distort_funcs = [
    replace_entities,
    replace_noun,
    replace_pronoun,
    replace_verb,
    replace_verb_form,
    replace_modal,
    replace_quantity,
    replace_noun,
]

def distort_data(summary, labels, dialogue):
    # randomly choose distort functions
    num_funcs = len(distort_funcs)
    num_to_use = int(Y * num_funcs)
    funcs_to_use = random.sample(distort_funcs, num_to_use)

    # apply distort functions
    for func in funcs_to_use:
        summary, labels = func(summary, labels, dialog=dialogue)

    return summary, labels

def create_data_point(data, distort):
    summary_doc = nlp(data["summary"])
    summary = " ".join(token.text for token in summary_doc)
    labels = ["O"] * len(summary.split())
    if distort:
        summary, labels = distort_data(summary, labels, data["dialogue"])
    summary = summary.split()
    
    if len(summary) != len(labels):
        print([token.text for token in summary_doc])
        print(summary)
        print(labels)
        raise Exception()

    return summary, labels

# create distorted and non-distorted data
new_dataset = DatasetDict()

for split in dataset.keys():
    ids = []
    dialogues = []
    ref_summaries = []
    distorted_summaries = []
    labels = []

    total_data = len(dataset[split])
    num_to_distort = int(X * total_data)
    indices_to_distort = random.sample(range(total_data), num_to_distort)

    for i, data in tqdm(enumerate(dataset[split])):
        distorted_summary, raw_labels = create_data_point(data, i in indices_to_distort)
        dialog_doc = nlp(data["dialogue"])
        dialog = [token.text for token in dialog_doc]

        ids.append(data["id"])
        dialogues.append(dialog)
        ref_summaries.append(data["summary"])
        distorted_summaries.append(distorted_summary)
        labels.append([label2id[label] for label in raw_labels])
    new_dataset[split] = Dataset.from_dict({
        "ids": ids,
        "dialogues": dialogues,
        "ref_summaries": ref_summaries,
        "distorted_summaries": distorted_summaries,
        "labels": labels,
    })

14732it [12:14, 20.07it/s]
819it [00:41, 19.97it/s]
818it [00:40, 20.27it/s]


In [739]:
def align_labels_with_tokens(labels, word_ids, context_len):
    new_labels = []
    current_word = None
    for i in range(len(word_ids)):
        if i < context_len + 2:
            new_labels.append(-100)
        else:
            if word_ids[i] != current_word:
                # Start of a new word!
                current_word = word_ids[i]
                label = -100 if word_ids[i] is None else labels[word_ids[i]]
                new_labels.append(label)
            else:
                # Special token or same word as prev. token
                new_labels.append(-100)

    return new_labels

In [740]:
def preprocess_data(data):
    tokenized_inputs = tokenizer(data['dialogues'], data['distorted_summaries'], is_split_into_words=True, truncation=True, max_length=512)
    new_labels = []
    for i, labels in enumerate(data["labels"]):
        word_ids = tokenized_inputs.word_ids(i)
        dialogue = data["dialogues"][i]
        context_len = len(tokenizer.tokenize(dialogue, is_split_into_words=True))
        new_labels.append(align_labels_with_tokens(labels, word_ids, context_len))
    tokenized_inputs["labels"] = new_labels
    return tokenized_inputs

In [741]:
tokenized_datasets = new_dataset.map(
    preprocess_data,
    batched=True,
    remove_columns=new_dataset["train"].column_names,
)

                                                                   

In [750]:
import evaluate
import numpy as np

from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
metric = evaluate.load("seqeval")

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

In [751]:
from transformers import TrainingArguments
from transformers import Trainer

args = TrainingArguments(
    output_dir=f"./model/span-predictor/bert-base-cased",
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_dir=f"./model/span-predictor/bert-base-cased",
    logging_strategy="steps",
    logging_steps=500,
    num_train_epochs=5,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)
trainer.train()

  0%|          | 0/9210 [00:00<?, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  0%|          | 11/9210 [01:52<32:18:18, 12.64s/it]

KeyboardInterrupt: 

In [753]:
torch.cuda.is_available()

False