In [139]:
import torch
import torch.nn as nn
from torchcrf import CRF
from transformers import ElectraModel, ElectraTokenizer

In [140]:
# class KobertBiLSTMCRF(nn.Module):
#     """ koBERT with CRF """
#     def __init__(self, config, num_classes, vocab=None) -> None:
#         super(KobertBiLSTMCRF, self).__init__()

#         if vocab is None: # pretraining model 사용
#             self.bert, self.vocab = get_pytorch_kobert_model()
#         else: # finetuning model 사용           
#             self.bert = BertModel(config=BertConfig.from_dict(bert_config))
#             self.vocab = vocab
#         self._pad_id = self.vocab.token_to_idx[self.vocab.padding_token]

#         self.dropout = nn.Dropout(config.dropout)
#         self.bilstm = nn.LSTM(config.hidden_size, (config.hidden_size) // 2, dropout=config.dropout, batch_first=True, bidirectional=True)
#         self.position_wise_ff = nn.Linear(config.hidden_size, num_classes)
#         self.crf = CRF(num_labels=num_classes)

#     def forward(self, input_ids, token_type_ids=None, tags=None, using_pack_sequence=True):

#         seq_length = input_ids.ne(self._pad_id).sum(dim=1)
#         attention_mask = input_ids.ne(self._pad_id).float()
#         outputs = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
#         last_encoder_layer = outputs[0]
#         last_encoder_layer = self.dropout(last_encoder_layer)
#         if using_pack_sequence is True:
#             pack_padded_last_encoder_layer = pack_padded_sequence(last_encoder_layer, seq_length, batch_first=True, enforce_sorted=False)
#             outputs, hc = self.bilstm(pack_padded_last_encoder_layer)
#             outputs = pad_packed_sequence(outputs, batch_first=True, padding_value=self._pad_id)[0]
#         else:
#             outputs, hc = self.bilstm(last_encoder_layer)
#         emissions = self.position_wise_ff(outputs)

#         if tags is not None: # crf training
#             log_likelihood, sequence_of_tags = self.crf(emissions, tags), self.crf.decode(emissions)
#             return log_likelihood, sequence_of_tags
#         else: # tag inference
#             sequence_of_tags = self.crf.decode(emissions)
#             return sequence_of_tags

In [141]:
class BERT_BiLSTM_CRF(nn.Module):

    def __init__(self, bert, config, need_birnn=False, rnn_dim=128):
        super(BERT_BiLSTM_CRF, self).__init__()
        
        self.num_tags = config.num_labels
        self.bert = bert
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        out_dim = config.hidden_size
        self.need_birnn = need_birnn

        # if False, no use of BiLSTM
        if need_birnn:
            self.birnn = nn.LSTM(config.hidden_size, rnn_dim, num_layers=1, bidirectional=True, batch_first=True)
            out_dim = rnn_dim*2
        
        self.hidden2tag = nn.Linear(out_dim, config.num_labels)
        self.crf = CRF(config.num_labels, batch_first=True)
    

    def forward(self, input_ids, tags, input_mask=None):
        emissions = self.tag_outputs(input_ids, input_mask)
        loss = -1*self.crf(emissions, tags.long(), input_mask.byte())

        return loss

    def tag_outputs(self, input_ids, input_mask=None):

        outputs = self.bert(input_ids, attention_mask=input_mask)

        sequence_output = outputs[0]
        
        if self.need_birnn:
            sequence_output, _ = self.birnn(sequence_output)

        sequence_output = self.dropout(sequence_output)
        emissions = self.hidden2tag(sequence_output)

        return emissions
    
    def predict(self, input_ids, input_mask=None):
        emissions = self.tag_outputs(input_ids, input_mask)
        return self.crf.decode(emissions, input_mask.byte())

In [142]:
ckpt = "monologg/koelectra-base-v3-discriminator"
bert = ElectraModel.from_pretrained(ckpt, num_labels=22)
tokenizer = ElectraTokenizer.from_pretrained('./tokenizer/')
bert.resize_token_embeddings(len(tokenizer))
config = bert.config
print(config.num_labels)
print(config.hidden_dropout_prob)
print(config.hidden_size)

Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense.weight']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


22
0.1
768


In [143]:
model = BERT_BiLSTM_CRF(bert, config, need_birnn=True, rnn_dim=128)

In [144]:
text = '야호!!!'
input_ids, token_type_ids, attention_mask = dict(tokenizer(text, return_tensors='pt')).values()
tags = torch.Tensor([[1, 5, 9, 5, 9, 5, 14]])

emissions = model.tag_outputs(input_ids, attention_mask)
loss = -1*model.crf(emissions, tags.long(), attention_mask.byte())

In [147]:
loss = model.forward(input_ids, tags, attention_mask)

In [146]:
model.predict(input_ids, attention_mask)

[[18, 10, 10, 10, 10, 10, 7]]