In [1]:
import torch.nn as nn
from torchcrf import CRF
from transformers import ElectraModel, ElectraTokenizer

In [36]:
class BERT_BiLSTM_CRF(nn.Module):

    def __init__(self, bert, config, need_birnn=False, rnn_dim=128):
        super(BERT_BiLSTM_CRF, self).__init__()
        
        self.num_tags = config.num_labels
        self.bert = bert
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        out_dim = config.hidden_size
        self.need_birnn = need_birnn

        # if False, no use of BiLSTM
        if need_birnn:
            self.birnn = nn.LSTM(config.hidden_size, rnn_dim, num_layers=1, bidirectional=True, batch_first=True)
            out_dim = rnn_dim*2
        
        self.hidden2tag = nn.Linear(out_dim, config.num_labels)
        self.crf = CRF(config.num_labels, batch_first=True)
    

    def forward(self, input_ids, tags, token_type_ids=None, input_mask=None):
        emissions = self.tag_outputs(input_ids, token_type_ids, input_mask)
        loss = -1*self.crf(emissions, tags, mask=input_mask)
        # loss = -1*self.crf(emissions, tags, mask=input_mask.byte())

        return loss

    
    def tag_outputs(self, input_ids, token_type_ids=None, input_mask=None):

        outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=input_mask)

        sequence_output = outputs[0]
        
        if self.need_birnn:
            sequence_output, _ = self.birnn(sequence_output)

        sequence_output = self.dropout(sequence_output)
        emissions = self.hidden2tag(sequence_output)

        return emissions
    
    def predict(self, input_ids, token_type_ids=None, input_mask=None):
        emissions = self.tag_outputs(input_ids, token_type_ids, input_mask)
        return self.crf.decode(emissions, input_mask.byte())


In [37]:
ckpt = "monologg/koelectra-base-v3-discriminator"
bert = ElectraModel.from_pretrained(ckpt, num_labels=22)
tokenizer = ElectraTokenizer.from_pretrained('./tokenizer/')
bert.resize_token_embeddings(len(tokenizer))
config = bert.config
print(config.num_labels)
print(config.hidden_dropout_prob)
print(config.hidden_size)

Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense.weight']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


22
0.1
768


In [38]:
model = BERT_BiLSTM_CRF(bert, config, need_birnn=True, rnn_dim=128)

In [39]:
text = '야호!!!'
input_ids, token_type_ids, attention_mask = dict(tokenizer(text, return_tensors='pt')).values()

In [40]:
input_ids

tensor([[   2, 3102, 4029,    5,    5,    5,    3]])

In [41]:
model.forward(input_ids, token_type_ids, attention_mask)

tensor(21.7547, grad_fn=<MulBackward0>)

In [42]:
model.predict(input_ids, token_type_ids, attention_mask)

[[8, 8, 11, 8, 11, 8, 11]]