In [1]:
import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
import torch.nn.functional as F

from kobert_tokenizer import KoBERTTokenizer


In [44]:
class BiEncoder(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config)
        
        # Initialize weights and apply final processing
        self.post_init()
    
    def forward(
        self,
        input_ids = None,
        attention_mask = None,
        token_type_ids = None,
        candidate_input_ids = None,
        candidate_attention_mask = None,
        candidate_token_type_ids = None,
        position_ids = None,
        head_mask = None,
        inputs_embeds = None,
        output_attentions = None,
        output_hidden_states = None,
        training = False
        ):
        
        context_output = self.bert(input_ids = input_ids,
                                attention_mask= attention_mask,
                                token_type_ids= token_type_ids,
                                position_ids = position_ids,
                                head_mask = head_mask,
                                inputs_embeds = inputs_embeds,
                                output_attentions = output_attentions,
                                output_hidden_states = output_hidden_states,
                                )

        if training:
            candidate_output = self.bert(input_ids = candidate_input_ids,
                                    attention_mask = candidate_attention_mask, 
                                    token_type_ids = candidate_token_type_ids,
                                    position_ids = position_ids,
                                    head_mask = head_mask,
                                    inputs_embeds = inputs_embeds,
                                    output_attentions = output_attentions,
                                    output_hidden_states = output_hidden_states,
                                    )
        
            dot_product = torch.matmul(context_output[1], candidate_output[1].t())

            loss_fnt = nn.CrossEntropyLoss()
            loss = loss_fnt(dot_product, torch.arange(dot_product.shape[0]))

            return loss, dot_product

        return context_output[1]


In [45]:
tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
model = BiEncoder.from_pretrained('skt/kobert-base-v1')

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'XLNetTokenizer'. 
The class this function is called from is 'KoBERTTokenizer'.


In [46]:
seq = tokenizer(['i am a student', 'wqwei am a student'], padding='max_length', max_length=50, truncation=True, return_tensors='pt')
seq2 = tokenizer(['i am a student', 'wqwei am a student'], padding='max_length', max_length=50, truncation=True, return_tensors='pt')

output = model(input_ids = seq['input_ids'],
                attention_mask = seq['attention_mask'],
                token_type_ids = seq['token_type_ids'],
                candidate_input_ids = seq2['input_ids'],
                candidate_attention_mask = seq2['attention_mask'],
                candidate_token_type_ids = seq2['token_type_ids'],
                training = True)
output

tensor([[  2, 517, 405, 517, 373, 517, 367, 517, 441, 446, 388, 393,   3,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1],
        [  2, 517, 455, 433, 455, 389, 405, 517, 373, 517, 367, 517, 441, 446,
         388, 393,   3,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1]])


(tensor(3.3868, grad_fn=<NllLossBackward>),
 tensor([[ 66.0941,  72.8666],
         [ 72.8666, 169.8437]], grad_fn=<MmBackward>))