In [1]:
import json
import time
import numpy as np

In [2]:
context_corpus = json.load(open('dataset/paragraph_context.json', 'r', encoding='utf8')) 
question_corpus = json.load(open('dataset/question_context.json', 'r', encoding='utf8'))
train_labels_json = json.load(open('dataset/train_labels.json', 'r', encoding='utf8'))
test_labels_json = json.load(open('dataset/test_labels.json', 'r', encoding='utf8'))

In [3]:
from torch import nn
import torch
from transformers import ElectraTokenizer, ElectraModel, ElectraConfig, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, Dataset

In [4]:
# monologg/kobert, "monologg/koelectra-base-v3-discriminator"

In [5]:
contexts = list(context_corpus.values())

In [8]:
class CreateDataset(Dataset):
    def __init__(self, context_corpus, question_corpus, labels):
        self.context_corpus = context_corpus
        self.question_corpus = question_corpus
        self.labels = self.create_labels(labels)
    
    def create_labels(self, labels):
        new_labels = []
        for key in labels.keys():
            for t in labels[key]:
                new_labels.append([key, t])
        print(new_labels[0])
        return new_labels
    
    
    def __len__(self):
        return len(self.labels)
        
    def __getitem__(self, idx):
        doc_id, que_id = self.labels[idx]
        return {'context': self.context_corpus[doc_id], 'question': self.question_corpus[que_id]}

In [9]:
class Model(nn.Module):
    def __init__(self, args):
        super().__init__()
        config = ElectraConfig.from_pretrained(args.model_name, local_file_only=True)
        self.model  = ElectraModel.from_pretrained(args.model_name, config=config)
        self.tokenizer = ElectraTokenizer.from_pretrained(args.model_name, local_file_only=True)
        
        self.punctation_idx = self.tokenizer.get_vocab()['.']
        self.pad_token_idx = self.tokenizer.pad_token_id
        self.mask_token_idx = self.tokenizer.mask_token_id
        self.d = self.tokenizer.get_vocab()['[unused0]']
        self.q = self.tokenizer.get_vocab()['[unused1]']
        self.linear = nn.Linear(config.hidden_size, 256)
        
        self.doc_maxlen = args.doc_maxlen
        self.query_maxlen = args.query_maxlen
        self.device = args.device
        self.criterion = nn.CrossEntropyLoss()
        
        
    def forward(self, feature):
        q_output = self.query(feature['question'])
        d_output = self.doc(feature['context'])
        prediction = self.similarity(q_output, d_output)
        loss = self.calc_loss(prediction)
        return loss
    
    def calc_loss(self, prediction):
        batch_size = prediction.shape[0]
        label = torch.arange(batch_size).to(self.device)
        return self.criterion(prediction, label)
        
    
    def similarity(self, q_output, d_output):
        # q_output = [batch, 128, 256]
        # d_output = [batch, seq_len, 256]
        prediction = torch.einsum('ijk,abk->iajb', q_output, d_output)
        prediction, _ = torch.max(prediction, dim=-1)
        prediction = torch.sum(prediction, dim=-1)
        return prediction
    
    
    def doc(self, D):
        inputs = self.tokenizer(D, return_tensors='pt', padding=True, truncation=True, max_length=self.doc_maxlen)
        
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        
        batch = input_ids.shape[0]
        
        b = torch.LongTensor([self.d]* batch).view(-1, 1)
        a = torch.ones(size=(batch, 1))
        
        input_ids = torch.cat([input_ids[:, :1], b, input_ids[:, 1:]], dim=1).to(self.device)
        attention_mask = torch.cat([attention_mask[:, :1], a, attention_mask[:, 1:]], dim=1).to(self.device)
        
        punctation = input_ids
        
        model_input = {'input_ids': input_ids,
                      'attention_mask': attention_mask}
        
        output = self.model(**model_input)['last_hidden_state']
        output = self.linear(output)
        
        new_mask = attention_mask * (punctation != self.punctation_idx)
        output = output * new_mask.unsqueeze(-1)
        output = torch.nn.functional.normalize(output, p=2, dim=2)
        return output
    
    
    def query(self, Q):
        inputs = self.tokenizer(Q, return_tensors='pt', truncation=True, max_length=self.query_maxlen,
                               pad_to_max_length=True)
        
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        
        input_ids = input_ids.masked_fill(attention_mask==self.pad_token_idx, self.mask_token_idx)
        
        batch = input_ids.shape[0]
        
        b = torch.LongTensor([self.d]* batch).view(-1, 1)
        a = torch.zeros(size=(batch, 1))
        
        input_ids = torch.cat([input_ids[:, :1], b, input_ids[:, 1:]], dim=1).to(self.device)
        new_mask = torch.ones_like(input_ids).to(self.device)
        
        punctation = input_ids
        
        model_input = {'input_ids': input_ids,
                      'attention_mask': new_mask}
        
        output = self.model(**model_input)['last_hidden_state']
        output = self.linear(output)
        output = torch.nn.functional.normalize(output, p=2, dim=2)
        return output


In [10]:
import easydict
from tqdm import tqdm
args = easydict.EasyDict({
    'model_name': 'monologg/koelectra-base-v3-discriminator',
    'doc_maxlen': 512-1,
    'query_maxlen': 128-1,
    'device': 'cuda',
    'epochs': 5,
    'warmup': 0.1,
    'batch_size': 16
})

In [11]:
model = Model(args).to(args.device).eval()

Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
dataset = CreateDataset(context_corpus, question_corpus, train_labels_json)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

['PARS_1agoYxToKo', 'QUES_TNC71lb33r']


In [15]:
optimizer = AdamW(model.parameters(), lr=2e-5)
max_step = len(dataloader) * args.epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup * max_step, num_training_steps=max_step)
pre_loss = float('inf')

In [16]:
for epoch in range(args.epochs):
    model.train()
    avg_loss = []
    for x in tqdm(dataloader):
        optimizer.zero_grad()
        loss = model(x)
        loss.backward()
        avg_loss.append(loss.item())
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        scheduler.step()
        
    print('epoch=', epoch, 'loss=', np.mean(avg_loss))
    
    if np.mean(avg_loss) < pre_loss:
        torch.save(model, 'best.pt')
torch.save(model, 'last.pt')

100%|████████████████████████████████████████████████████████████████████████████| 11652/11652 [57:17<00:00,  3.39it/s]


epoch= 0 loss= 0.279783604182856


100%|████████████████████████████████████████████████████████████████████████████| 11652/11652 [58:56<00:00,  3.29it/s]


epoch= 1 loss= 0.03537065577154934


100%|████████████████████████████████████████████████████████████████████████████| 11652/11652 [57:51<00:00,  3.36it/s]


epoch= 2 loss= 0.01118345380851512


100%|████████████████████████████████████████████████████████████████████████████| 11652/11652 [56:20<00:00,  3.45it/s]


epoch= 3 loss= 0.0043108137073284035


100%|████████████████████████████████████████████████████████████████████████████| 11652/11652 [56:43<00:00,  3.42it/s]


epoch= 4 loss= 0.0017395397259050965
