## ELECTRA的原理

ELECTRA（Efficiently Learning an Encoder that Classifies Token Replacements Accurately，高效训练编码器准确分类替换标记）

![](1.png)

弥补了BERT中预训练和微调之间不匹配的问题。

它在预训练时使用了`[MASK]`标记，但在针对下游任务的微调过程中，`[MASK]`标记并不存在，这导致了预训练和微调之间的不匹配

### ELECTRA生成器和判别器
生成器（Generator）和判别器（Discriminator）本质上是两个BERT模型，最终都是经过一个分类器，前者分类的是词的概率，后者分类的是标记的类别。

在训练时，两者会一起进行训练，损失函数也是一起计算的！

### 预训练ELECTRA生成器和判别器的使用

In [1]:
! pip install transformers==3.5.1

Collecting transformers==3.5.1
  Downloading transformers-3.5.1-py3-none-any.whl.metadata (32 kB)
Collecting tokenizers==0.9.3 (from transformers==3.5.1)
  Downloading tokenizers-0.9.3.tar.gz (172 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m172.0/172.0 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting sentencepiece==0.1.91 (from transformers==3.5.1)
  Downloading sentencepiece-0.1.91.tar.gz (500 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m500.5/500.5 kB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[?25h  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a su

In [9]:
from transformers import ElectraTokenizer, ElectraModel
import torch
import torch.nn.functional as F

In [3]:
# 加载生成器模型
generator_model = ElectraModel.from_pretrained('google/electra-small-generator')

# 加载判别器模型
discriminator_model = ElectraModel.from_pretrained('google/electra-small-discriminator')

# 加载ELECTRA的tokenizer
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/54.2M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/54.2M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



In [4]:
generator_model.config

ElectraConfig {
  "_name_or_path": "google/electra-small-generator",
  "architectures": [
    "ElectraForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "embedding_size": 128,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 1024,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "electra",
  "num_attention_heads": 4,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "summary_activation": "gelu",
  "summary_last_dropout": 0.1,
  "summary_type": "first",
  "summary_use_proj": true,
  "transformers_version": "4.44.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [5]:
discriminator_model.config

ElectraConfig {
  "_name_or_path": "google/electra-small-discriminator",
  "architectures": [
    "ElectraForPreTraining"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "embedding_size": 128,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 1024,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "electra",
  "num_attention_heads": 4,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "summary_activation": "gelu",
  "summary_last_dropout": 0.1,
  "summary_type": "first",
  "summary_use_proj": true,
  "transformers_version": "4.44.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [6]:
# 输入句子
sentence = "The quick brown fox jumps over the lazy dog."

# 对句子进行编码
inputs = tokenizer(sentence, return_tensors="pt")

print(inputs)

{'input_ids': tensor([[  101,  1996,  4248,  2829,  4419, 14523,  2058,  1996, 13971,  3899,
          1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [10]:
# 获取词汇表大小和隐藏层大小
vocab_size = tokenizer.vocab_size
hidden_size = generator_model.config.hidden_size
print(vocab_size)
print(hidden_size)

30522
256


In [11]:
# 生成器的线性层，用于将隐藏状态映射到词汇表大小
generator_fc = torch.nn.Linear(hidden_size, vocab_size)

# 判别器的线性层，用于二分类，判断单词是否被替换
discriminator_fc = torch.nn.Linear(hidden_size, 1)

In [12]:
mask_token_index = torch.tensor([4])  # 假设 'fox' 是被掩码的词
inputs['input_ids'][0, mask_token_index] = tokenizer.mask_token_id  # 替换为 [MASK]
print(inputs)

{'input_ids': tensor([[  101,  1996,  4248,  2829,   103, 14523,  2058,  1996, 13971,  3899,
          1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


可以看到，为替换前是4419，替换后是103

In [13]:
# 通过生成器模型预测掩码位置的单词
with torch.no_grad():
    generator_outputs = generator_model(**inputs)
    hidden_states = generator_outputs.last_hidden_state  # 获取隐藏状态
    logits = generator_fc(hidden_states)  # 通过生成器的线性层得到词汇表大小的logits

In [14]:
print(logits)

tensor([[[ 1.6760, -0.2789, -0.4634,  ...,  0.9555,  1.0139,  0.7684],
         [ 0.7700,  1.9130,  0.8008,  ...,  0.3682,  0.2014, -0.6246],
         [ 0.7751,  0.1468, -0.2396,  ...,  0.2797,  1.7145,  0.5617],
         ...,
         [-0.0614, -0.0751, -0.6678,  ..., -0.6557,  0.2386,  0.5648],
         [ 1.6732, -0.2821, -0.4623,  ...,  0.9514,  1.0100,  0.7638],
         [ 1.6760, -0.2789, -0.4633,  ...,  0.9555,  1.0139,  0.7684]]])


In [15]:
# 选择掩码位置的预测单词
mask_logits = logits[0, mask_token_index]  # 获取掩码位置的logits
predicted_token_id = torch.argmax(mask_logits, dim=-1)  # 找到概率最高的单词
predicted_token = tokenizer.decode(predicted_token_id)
print(predicted_token)

grape


In [17]:
# 用生成器的预测替换原始句子中的掩码位置
inputs_with_replaced_token = inputs['input_ids'].clone()
inputs_with_replaced_token[0, mask_token_index] = predicted_token_id
print(inputs_with_replaced_token)

tensor([[  101,  1996,  4248,  2829, 14722, 14523,  2058,  1996, 13971,  3899,
          1012,   102]])


生成器生成的是14722

In [18]:
# 将生成器生成的句子送入判别器
with torch.no_grad():
    discriminator_outputs = discriminator_model(input_ids=inputs_with_replaced_token)
    hidden_states = discriminator_outputs.last_hidden_state  # 获取判别器的隐藏状态
    discriminator_logits = discriminator_fc(hidden_states)  # 通过判别器的线性层得到二分类logits

In [19]:
# 判别器输出：每个位置的二分类结果（0表示真实词，1表示替换词）
predictions = torch.round(torch.sigmoid(discriminator_logits))  # 使用sigmoid并取整

In [20]:
print(predictions)

tensor([[[1.],
         [0.],
         [1.],
         [1.],
         [1.],
         [0.],
         [1.],
         [0.],
         [1.],
         [0.],
         [0.],
         [1.]]])


In [21]:
# 8. 打印出判别器的判断结果
for i, (token_id, is_replaced) in enumerate(zip(inputs_with_replaced_token[0], predictions[0])):
    token = tokenizer.decode(token_id)
    label = "替换词" if is_replaced.item() == 1 else "真实词"
    print(f"Token: {token}, 判别结果: {label}")

Token: [ C L S ], 判别结果: 替换词
Token: t h e, 判别结果: 真实词
Token: q u i c k, 判别结果: 替换词
Token: b r o w n, 判别结果: 替换词
Token: g r a p e, 判别结果: 替换词
Token: j u m p s, 判别结果: 真实词
Token: o v e r, 判别结果: 替换词
Token: t h e, 判别结果: 真实词
Token: l a z y, 判别结果: 替换词
Token: d o g, 判别结果: 真实词
Token: ., 判别结果: 真实词
Token: [ S E P ], 判别结果: 替换词


判断错误的原因可能有以下几种

#### 判别器预期的输入
判别器训练的任务是基于“原始句子”和“生成器生成的替换词”来判断哪些是被替换的词。判别器期望在输入中看到生成器替换的单词，而不是仅仅将生成器预测的单词插入其中。如果直接用生成器生成的句子作为输入给判别器，判别器没有足够的信息去判断哪些词被替换了。

#### 缺乏噪声对比（真实 vs 替换）
在 ELECTRA 的训练过程中，训练样本中不仅包括被替换的单词，还包括一些没有被替换的单词。判别器基于这些真实单词和生成的单词来学习判断。因此，如果我们只替换一个单词（如 fox），判别器可能无法有效判断哪些词是真实的。

#### 未训练判别器（线性层未微调）
我们定义了一个判别器的线性层 discriminator_fc，但在这个设置中，判别器的头部没有经过训练。Electra 预训练时，生成器和判别器都通过大量数据进行了共同训练。如果直接使用预训练模型的输出，并手动添加线性层而不进行微调，判别器的输出可能不准确。