In [104]:
from collections import defaultdict
from tqdm.auto import tqdm
from multiprocessing import cpu_count
import numpy as np

import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    default_data_collator,
    TrainingArguments,
    Trainer,
    get_scheduler,
    pipeline
)
from datasets import load_dataset
from accelerate import Accelerator

In [34]:
model_checkpoint = "distilbert-base-uncased"
# Commit ID at time of executing this notebook
model_commit_id = "6cdc0aad91f5ae2e6712e91bc7b65d1cf5c05411"

In [35]:
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, revision=model_commit_id)

In [37]:
num_tgt_cores = max(1, cpu_count() - 1)

# Off-the-shelf model

In [39]:
distilbert_num_params = model.num_parameters()

In [40]:
print(f"{round(distilbert_num_params / 1e6)}M parameters")

67M parameters


In [41]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, revision=model_commit_id)

In [42]:
text = "This is a great [MASK]."

In [43]:
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
    token_logits = model(**inputs).logits

In [44]:
token_logits.shape

torch.Size([1, 8, 30522])

In [45]:
tokenizer.decode(inputs.input_ids.squeeze().tolist())

'[CLS] this is a great [MASK]. [SEP]'

In [46]:
tokenizer.convert_ids_to_tokens(inputs.input_ids.squeeze().tolist())

['[CLS]', 'this', 'is', 'a', 'great', '[MASK]', '.', '[SEP]']

In [47]:
with torch.no_grad():
    mask_token_idx = torch.where(inputs.input_ids == tokenizer.mask_token_id)[1]
    mask_token_logits = token_logits.squeeze(0)[mask_token_idx, :]
# Choose top 5 candidates for [MASK]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

In [48]:
for i, token in enumerate(top_5_tokens):
    print(f"{i}: {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}")

0: This is a great deal.
1: This is a great success.
2: This is a great adventure.
3: This is a great idea.
4: This is a great feat.


# Dataset

In [49]:
dataset_checkpoint = "imdb"
dataset_commit_id = "9c6ede893febf99215a29cc7b72992bb1138b06b"

In [50]:
imdb_dataset = load_dataset(dataset_checkpoint, revision=dataset_commit_id)

In [51]:
imdb_dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

In [52]:
sample = imdb_dataset["train"].shuffle(seed=42).select(range(3))

In [53]:
for i, row in enumerate(sample):
    gap = ""
    if i != 0:
        gap = '\n'
    print(f"{gap}>>> Review: {row['text']}")
    print(f">>> Label: {row['label']}")

>>> Review: There is no relation at all between Fortier and Profiler but the fact that both are police series about violent crimes. Profiler looks crispy, Fortier looks classic. Profiler plots are quite simple. Fortier's plot are far more complicated... Fortier looks more like Prime Suspect, if we have to spot similarities... The main character is weak and weirdo, but have "clairvoyance". People like to compare, to judge, to evaluate. How about just enjoying? Funny thing too, people writing Fortier looks American but, on the other hand, arguing they prefer American series (!!!). Maybe it's the language, or the spirit, but I think this series is more English than American. By the way, the actors are really good and funny. The acting is not superficial at all...
>>> Label: 1

>>> Review: This movie is a great. The plot is very true to the book which is a classic written by Mark Twain. The movie starts of with a scene where Hank sings a song with a bunch of kids called "when you stub your

# Data preprocessing

In [54]:
def tokenize_function(examples):
    results = tokenizer(examples["text"])
    n_examples = len(results["input_ids"])
    if tokenizer.is_fast:
        # Match up tokens with the corresponding word in the original input
        results["word_ids"] = [results.word_ids(e) for e in range(n_examples)]
    return results

In [56]:
tokenized_datasets = imdb_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text", "label"],
    num_proc=num_tgt_cores
)

Map (num_proc=15):   0%|          | 0/25000 [00:00<?, ? examples/s]

Map (num_proc=15):   0%|          | 0/25000 [00:00<?, ? examples/s]

Map (num_proc=15):   0%|          | 0/50000 [00:00<?, ? examples/s]

In [57]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids'],
        num_rows: 50000
    })
})

In [58]:
tokenizer.model_max_length

512

## Chunk data before tokenization

In [59]:
chunk_size = 128

In [60]:
tokenized_samples = tokenized_datasets["train"][:3]

In [61]:
for i, sample in enumerate(tokenized_samples["input_ids"]):
    print(f">>> Review {i} length: {len(sample)}")

>>> Review 0 length: 363
>>> Review 1 length: 304
>>> Review 2 length: 133


In [62]:
# Concatenate all examples
concatenated_examples = {
    k: sum(tokenized_samples[k], []) for k in tokenized_samples.keys()
}
total_length = len(concatenated_examples["input_ids"])

In [63]:
print(f">>> Concatenated reviews length: {total_length}")

>>> Concatenated reviews length: 800


In [64]:
# Split into chunks by 'block size'
chunks = {
    k: [t[i:i + chunk_size] for i in range(0, total_length, chunk_size)]
    for k, t in concatenated_examples.items()
}

In [65]:
for chunk in chunks["input_ids"]:
    print(f">>> Chunk length: {len(chunk)}")

>>> Chunk length: 128
>>> Chunk length: 128
>>> Chunk length: 128
>>> Chunk length: 128
>>> Chunk length: 128
>>> Chunk length: 128
>>> Chunk length: 32


### Create a repeatable function to do this
This should be batch-able for mapping over the full dataset

In [66]:
def group_texts(examples):
    # Concatenate all text
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # Drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i:i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

In [70]:
lm_datasets = tokenized_datasets.map(group_texts, batched=True)

In [71]:
# Now have more examples because of chunking
lm_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 61291
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 59904
    })
    unsupervised: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 122957
    })
})

In [72]:
tokenizer.decode(lm_datasets["train"][2]["input_ids"])

'arguably their answer to good old boy john ford, had sex scenes in his films. < br / > < br / > i do commend the filmmakers for the fact that any sex shown in the film is shown for artistic purposes rather than just to shock people and make money to be shown in pornographic theaters in america. i am curious - yellow is a good film for anyone wanting to study the meat and potatoes ( no pun intended ) of swedish cinema. but really, this film doesn\'t have much of a plot. [SEP] [CLS] " i am curious : yellow " is a risible and pretentious steaming pile. it doesn'

## Data collation
Need to set up random masking of text

In [73]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

How is masking applied to our data?

In [74]:
samples = [lm_datasets["train"].shuffle(seed=42)[i] for i in range(2)]
for sample in samples:
    _ = sample.pop("word_ids")
for chunk in data_collator(samples)["input_ids"]:
    print(f">>> {tokenizer.decode(chunk)}")
    # We are applying masks at the token level, not the word level
    # this means that sometimes only parts of words are masked
    print(f">>> {tokenizer.convert_ids_to_tokens(chunk)}")

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


>>> . because you got to [MASK] [MASK] as real characters this makes you like them more as an audience, and makes you more sympathetic to them as totally [MASK] victims of the [MASK] government, who you can not sympathise with [MASK] the singing of the students is correct because we [MASK] from accounts that the students in the riot [MASK] singing and dancing before it became violent. the clothing of the students in [MASK] [MASK] [MASK] is very similar to [MASK] clothing shown in photos [MASK] soweto [MASK] they made the movie actually [MASK] soweto, which is why it looks [MASK] accurate in many parts. all these things make the film [MASK] [MASK] for someone using
>>> ['.', 'because', 'you', 'got', 'to', '[MASK]', '[MASK]', 'as', 'real', 'characters', 'this', 'makes', 'you', 'like', 'them', 'more', 'as', 'an', 'audience', ',', 'and', 'makes', 'you', 'more', 'sympathetic', 'to', 'them', 'as', 'totally', '[MASK]', 'victims', 'of', 'the', '[MASK]', 'government', ',', 'who', 'you', 'can', 

### Custom data collation for whole-word masking

In [75]:
wwm_probability = 0.2

In [76]:
def whole_word_masking_data_collator(features, debug=False):
    for feature in features:
        word_ids = feature.pop("word_ids")

        # Create a map between words and corresp. token indices
        mapping = defaultdict(list)
        current_word_index = -1
        current_word = None
        for idx, word_id in enumerate(word_ids):
            if word_id is not None:
                # On a new word, not continuing a previous word
                if word_id != current_word:
                    current_word = word_id
                    current_word_index += 1
                mapping[current_word_index].append(idx)

        # Randomly mask words
        mask = np.random.binomial(1, wwm_probability, (len(mapping),))
        input_ids = feature["input_ids"]
        labels = feature["labels"]
        
        # Set labels to -100 for all tokens not belonging to masked words
        new_labels = [-100] * len(labels)
        for word_id in np.where(mask)[0]:
            word_id = word_id.item()
            for idx in mapping[word_id]:
                new_labels[idx] = labels[idx]
                input_ids[idx] = tokenizer.mask_token_id
        if debug:
            feature["full_labels"] = labels
        feature["labels"] = new_labels
    
    return default_data_collator(features)

Explore some samples using this collation approach

In [77]:
# Now, if a word is masked and is split up into multiple
# tokens, they are all masked together
samples = [lm_datasets["train"].shuffle(seed=42)[i] for i in range(2)]
batch = whole_word_masking_data_collator(samples, debug=True)

for i in range(len(batch["input_ids"])):
    print(f">>> {tokenizer.decode(batch['input_ids'][i])}")
    print(f">>> {tokenizer.convert_ids_to_tokens(batch['input_ids'][i])}")
    print(f">>> {tokenizer.convert_ids_to_tokens(batch['full_labels'][i])}")

>>> . because you got [MASK] see [MASK] as [MASK] [MASK] this makes you like them more [MASK] an audience [MASK] and makes you more [MASK] to [MASK] as totally the victims of the white [MASK], who you [MASK] not sympathise with. the [MASK] [MASK] the students is [MASK] because we know from [MASK] that the students [MASK] the riot [MASK] singing and [MASK] before [MASK] became [MASK]. [MASK] clothing of [MASK] students [MASK] sarafina is [MASK] similar to [MASK] clothing shown in [MASK] from [MASK] [MASK] [MASK]. they made the movie [MASK] in soweto, which is why it [MASK] very accurate in many parts. all [MASK] things make the [MASK] more accurate for someone using
>>> ['.', 'because', 'you', 'got', '[MASK]', 'see', '[MASK]', 'as', '[MASK]', '[MASK]', 'this', 'makes', 'you', 'like', 'them', 'more', '[MASK]', 'an', 'audience', '[MASK]', 'and', 'makes', 'you', 'more', '[MASK]', 'to', '[MASK]', 'as', 'totally', 'the', 'victims', 'of', 'the', 'white', '[MASK]', ',', 'who', 'you', '[MASK]',

# Fine-tuning

## Downsample the data to avoid long runtimes

In [78]:
train_size = 10_000
test_size = int(0.1 * train_size)

In [79]:
downsamp_dataset = lm_datasets["train"].train_test_split(
    train_size=train_size,
    test_size=test_size,
    seed=42
)

In [80]:
downsamp_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 10000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 1000
    })
})

## Set up trainer

In [81]:
batch_size = 64
lr = 2e-5
weight_decay = 1e-2
logging_steps = len(downsamp_dataset["train"]) // batch_size
model_name = model_checkpoint

In [None]:
# Since model doesn't use word_ids, need to make sure we don't delete these
# by setting remove_unused_columns to False
# If don't do this, can't use `whole_word_masking_data_collator` since
# it depends on the word_ids column being present
training_args = TrainingArguments(
    output_dir=f"../temp/07/{model_name}-finetuned-imdb",
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    learning_rate=lr,
    weight_decay=weight_decay,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    push_to_hub=False,
    fp16=True,  # Mixed-precision training for speed boost
    logging_steps=logging_steps,
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=downsamp_dataset["train"],
    eval_dataset=downsamp_dataset["test"],
    data_collator=whole_word_masking_data_collator,
    tokenizer=tokenizer
)

## Train and evaluate model

In [None]:
eval_results = trainer.evaluate()

In [None]:
print(f">>> Perplexity: {np.exp(eval_results['eval_loss']).round(2)}")

In [None]:
trainer.train()

In [None]:
eval_results = trainer.evaluate()

In [None]:
print(f">>> Perplexity: {np.exp(eval_results['eval_loss']).round(2)}")

## Train and evaluate model with Accelerate & custom loop

### Apply masking once to validation set
This allows for comparability of values across evaluations

In [82]:
def insert_random_mask(batch, collator):
    features = [dict(zip(batch, t)) for t in zip(*batch.values())]
    masked_inputs = collator(features)
    return {"masked_" + k: v.numpy() for k, v in masked_inputs.items()}

In [83]:
# downsamp_dataset_2 = downsamp_dataset.remove_columns(["word_ids"])
eval_dataset = downsamp_dataset["test"].map(
    lambda b: insert_random_mask(b, whole_word_masking_data_collator),
    batched=True,
    remove_columns=downsamp_dataset["test"].column_names
)
eval_dataset = eval_dataset.rename_columns({
    "masked_input_ids": "input_ids",
    "masked_attention_mask": "attention_mask",
    "masked_labels": "labels"
})

In [84]:
eval_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 1000
})

### Data loaders and model

In [93]:
lr = 5e-5
num_train_epochs = 3
# batch_size = 64 won't fit in GPU memory on my machine
batch_size = 32

In [94]:
# Use default data collator
train_dataloader = DataLoader(
    downsamp_dataset["train"],
    shuffle=True,
    batch_size=batch_size,
    collate_fn=whole_word_masking_data_collator
)

eval_dataloader = DataLoader(
    eval_dataset,
    batch_size=batch_size,
    collate_fn=default_data_collator
)

In [95]:
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)

In [96]:
optimizer = AdamW(model.parameters(), lr=lr)
accelerator = Accelerator()

In [97]:
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model,
    optimizer,
    train_dataloader,
    eval_dataloader
)

In [98]:
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

In [99]:
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

### Training loop

In [103]:
progress_bar = tqdm(range(num_training_steps))
output_dir = "../temp/07/distilbert-base-uncased-finetuned-imdb-accelerate"

for epoch in range(num_train_epochs):
    torch.cuda.empty_cache()
    # Training block
    model.train()
    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # Evaluation block
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)

        loss = outputs.loss
        losses.append(accelerator.gather(loss.repeat(batch_size)))

    losses = torch.cat(losses)
    losses = losses[:len(eval_dataset)]
    try:
        perplexity = torch.exp(torch.mean(losses)).item()
    except OverflowError:
        perplexity = float("inf")

    print(f">>> Epoch {epoch}, Perplexity: {perplexity}")
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(
        output_dir,
        save_function=accelerator.save
    )
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)

  0%|          | 0/939 [00:00<?, ?it/s]

>>> Epoch 0, Perplexity: 25.044612884521484
>>> Epoch 1, Perplexity: 24.35580062866211
>>> Epoch 2, Perplexity: 24.35580062866211


## Use the fine-tuned model in a pipeline

In [105]:
mask_filler = pipeline(
    "fill-mask", model=output_dir
)

In [106]:
preds = mask_filler(text)

for pred in preds:
    print(f">>> {pred['sequence']}")

>>> this is a great film.
>>> this is a great movie.
>>> this is a great idea.
>>> this is a great one.
>>> this is a great comedy.
