In [None]:
import os
import json
import numpy as np
from tqdm import tqdm, trange
import random
import torch
import torch.nn.functional as F
from transformers import BertModel, BertPreTrainedModel, AdamW, TrainingArguments, get_linear_schedule_with_warmup

torch.manual_seed(2021)
torch.cuda.manual_seed(2021)
np.random.seed(2021)
random.seed(2021)

# 1. 데이터 로드

In [None]:
data_path  = "../../data/"
context_path = "wikipedia_documents.json"
with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = list(dict.fromkeys([v["text"] for v in wiki.values()]))
print('context len :', len(corpus))

In [None]:
from datasets import load_from_disk, concatenate_datasets

dataset_dir = '../../data/train_dataset'
dataset = load_from_disk(dataset_dir)
training_dataset = concatenate_datasets([
        dataset["train"].flatten_indices(),
        dataset["validation"].flatten_indices(),
    ])
print(len(dataset['train']), len(training_dataset))

In [None]:
from transformers import AutoTokenizer

model_checkpoint = "klue/bert-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# 2. Sparse Embedding Retrieval

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer
tfidfv = TfidfVectorizer(tokenizer=tokenizer.tokenize, ngram_range=(1, 2), max_features=50000)
p_embedding = tfidfv.fit_transform(corpus)

In [None]:
def get_topk_similarity(qeury_vec, k):
    result = qeury_vec * p_embedding.T
    result = result.toarray()

    doc_scores3 = np.partition(result, -k)[:, -k:][:, ::-1]
    ind = np.argsort(doc_scores3, axis=-1)[:, ::-1]
    doc_scores3 = np.sort(doc_scores3, axis=-1)[:, ::-1]
    doc_indices3 = np.argpartition(result, -k)[:, -k:][:, ::-1]
    r, c = ind.shape
    ind = ind + np.tile(np.arange(r).reshape(-1, 1), (1, c)) * c
    doc_indices3 = doc_indices3.ravel()[ind].reshape(r, c)

    return doc_scores3, doc_indices3

sparse_answer_rank_list = []
for idx in tqdm(range(len(dataset['validation']))):
    query = dataset['validation'][idx]['question']
    ground_truth = dataset['validation'][idx]['context']
    #print(ground_truth)
    query_vec = tfidfv.transform([query])
    k=len(corpus)

    doc_scores, doc_indices = get_topk_similarity(query_vec, k)

sparse_answer_rank_list

# 3. Dense Embedding Retrieval

In [None]:
training_dataset['context'][0]

In [None]:
def get_resverse_topk_similarity(qeury_vec, k):
    result = qeury_vec * p_embedding.T
    result = result.toarray()

    doc_scores3 = np.partition(result, k)[:, :k][:, ::-1]
    ind = np.argsort(doc_scores3, axis=-1)[:, :]# ::-1]
    doc_scores3 = np.sort(doc_scores3, axis=-1)[:, :]# ::-1]
    doc_indices3 = np.argpartition(result, k)[:, :k][:, :]# ::-1]
    r, c = ind.shape
    ind = ind + np.tile(np.arange(r).reshape(-1, 1), (1, c)) * c
    doc_indices3 = doc_indices3.ravel()[ind].reshape(r, c)

    return doc_scores3, doc_indices3

In [None]:
query_vec = tfidfv.transform([training_dataset['context'][1]])
result = query_vec * p_embedding.T

In [None]:
get_resverse_topk_similarity(query_vec, 3)

In [None]:
corpus[25734]

In [None]:
training_dataset['context'][0]

In [None]:
# set number of neagative sample
num_neg = 3
corpus = np.array(corpus)

query_vec = tfidfv.transform(training_dataset['context'])
doc_scores, doc_indices = get_resverse_topk_similarity(query_vec, 3)
neg_idxs = doc_indices

In [None]:
neg_idxs[0]

In [None]:
p_with_neg = []
for idx, c in enumerate(training_dataset['context']):
    p_neg = corpus[neg_idxs[idx]]
    #print(p_neg)#
    #if idx==2: break
    p_with_neg.append(c)
    p_with_neg.extend(p_neg)
#print(p_with_neg)

print('[Positive context]')
print(p_with_neg[4], '\n')
print('[Negative context]')
print(p_with_neg[5], '\n', p_with_neg[6])

In [None]:
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset)

q_seqs = tokenizer(training_dataset['question'], padding="max_length", truncation=True, return_tensors='pt')
p_seqs = tokenizer(p_with_neg, padding="max_length", truncation=True, return_tensors='pt')

max_len = p_seqs['input_ids'].size(-1)
p_seqs['input_ids'] = p_seqs['input_ids'].view(-1, num_neg+1, max_len)
p_seqs['attention_mask'] = p_seqs['attention_mask'].view(-1, num_neg+1, max_len)
p_seqs['token_type_ids'] = p_seqs['token_type_ids'].view(-1, num_neg+1, max_len)

print(p_seqs['input_ids'].size())  #(num_example, pos + neg, max_len)

train_dataset = TensorDataset(p_seqs['input_ids'], p_seqs['attention_mask'], p_seqs['token_type_ids'], 
                        q_seqs['input_ids'], q_seqs['attention_mask'], q_seqs['token_type_ids'])

In [None]:
class BertEncoder(BertPreTrainedModel):
  def __init__(self, config):
    super(BertEncoder, self).__init__(config)

    self.bert = BertModel(config)
    self.init_weights()
      
  def forward(self, input_ids, 
              attention_mask=None, token_type_ids=None): 
  
      outputs = self.bert(input_ids,
                          attention_mask=attention_mask,
                          token_type_ids=token_type_ids)
      
      pooled_output = outputs[1]
      return pooled_output

In [None]:
# load pre-trained model on cuda (if available)
p_encoder = BertEncoder.from_pretrained(model_checkpoint).cuda()
q_encoder = BertEncoder.from_pretrained(model_checkpoint).cuda()

In [None]:
def train(args, num_neg, dataset, p_model, q_model):
  
  # Dataloader
  train_sampler = RandomSampler(dataset)
  train_dataloader = DataLoader(dataset, sampler=train_sampler, batch_size=args.per_device_train_batch_size)

  # Optimizer
  no_decay = ['bias', 'LayerNorm.weight']
  optimizer_grouped_parameters = [
        {'params': [p for n, p in p_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in p_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        {'params': [p for n, p in q_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in q_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
  t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
  scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)

  # Start training!
  global_step = 0
  
  p_model.zero_grad()
  q_model.zero_grad()
  torch.cuda.empty_cache()
  
  train_iterator = trange(int(args.num_train_epochs), desc="Epoch")

  for _ in train_iterator:
    epoch_iterator = tqdm(train_dataloader, desc="Iteration")

    for step, batch in enumerate(epoch_iterator):
      q_encoder.train()
      p_encoder.train()
      
      targets = torch.zeros(args.per_device_train_batch_size).long()
      if torch.cuda.is_available():
        batch = tuple(t.cuda() for t in batch)
        targets = targets.cuda()

      p_inputs = {'input_ids': batch[0].view(
                                    args.per_device_train_batch_size*(num_neg+1), -1),
                  'attention_mask': batch[1].view(
                                    args.per_device_train_batch_size*(num_neg+1), -1),
                  'token_type_ids': batch[2].view(
                                    args.per_device_train_batch_size*(num_neg+1), -1)
                  }
      
      q_inputs = {'input_ids': batch[3],
                  'attention_mask': batch[4],
                  'token_type_ids': batch[5]}
      
      p_outputs = p_model(**p_inputs)  #(batch_size*(num_neg+1), emb_dim)
      q_outputs = q_model(**q_inputs)  #(batch_size*, emb_dim)

      # Calculate similarity score & loss
      p_outputs = p_outputs.view(args.per_device_train_batch_size, -1, num_neg+1)
      q_outputs = q_outputs.view(args.per_device_train_batch_size, 1, -1)

      sim_scores = torch.bmm(q_outputs, p_outputs).squeeze()  #(batch_size, num_neg+1)
      sim_scores = sim_scores.view(args.per_device_train_batch_size, -1)
      sim_scores = F.log_softmax(sim_scores, dim=1)

      loss = F.nll_loss(sim_scores, targets)
      #print(loss)

      loss.backward()
      optimizer.step()
      scheduler.step()
      q_model.zero_grad()
      p_model.zero_grad()
      global_step += 1
      
      torch.cuda.empty_cache()
  return p_model, q_model

In [None]:
args = TrainingArguments(
    output_dir="dense_retireval",
    evaluation_strategy="epoch",
    learning_rate=1e-4,
    fp16=True,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=30,
    weight_decay=0.01,
    save_epochs=1,
)
p_encoder, q_encoder = train(args, num_neg, train_dataset, p_encoder, q_encoder)

In [None]:
valid_corpus = list(set([example['context'] for example in dataset['validation']]))
valid_corpus_idx = [np.where(corpus==e)[0].tolist()[0] for e in valid_corpus] # valid문서들의 전체문서 idx
valid_corpus_idx = np.array(valid_corpus_idx)

answer_dense_rank_list = []
for idx in tqdm(range(len(dataset['validation']))):
    # 1. query와 정답을 뽑아내기
    query = dataset['validation'][idx]['question']
    ground_truth = dataset['validation'][idx]['context']

    if not ground_truth in valid_corpus:
        valid_corpus.append(ground_truth) 
    # print(query)
    # print(ground_truth, '\n\n')

    # 2. passage encoder, question encoder을 이용해 dense embedding 생성
    with torch.no_grad():
        p_encoder.eval()
        q_encoder.eval()

        q_seqs_val = tokenizer([query], padding="max_length", truncation=True, return_tensors='pt').to('cuda')
        q_emb = q_encoder(**q_seqs_val).to('cpu')  #(num_query, emb_dim)

        p_embs = []
        for p in valid_corpus:
            p = tokenizer(p, padding="max_length", truncation=True, return_tensors='pt').to('cuda')
            p_emb = p_encoder(**p).to('cpu').numpy()
            p_embs.append(p_emb)

    p_embs = torch.Tensor(p_embs).squeeze()  # (num_passage, emb_dim)
    #print(p_embs.size(), q_emb.size())

    # 3. 생성된 embedding에 dot product를 수행 => Document들의 similarity ranking을 구함
    dot_prod_scores = torch.matmul(q_emb, torch.transpose(p_embs, 0, 1))
    #print(dot_prod_scores.size())
    rank = torch.argsort(dot_prod_scores, dim=1, descending=True).squeeze()
    #print(dot_prod_scores)
    #print(rank)

    # 4. rank에서 groundtruth와 비교
    rank_doc_idx = valid_corpus_idx[rank.tolist()] # rank된 문서들의 전체문서 idx
    ground_truth_doc_idx = np.where(corpus==ground_truth)[0].tolist()[0] # ground truth의 전체문서 idx
    answer_rank = np.where(rank_doc_idx==ground_truth_doc_idx)[0].tolist()[0] # 정답문서의 rank
    print(answer_rank)
    answer_dense_rank_list.append(answer_rank)