# ELECTRA on 🤗 Transformers 🤗

In [1]:
import transformers
transformers.__version__

'3.0.2'

In [2]:
# 불필요한 로깅 메시지 제거용
import logging

logging.getLogger().setLevel(logging.WARN)

## Model Architecture

![](https://user-images.githubusercontent.com/28896432/80024445-0f444e00-851a-11ea-9137-9da2abfd553d.png)

## TL;DR (Example)

### 1. Discriminator
- [electra-base-discriminator](https://huggingface.co/google/electra-base-discriminator#how-to-use-the-discriminator-in-transformers)
- Fake Token Detection

In [3]:
import torch
from transformers import ElectraForPreTraining, ElectraTokenizer
from pprint import pprint

discriminator = ElectraForPreTraining.from_pretrained("monologg/koelectra-base-discriminator")
tokenizer = ElectraTokenizer.from_pretrained("monologg/koelectra-base-discriminator")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=467.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=443133216.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=279173.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=51.0, style=ProgressStyle(description_w…




In [26]:
fake_sentence = "나는 왜 밥을 먹었다."
fake_inputs = tokenizer.encode(fake_sentence, return_tensors="pt") # tensors
?tokenizer.encode
print(fake_sentence, '-[encode]->', fake_inputs)
print(fake_inputs[0], '-[decode]->', tokenizer.decode(fake_inputs[0])) # a string /w special token
print()

fake_tokens = tokenizer.tokenize(fake_sentence)
fake_ids = tokenizer.convert_tokens_to_ids(fake_tokens) # list of int
print(fake_sentence, '-[tokenize]->', fake_tokens)
print(fake_tokens, '-[convert_tokens_to_ids]->', fake_ids)
print(fake_ids, '-[convert_ids_to_tokens]->', tokenizer.convert_ids_to_tokens(fake_ids))

# sentence = "나는 방금 밥을 먹었다."

나는 왜 밥을 먹었다. -[encode]-> tensor([[    2, 10841,  3579, 21509, 27660,    18,     3]])
tensor([    2, 10841,  3579, 21509, 27660,    18,     3]) -[decode]-> [CLS] 나는 왜 밥을 먹었다. [SEP]

나는 왜 밥을 먹었다. -[tokenize]-> ['나는', '왜', '밥을', '먹었다', '.']
['나는', '왜', '밥을', '먹었다', '.'] -[convert_tokens_to_ids]-> [10841, 3579, 21509, 27660, 18]
[10841, 3579, 21509, 27660, 18] -[convert_ids_to_tokens]-> ['나는', '왜', '밥을', '먹었다', '.']


In [25]:
discriminator_outputs = discriminator(fake_inputs)
print(fake_inputs)
print('y:', discriminator_outputs[0].data)
print('sin(y):', torch.sign(discriminator_outputs[0]).data)
print('(sin(y)+1)/2:', ((torch.sign(discriminator_outputs[0])+1)/2).data)
predictions = torch.round((torch.sign(discriminator_outputs[0]) + 1) / 2)

pprint(list(zip(fake_tokens, predictions.tolist()[1:-1])))

tensor([[    2, 10841,  3579, 21509, 27660,    18,     3]])
y: tensor([-10.2365,  -2.9496,   1.1244,  -3.4443,  -2.1990,  -3.4055, -10.2365])
sin(y): tensor([-1., -1.,  1., -1., -1., -1., -1.])
(sin(y)+1)/2: tensor([0., 0., 1., 0., 0., 0., 0.])
[('나는', 0.0), ('왜', 1.0), ('밥을', 0.0), ('먹었다', 0.0), ('.', 0.0)]


### 2. Generator

- [electra-base-generator](https://huggingface.co/google/electra-base-generator#how-to-use-the-generator-in-transformers)
- 기존 BERT의 Mask Token Prediction과 동일하다고 생각하면 됨!

In [31]:
from transformers import pipeline
from pprint import pprint

fill_mask = pipeline(
    "fill-mask",
    model="monologg/koelectra-base-generator",
    tokenizer="monologg/koelectra-base-generator"
)
s = "나는 {} 밥을 먹었다.".format(fill_mask.tokenizer.mask_token)
print(s)
pprint(fill_mask(s))

나는 [MASK] 밥을 먹었다.
[{'score': 0.07130879908800125,
  'sequence': '[CLS] 나는 식당에서 밥을 먹었다. [SEP]',
  'token': 26194,
  'token_str': '식당에서'},
 {'score': 0.04359052702784538,
  'sequence': '[CLS] 나는 방금 밥을 먹었다. [SEP]',
  'token': 24499,
  'token_str': '방금'},
 {'score': 0.029709946364164352,
  'sequence': '[CLS] 나는 다시 밥을 먹었다. [SEP]',
  'token': 10715,
  'token_str': '다시'},
 {'score': 0.02787844091653824,
  'sequence': '[CLS] 나는 앉아서 밥을 먹었다. [SEP]',
  'token': 23755,
  'token_str': '앉아서'},
 {'score': 0.025679906830191612,
  'sequence': '[CLS] 나는 내 밥을 먹었다. [SEP]',
  'token': 783,
  'token_str': '내'}]


## Detail Review for Code

- Pretraining의 경우는 원 저자의 Tensorflow 코드를 쓰는 것을 추천
- Huggingface Transformers의 코드는 Pretraining이 완료된 모델을 가져다 쓰는 용도
- [modeling_electra.py](https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_electra.py)

In [1]:
import torch
import torch.nn as nn

from transformers.modeling_bert import BertEmbeddings
from transformers.modeling_electra import ElectraPreTrainedModel

### 1. ElectraEmbeddings

BertEmbeddings와 동일!

In [2]:
class ElectraEmbeddings(BertEmbeddings):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__(config)
        self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(config.embedding_size, eps=config.layer_norm_eps)

### 2. ElectraModel

BertModel과 동일하지만 Pooler는 없음!

### 3. ElectraForPreTraining

- Electra model with a binary classification head on top as used during pre-training for identifying generated tokens.
- 개인적으로 이 클래스 이름을 좋아하진 않음
- 논문에서는 `ElectraModel`도 discriminator의 것을 사용하라고 하고 있음

```python
model = ElectraForPreTraining.from_pretrained("google/electra-base-discriminator")
```

In [15]:
from transformers import ElectraForPreTraining

model = ElectraForPreTraining.from_pretrained("google/electra-base-discriminator")

- `nn.Linear(config.hidden_size, 1)`를 통과한 후, `BCEWithLogitsLoss`를 사용하여 Sigmoid 적용!

In [25]:
class ElectraDiscriminatorPredictions(nn.Module):
    """Prediction module for the discriminator, made up of two dense layers."""

    def __init__(self, config):
        super().__init__()

        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dense_prediction = nn.Linear(config.hidden_size, 1)
        self.config = config

    def forward(self, discriminator_hidden_states, attention_mask):
        hidden_states = self.dense(discriminator_hidden_states)
        hidden_states = get_activation(self.config.hidden_act)(hidden_states)
        logits = self.dense_prediction(hidden_states).squeeze()

        return logits

In [26]:
class ElectraForPreTraining(ElectraPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.electra = ElectraModel(config)
        self.discriminator_predictions = ElectraDiscriminatorPredictions(config)
        self.init_weights()
    
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
    ):
        discriminator_hidden_states = self.electra(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            inputs_embeds,
            output_attentions,
            output_hidden_states,
        )
        discriminator_sequence_output = discriminator_hidden_states[0]

        logits = self.discriminator_predictions(discriminator_sequence_output, attention_mask)

        output = (logits,)

        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss()
            if attention_mask is not None:
                active_loss = attention_mask.view(-1, discriminator_sequence_output.shape[1]) == 1
                active_logits = logits.view(-1, discriminator_sequence_output.shape[1])[active_loss]
                active_labels = labels[active_loss]
                loss = loss_fct(active_logits, active_labels.float())
            else:
                loss = loss_fct(logits.view(-1, discriminator_sequence_output.shape[1]), labels.float())

            output = (loss,) + output

        output += discriminator_hidden_states[1:]

        return output  # (loss), scores, (hidden_states), (attentions)

### 4. ElectraForMaskedLM

- Electra model with a language modeling head on top.
- 우리가 아는 BERT의 Masked Token Prediction

```python
model = ElectraForMaskedLM.from_pretrained('google/electra-base-generator')
```

In [27]:
from transformers import ElectraForMaskedLM

model = ElectraForMaskedLM.from_pretrained('google/electra-base-generator')

In [29]:
class ElectraGeneratorPredictions(nn.Module):
    """Prediction module for the generator, made up of two dense layers."""

    def __init__(self, config):
        super().__init__()

        self.LayerNorm = BertLayerNorm(config.embedding_size)
        self.dense = nn.Linear(config.hidden_size, config.embedding_size)

    def forward(self, generator_hidden_states):
        hidden_states = self.dense(generator_hidden_states)
        hidden_states = get_activation("gelu")(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)

        return hidden_states

In [30]:
class ElectraForMaskedLM(ElectraPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.electra = ElectraModel(config)
        self.generator_predictions = ElectraGeneratorPredictions(config)

        self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
        self.init_weights()

    def get_output_embeddings(self):
        return self.generator_lm_head
    
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        **kwargs
    ):
        if "masked_lm_labels" in kwargs:
            warnings.warn(
                "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
                DeprecationWarning,
            )
            labels = kwargs.pop("masked_lm_labels")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

        generator_hidden_states = self.electra(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            inputs_embeds,
            output_attentions,
            output_hidden_states,
        )
        generator_sequence_output = generator_hidden_states[0]

        prediction_scores = self.generator_predictions(generator_sequence_output)
        prediction_scores = self.generator_lm_head(prediction_scores)

        output = (prediction_scores,)

        # Masked language modeling softmax layer
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()  # -100 index = padding token
            loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
            output = (loss,) + output

        output += generator_hidden_states[1:]

        return output  # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)