In [36]:
import torch
import torch.nn as nn
from torchcrf import CRF
from transformers import ElectraModel, ElectraTokenizer

In [37]:
# class KobertBiLSTMCRF(nn.Module):
#     """ koBERT with CRF """
#     def __init__(self, config, num_classes, vocab=None) -> None:
#         super(KobertBiLSTMCRF, self).__init__()

#         if vocab is None: # pretraining model 사용
#             self.bert, self.vocab = get_pytorch_kobert_model()
#         else: # finetuning model 사용           
#             self.bert = BertModel(config=BertConfig.from_dict(bert_config))
#             self.vocab = vocab
#         self._pad_id = self.vocab.token_to_idx[self.vocab.padding_token]

#         self.dropout = nn.Dropout(config.dropout)
#         self.bilstm = nn.LSTM(config.hidden_size, (config.hidden_size) // 2, dropout=config.dropout, batch_first=True, bidirectional=True)
#         self.position_wise_ff = nn.Linear(config.hidden_size, num_classes)
#         self.crf = CRF(num_labels=num_classes)

#     def forward(self, input_ids, token_type_ids=None, tags=None, using_pack_sequence=True):

#         seq_length = input_ids.ne(self._pad_id).sum(dim=1)
#         attention_mask = input_ids.ne(self._pad_id).float()
#         outputs = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
#         last_encoder_layer = outputs[0]
#         last_encoder_layer = self.dropout(last_encoder_layer)
#         if using_pack_sequence is True:
#             pack_padded_last_encoder_layer = pack_padded_sequence(last_encoder_layer, seq_length, batch_first=True, enforce_sorted=False)
#             outputs, hc = self.bilstm(pack_padded_last_encoder_layer)
#             outputs = pad_packed_sequence(outputs, batch_first=True, padding_value=self._pad_id)[0]
#         else:
#             outputs, hc = self.bilstm(last_encoder_layer)
#         emissions = self.position_wise_ff(outputs)

#         if tags is not None: # crf training
#             log_likelihood, sequence_of_tags = self.crf(emissions, tags), self.crf.decode(emissions)
#             return log_likelihood, sequence_of_tags
#         else: # tag inference
#             sequence_of_tags = self.crf.decode(emissions)
#             return sequence_of_tags

In [38]:
class BERT_BiLSTM_CRF(nn.Module):

    def __init__(self, bert, config, need_birnn=False, rnn_dim=128):
        super(BERT_BiLSTM_CRF, self).__init__()
        
        self.num_tags = config.num_labels
        self.bert = bert
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        out_dim = config.hidden_size
        self.need_birnn = need_birnn

        # if False, no use of BiLSTM
        if need_birnn:
            self.birnn = nn.LSTM(config.hidden_size, rnn_dim, num_layers=1, bidirectional=True, batch_first=True)
            out_dim = rnn_dim*2
        
        self.hidden2tag = nn.Linear(out_dim, config.num_labels)
        self.crf = CRF(config.num_labels, batch_first=True)
    

    def forward(self, input_ids, tags, input_mask=None):
        emissions = self.tag_outputs(input_ids, input_mask)
        loss = -1*self.crf(emissions, tags.long(), input_mask.byte()) # log_likelihood
        prediction = self.crf.decode(emissions, input_mask.byte())
        return loss, prediction

    def tag_outputs(self, input_ids, input_mask=None):

        outputs = self.bert(input_ids, attention_mask=input_mask)

        sequence_output = outputs[0]
        
        if self.need_birnn:
            sequence_output, _ = self.birnn(sequence_output)

        sequence_output = self.dropout(sequence_output)
        emissions = self.hidden2tag(sequence_output)

        return emissions
    
    def predict(self, input_ids, input_mask=None):
        emissions = self.tag_outputs(input_ids, input_mask)
        return self.crf.decode(emissions, input_mask.byte())

In [39]:
ckpt = "monologg/koelectra-base-v3-discriminator"
bert = ElectraModel.from_pretrained(ckpt, num_labels=22)
tokenizer = ElectraTokenizer.from_pretrained('./tokenizer/')
bert.resize_token_embeddings(len(tokenizer))
config = bert.config
print(config.num_labels)
print(config.hidden_dropout_prob)
print(config.hidden_size)

Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


22
0.1
768


In [40]:
model = BERT_BiLSTM_CRF(bert, config, need_birnn=True, rnn_dim=128)

In [103]:
text = '야호!!!'
input_ids, token_type_ids, attention_mask = dict(tokenizer(text, return_tensors='pt')).values()
tags = torch.Tensor([[1, 5, 9, 5, 9, 5, 14]])

emissions = model.tag_outputs(input_ids, attention_mask)
loss = -1*model.crf(emissions, tags.long(), attention_mask.byte()) # log_likelihood
loss

tensor(22.3188, grad_fn=<MulBackward0>)

In [104]:
input_ids = torch.cat([input_ids, input_ids, input_ids])
attention_mask = torch.cat([attention_mask, attention_mask, attention_mask])
tags = torch.cat([tags, tags, tags])

In [105]:
loss = model(input_ids, tags, attention_mask)
prediction = model.predict(input_ids, attention_mask)
loss, prediction

(tensor(66.5115, grad_fn=<MulBackward0>),
 [[11, 11, 11, 11, 11, 11, 11],
  [11, 11, 11, 11, 11, 11, 8],
  [11, 11, 11, 11, 11, 11, 11]])

In [91]:
loss.mean()

tensor(22.2125, grad_fn=<MeanBackward0>)

In [43]:
loss.item()

22.313861846923828

In [44]:
emissions = model.tag_outputs(input_ids, attention_mask)
type(emissions)

torch.Tensor

In [89]:
import random

el = [15, 15, 15, 15, 15, 9, 8]
el = [random.randint(0, 21) for _ in range(150)]
sample = [el, el, el, el, el]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sample = list(map(lambda x: x + [0 for _ in range(128 - len(x))], sample))
sample = torch.LongTensor(sample).to(device)
sample = sample.view(-1, sample.shape[-1]) # shape(batchsize, length, ouputdim) --> shape(batchsize, length*outputdim)
sample.shape

torch.Size([5, 150])

In [80]:
def categorical_accuracy(preds, y, tag_pad_idx):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    # max_preds = preds.argmax(dim = -1, keepdim = True) # get the index of the max probability
    non_pad_elements = torch.nonzero((y != tag_pad_idx))
    correct = preds[non_pad_elements].squeeze(1).eq(y[non_pad_elements])
    return correct.sum() / torch.FloatTensor([y[non_pad_elements].shape[0]]).to(device)

In [82]:
categorical_accuracy(sample[0], sample[0], 0)

tensor([7.5972], device='cuda:0')

In [83]:
y = sample[0]

In [84]:
tag_pad_idx = 0
# non_pad_elements = (y != tag_pad_idx).nonzero()
non_pad_elements = torch.nonzero((y != tag_pad_idx))
non_pad_elements

torch.FloatTensor([y[non_pad_elements].shape[0]])

tensor([144.])

In [88]:
y[non_pad_elements].squeeze(1)

tensor([18, 16,  9,  7,  6,  4,  9,  7, 19,  2, 18,  8, 12, 20,  5,  4, 13,  1,
        12,  8,  9,  1, 14, 20, 19, 13,  3,  4,  1,  4,  6,  5, 19, 15, 20, 14,
        20, 13,  7,  2, 10,  5,  8,  2, 10,  2,  6,  8, 12, 10, 11,  3, 20,  3,
         9,  4, 13,  9,  9, 18, 21, 12, 13,  7,  2, 14,  9,  3, 20, 15, 13,  2,
        17,  3, 17, 15, 11,  4,  6,  3, 16, 17,  1, 19, 20, 17, 12, 13, 20, 20,
        14,  8,  5, 12, 19, 10, 12,  9,  9,  4,  5, 14, 20,  5, 14,  4,  7, 14,
        13, 21,  4, 19,  2, 16,  9,  3,  5, 16, 16, 12, 18, 10, 15, 16, 21,  6,
        12, 19,  8,  8,  7, 21,  2, 16, 12,  2,  5, 11, 18,  2,  5,  7,  6, 14],
       device='cuda:0')

In [85]:
y[non_pad_elements].squeeze(1).eq(y[non_pad_elements])

tensor([[ True, False, False,  ..., False, False, False],
        [False,  True, False,  ..., False, False, False],
        [False, False,  True,  ..., False, False, False],
        ...,
        [False, False, False,  ...,  True, False, False],
        [False, False, False,  ..., False,  True, False],
        [False, False, False,  ..., False, False,  True]], device='cuda:0')