RTD - Replaced Token Detection

RTD is an alternative to the traditional Masked Language Model (MLM) objective. Instead of masking tokens and predicting them (as in BERT), RTD randomly replaces tokens in a sentence and trains the model to identify whether each token is the original or a replaced one.

In [1]:
from transformers import ElectraTokenizer, ElectraForPreTraining
import torch

# Load pre-trained ELECTRA model and tokenizer
tokenizer = ElectraTokenizer.from_pretrained("google/electra-small-discriminator")
model = ElectraForPreTraining.from_pretrained("google/electra-small-discriminator")

# Sample sentence
sentence = "The quick brown fox jumps over the lazy dog."

# Tokenize the input sentence
inputs = tokenizer(sentence, return_tensors="pt")

# Introduce token replacements to simulate the RTD task
# For demonstration, let's replace "fox" with "cat" and "lazy" with "sleepy"
inputs["input_ids"][0][4] = tokenizer.convert_tokens_to_ids("cat")  # Replacing "fox" with "cat"
inputs["input_ids"][0][7] = tokenizer.convert_tokens_to_ids("sleepy")  # Replacing "lazy" with "sleepy"

# Run the model on the modified sentence
outputs = model(**inputs)

# The model's output contains logits where each token has a "real" or "fake" prediction
predictions = torch.round(torch.sigmoid(outputs.logits))

# Decode the tokens and print out the RTD results
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
for token, prediction in zip(tokens, predictions[0]):
    status = "Real" if prediction.item() == 1 else "Replaced"
    print(f"Token: {token:<10} | Status: {status}")


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/54.2M [00:00<?, ?B/s]

Token: [CLS]      | Status: Replaced
Token: the        | Status: Replaced
Token: quick      | Status: Replaced
Token: brown      | Status: Replaced
Token: cat        | Status: Replaced
Token: jumps      | Status: Replaced
Token: over       | Status: Replaced
Token: sleepy     | Status: Replaced
Token: lazy       | Status: Replaced
Token: dog        | Status: Replaced
Token: .          | Status: Replaced
Token: [SEP]      | Status: Replaced


model.safetensors:   0%|          | 0.00/54.2M [00:00<?, ?B/s]