# 5강) BERT를 활용한 Dense Passage Retrieval 실습

### Requirements

!pip install datasets
!pip install transformers

In [1]:
import os
import json

## 데이터셋 로딩


In [2]:
data_path  = "../../data/"
context_path = "wikipedia_documents.json"
with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = list(dict.fromkeys([v["text"] for v in wiki.values()]))
print('context len :', len(corpus))

context len : 56737


In [3]:
from transformers import AutoTokenizer
import numpy as np

model_checkpoint = "klue/bert-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

## Dense encoder (BERT) 학습 시키기

HuggingFace BERT를 활용하여 question encoder, passage encoder 학습

In [6]:
from tqdm import tqdm, trange
import argparse
import random
import torch
import torch.nn.functional as F
from transformers import BertModel, BertPreTrainedModel, AdamW, TrainingArguments, get_linear_schedule_with_warmup

torch.manual_seed(2021)
torch.cuda.manual_seed(2021)
np.random.seed(2021)
random.seed(2021)

1) Training Dataset 준비하기 (question, passage pairs)

---



In [7]:
from datasets import load_from_disk

dataset_dir = '../../data/train_dataset'
dataset = load_from_disk(dataset_dir)

training_dataset = dataset['train']#[sample_idx]
print(len(dataset['train']), len(training_dataset))

Negative sampling을 위한 negative sample들을 샘플링

In [10]:
# set number of neagative sample
num_neg = 3

corpus = np.array(corpus)
p_with_neg = []

for c in training_dataset['context']:
  while True:
    neg_idxs = np.random.randint(len(corpus), size=num_neg)

    if not c in corpus[neg_idxs]:
      p_neg = corpus[neg_idxs]

      p_with_neg.append(c)
      p_with_neg.extend(p_neg)
      break

print('[Positive context]')
print(p_with_neg[0], '\n')
print('[Negative context]')
print(p_with_neg[1], '\n', p_with_neg[2])

In [12]:
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset)

q_seqs = tokenizer(training_dataset['question'], padding="max_length", truncation=True, return_tensors='pt')
p_seqs = tokenizer(p_with_neg, padding="max_length", truncation=True, return_tensors='pt')

In [13]:
max_len = p_seqs['input_ids'].size(-1)
p_seqs['input_ids'] = p_seqs['input_ids'].view(-1, num_neg+1, max_len)
p_seqs['attention_mask'] = p_seqs['attention_mask'].view(-1, num_neg+1, max_len)
p_seqs['token_type_ids'] = p_seqs['token_type_ids'].view(-1, num_neg+1, max_len)

print(p_seqs['input_ids'].size())  #(num_example, pos + neg, max_len)

torch.Size([3952, 4, 512])


In [14]:
train_dataset = TensorDataset(p_seqs['input_ids'], p_seqs['attention_mask'], p_seqs['token_type_ids'], 
                        q_seqs['input_ids'], q_seqs['attention_mask'], q_seqs['token_type_ids'])

2) BERT encoder 학습시키기

BertEncoder 모델 정의 후, question encoder, passage encoder에 pre-trained weight 불러오기

In [16]:
class BertEncoder(BertPreTrainedModel):
  def __init__(self, config):
    super(BertEncoder, self).__init__(config)

    self.bert = BertModel(config)
    self.init_weights()
      
  def forward(self, input_ids, 
              attention_mask=None, token_type_ids=None): 
  
      outputs = self.bert(input_ids,
                          attention_mask=attention_mask,
                          token_type_ids=token_type_ids)
      
      pooled_output = outputs[1]

      return pooled_output

In [17]:
# load pre-trained model on cuda (if available)
p_encoder = BertEncoder.from_pretrained(model_checkpoint)
q_encoder = BertEncoder.from_pretrained(model_checkpoint)

if torch.cuda.is_available():
  p_encoder.cuda()
  q_encoder.cuda()

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.bi

Train function 정의 후, 두개의 encoder fine-tuning 하기


In [18]:
def train(args, num_neg, dataset, p_model, q_model):
  
  # Dataloader
  train_sampler = RandomSampler(dataset)
  train_dataloader = DataLoader(dataset, sampler=train_sampler, batch_size=args.per_device_train_batch_size)

  # Optimizer
  no_decay = ['bias', 'LayerNorm.weight']
  optimizer_grouped_parameters = [
        {'params': [p for n, p in p_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in p_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        {'params': [p for n, p in q_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in q_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
  t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
  scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)

  # Start training!
  global_step = 0
  
  p_model.zero_grad()
  q_model.zero_grad()
  torch.cuda.empty_cache()
  
  train_iterator = trange(int(args.num_train_epochs), desc="Epoch")

  for _ in train_iterator:
    epoch_iterator = tqdm(train_dataloader, desc="Iteration")

    for step, batch in enumerate(epoch_iterator):
      q_encoder.train()
      p_encoder.train()
      
      targets = torch.zeros(args.per_device_train_batch_size).long()
      if torch.cuda.is_available():
        batch = tuple(t.cuda() for t in batch)
        targets = targets.cuda()

      p_inputs = {'input_ids': batch[0].view(
                                    args.per_device_train_batch_size*(num_neg+1), -1),
                  'attention_mask': batch[1].view(
                                    args.per_device_train_batch_size*(num_neg+1), -1),
                  'token_type_ids': batch[2].view(
                                    args.per_device_train_batch_size*(num_neg+1), -1)
                  }
      
      q_inputs = {'input_ids': batch[3],
                  'attention_mask': batch[4],
                  'token_type_ids': batch[5]}
      
      p_outputs = p_model(**p_inputs)  #(batch_size*(num_neg+1), emb_dim)
      q_outputs = q_model(**q_inputs)  #(batch_size*, emb_dim)

      # Calculate similarity score & loss
      p_outputs = p_outputs.view(args.per_device_train_batch_size, -1, num_neg+1)
      q_outputs = q_outputs.view(args.per_device_train_batch_size, 1, -1)

      sim_scores = torch.bmm(q_outputs, p_outputs).squeeze()  #(batch_size, num_neg+1)
      sim_scores = sim_scores.view(args.per_device_train_batch_size, -1)
      sim_scores = F.log_softmax(sim_scores, dim=1)

      loss = F.nll_loss(sim_scores, targets)
      #print(loss)

      loss.backward()
      optimizer.step()
      scheduler.step()
      q_model.zero_grad()
      p_model.zero_grad()
      global_step += 1
      
      torch.cuda.empty_cache()
  return p_model, q_model

In [19]:
args = TrainingArguments(
    output_dir="dense_retireval",
    evaluation_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=2,
    weight_decay=0.01
)

In [20]:
p_encoder, q_encoder = train(args, num_neg, train_dataset, p_encoder, q_encoder)

Iteration: 100%|██████████| 988/988 [14:53<00:00,  1.11it/s]
Iteration: 100%|██████████| 988/988 [14:46<00:00,  1.11it/s]
Epoch: 100%|██████████| 2/2 [29:40<00:00, 890.09s/it]


## Dense Embedding을 활용하여 passage retrieval 실습해보기

In [77]:
valid_corpus = list(set([example['context'] for example in dataset['validation']]))
valid_corpus_idx = [np.where(corpus==e)[0].tolist()[0] for e in valid_corpus] # valid문서들의 전체문서 idx
valid_corpus_idx = np.array(valid_corpus_idx)

# Rank를 비교해서 검증

### Dense

In [137]:
answer_rank_list = []
for idx in tqdm(range(len(dataset['validation']))):
    # 1. query와 정답을 뽑아내기
    query = dataset['validation'][idx]['question']
    ground_truth = dataset['validation'][idx]['context']

    if not ground_truth in valid_corpus:
        valid_corpus.append(ground_truth) 
    # print(query)
    # print(ground_truth, '\n\n')

    # 2. passage encoder, question encoder을 이용해 dense embedding 생성
    with torch.no_grad():
        p_encoder.eval()
        q_encoder.eval()

        q_seqs_val = tokenizer([query], padding="max_length", truncation=True, return_tensors='pt').to('cuda')
        q_emb = q_encoder(**q_seqs_val).to('cpu')  #(num_query, emb_dim)

        p_embs = []
        for p in valid_corpus:
            p = tokenizer(p, padding="max_length", truncation=True, return_tensors='pt').to('cuda')
            p_emb = p_encoder(**p).to('cpu').numpy()
            p_embs.append(p_emb)

    p_embs = torch.Tensor(p_embs).squeeze()  # (num_passage, emb_dim)
    #print(p_embs.size(), q_emb.size())

    # 3. 생성된 embedding에 dot product를 수행 => Document들의 similarity ranking을 구함
    dot_prod_scores = torch.matmul(q_emb, torch.transpose(p_embs, 0, 1))
    #print(dot_prod_scores.size())
    rank = torch.argsort(dot_prod_scores, dim=1, descending=True).squeeze()
    #print(dot_prod_scores)
    #print(rank)

    # 4. rank에서 groundtruth와 비교
    rank_doc_idx = valid_corpus_idx[rank.tolist()] # rank된 문서들의 전체문서 idx
    ground_truth_doc_idx = np.where(corpus==ground_truth)[0].tolist()[0] # ground truth의 전체문서 idx
    answer_rank = np.where(rank_doc_idx==ground_truth_doc_idx)[0].tolist()[0] # 정답문서의 rank
    print(answer_rank)
    answer_rank_list.append(answer_rank)

  0%|          | 1/240 [00:04<18:18,  4.60s/it]

214


  1%|          | 2/240 [00:09<18:26,  4.65s/it]

56


  1%|▏         | 3/240 [00:13<18:17,  4.63s/it]

127


  2%|▏         | 4/240 [00:18<18:11,  4.62s/it]

90


  2%|▏         | 5/240 [00:23<18:51,  4.81s/it]

181


  2%|▎         | 6/240 [00:28<18:33,  4.76s/it]

42


  3%|▎         | 7/240 [00:33<18:26,  4.75s/it]

191


  3%|▎         | 8/240 [00:37<18:08,  4.69s/it]

221


  4%|▍         | 9/240 [00:42<17:53,  4.65s/it]

171


  4%|▍         | 10/240 [00:46<17:41,  4.61s/it]

177


  5%|▍         | 11/240 [00:59<26:35,  6.97s/it]

125


  5%|▌         | 12/240 [01:03<23:48,  6.27s/it]

84


  5%|▌         | 13/240 [01:08<21:52,  5.78s/it]

157


  6%|▌         | 14/240 [01:13<20:44,  5.51s/it]

182


  6%|▋         | 15/240 [01:18<19:37,  5.23s/it]

29


  7%|▋         | 16/240 [01:22<18:51,  5.05s/it]

218


  7%|▋         | 17/240 [01:27<18:58,  5.11s/it]

57


  8%|▊         | 18/240 [01:33<19:09,  5.18s/it]

9


  8%|▊         | 19/240 [01:37<18:25,  5.00s/it]

34


  8%|▊         | 20/240 [01:42<17:51,  4.87s/it]

194


  9%|▉         | 21/240 [01:46<17:28,  4.79s/it]

103


  9%|▉         | 22/240 [01:51<17:11,  4.73s/it]

157


 10%|▉         | 23/240 [01:56<16:59,  4.70s/it]

177


 10%|█         | 24/240 [02:02<18:48,  5.23s/it]

75


 10%|█         | 25/240 [02:07<18:14,  5.09s/it]

206


 11%|█         | 26/240 [02:12<17:37,  4.94s/it]

0


 11%|█▏        | 27/240 [02:23<24:04,  6.78s/it]

152


 12%|█▏        | 28/240 [02:30<24:14,  6.86s/it]

63


 12%|█▏        | 29/240 [02:39<26:31,  7.54s/it]

92


 12%|█▎        | 30/240 [02:45<24:54,  7.12s/it]

104


 13%|█▎        | 31/240 [02:49<22:05,  6.34s/it]

79


 13%|█▎        | 32/240 [02:54<20:11,  5.82s/it]

200


 14%|█▍        | 33/240 [02:59<18:52,  5.47s/it]

86


 14%|█▍        | 34/240 [03:03<17:53,  5.21s/it]

127


 15%|█▍        | 35/240 [03:08<17:12,  5.03s/it]

17


 15%|█▌        | 36/240 [03:13<16:42,  4.91s/it]

168


 15%|█▌        | 37/240 [03:17<16:18,  4.82s/it]

226


 16%|█▌        | 38/240 [03:22<16:10,  4.80s/it]

24


 16%|█▋        | 39/240 [03:26<15:52,  4.74s/it]

210


 17%|█▋        | 40/240 [03:31<15:36,  4.68s/it]

227


 17%|█▋        | 41/240 [03:36<15:36,  4.70s/it]

47


 18%|█▊        | 42/240 [03:40<15:27,  4.69s/it]

24


 18%|█▊        | 43/240 [03:45<15:15,  4.65s/it]

77


 18%|█▊        | 44/240 [03:50<15:05,  4.62s/it]

232


 19%|█▉        | 45/240 [03:54<15:01,  4.62s/it]

38


 19%|█▉        | 46/240 [03:59<14:55,  4.62s/it]

207


 20%|█▉        | 47/240 [04:03<14:50,  4.61s/it]

77


 20%|██        | 48/240 [04:08<14:55,  4.66s/it]

181


 20%|██        | 49/240 [04:13<14:46,  4.64s/it]

146


 21%|██        | 50/240 [04:17<14:37,  4.62s/it]

188


 21%|██▏       | 51/240 [04:22<14:34,  4.62s/it]

130


 22%|██▏       | 52/240 [04:27<14:25,  4.60s/it]

28


 22%|██▏       | 53/240 [04:31<14:41,  4.71s/it]

17


 22%|██▎       | 54/240 [04:36<14:30,  4.68s/it]

53


 23%|██▎       | 55/240 [04:41<14:20,  4.65s/it]

204


 23%|██▎       | 56/240 [04:45<14:22,  4.69s/it]

155


 24%|██▍       | 57/240 [04:50<14:09,  4.64s/it]

12


 24%|██▍       | 58/240 [04:55<14:02,  4.63s/it]

190


 25%|██▍       | 59/240 [04:59<13:54,  4.61s/it]

25


 25%|██▌       | 60/240 [05:04<13:48,  4.60s/it]

134


 25%|██▌       | 61/240 [05:08<13:43,  4.60s/it]

117


 26%|██▌       | 62/240 [05:13<13:38,  4.60s/it]

82


 26%|██▋       | 63/240 [05:17<13:30,  4.58s/it]

128


 27%|██▋       | 64/240 [05:22<13:25,  4.58s/it]

125


 27%|██▋       | 65/240 [05:27<13:22,  4.58s/it]

94


 28%|██▊       | 66/240 [05:31<13:17,  4.59s/it]

205


 28%|██▊       | 67/240 [05:37<14:23,  4.99s/it]

149


 28%|██▊       | 68/240 [05:42<13:57,  4.87s/it]

143


 29%|██▉       | 69/240 [05:46<13:37,  4.78s/it]

64


 29%|██▉       | 70/240 [05:51<13:26,  4.74s/it]

164


 30%|██▉       | 71/240 [05:56<13:23,  4.76s/it]

181


 30%|███       | 72/240 [06:06<18:05,  6.46s/it]

219


 30%|███       | 73/240 [06:14<19:24,  6.97s/it]

40


 31%|███       | 74/240 [06:23<20:45,  7.50s/it]

99


 31%|███▏      | 75/240 [06:31<21:09,  7.69s/it]

201


 32%|███▏      | 76/240 [06:41<22:55,  8.39s/it]

175


 32%|███▏      | 77/240 [06:46<19:49,  7.30s/it]

226


 32%|███▎      | 78/240 [06:51<17:36,  6.52s/it]

74


 33%|███▎      | 79/240 [06:56<16:24,  6.11s/it]

181


 33%|███▎      | 80/240 [07:02<16:41,  6.26s/it]

188


 34%|███▍      | 81/240 [07:11<18:30,  6.98s/it]

43


 34%|███▍      | 82/240 [07:16<17:02,  6.47s/it]

190


 35%|███▍      | 83/240 [07:22<15:56,  6.09s/it]

0


 35%|███▌      | 84/240 [07:26<14:39,  5.64s/it]

62


 35%|███▌      | 85/240 [07:31<13:43,  5.32s/it]

205


 36%|███▌      | 86/240 [07:35<13:04,  5.09s/it]

87


 36%|███▋      | 87/240 [07:40<12:34,  4.93s/it]

93


 37%|███▋      | 88/240 [07:45<12:29,  4.93s/it]

220


 37%|███▋      | 89/240 [07:49<12:05,  4.81s/it]

122


 38%|███▊      | 90/240 [07:54<11:47,  4.72s/it]

119


 38%|███▊      | 91/240 [07:59<11:42,  4.71s/it]

85


 38%|███▊      | 92/240 [08:03<11:29,  4.66s/it]

4


 39%|███▉      | 93/240 [08:08<11:19,  4.62s/it]

82


 39%|███▉      | 94/240 [08:12<11:11,  4.60s/it]

230


 40%|███▉      | 95/240 [08:17<11:06,  4.59s/it]

137


 40%|████      | 96/240 [08:21<10:59,  4.58s/it]

195


 40%|████      | 97/240 [08:26<11:09,  4.68s/it]

141


 41%|████      | 98/240 [08:31<11:17,  4.77s/it]

31


 41%|████▏     | 99/240 [08:36<11:03,  4.71s/it]

154


 42%|████▏     | 100/240 [08:40<10:53,  4.67s/it]

159


 42%|████▏     | 101/240 [08:45<10:47,  4.66s/it]

198


 42%|████▎     | 102/240 [08:52<12:39,  5.50s/it]

150


 43%|████▎     | 103/240 [08:58<12:49,  5.62s/it]

209


 43%|████▎     | 104/240 [09:04<12:31,  5.53s/it]

21


 44%|████▍     | 105/240 [09:08<11:48,  5.25s/it]

74


 44%|████▍     | 106/240 [09:13<11:16,  5.05s/it]

70


 45%|████▍     | 107/240 [09:17<10:54,  4.92s/it]

195


 45%|████▌     | 108/240 [09:22<10:46,  4.90s/it]

111


 45%|████▌     | 109/240 [09:27<10:28,  4.80s/it]

157


 46%|████▌     | 110/240 [09:32<10:21,  4.78s/it]

160


 46%|████▋     | 111/240 [09:36<10:10,  4.73s/it]

149


 47%|████▋     | 112/240 [09:41<09:59,  4.69s/it]

173


 47%|████▋     | 113/240 [09:45<09:53,  4.67s/it]

46


 48%|████▊     | 114/240 [09:51<10:04,  4.80s/it]

116


 48%|████▊     | 115/240 [09:55<09:49,  4.72s/it]

180


 48%|████▊     | 116/240 [10:02<11:16,  5.46s/it]

11


 49%|████▉     | 117/240 [10:11<13:09,  6.42s/it]

161


 49%|████▉     | 118/240 [10:17<12:43,  6.26s/it]

133


 50%|████▉     | 119/240 [10:22<11:41,  5.80s/it]

73


 50%|█████     | 120/240 [10:30<13:13,  6.61s/it]

15


 50%|█████     | 121/240 [10:40<14:49,  7.47s/it]

115


 51%|█████     | 122/240 [10:45<13:25,  6.83s/it]

38


 51%|█████▏    | 123/240 [10:49<12:00,  6.16s/it]

1


 52%|█████▏    | 124/240 [10:54<10:58,  5.68s/it]

85


 52%|█████▏    | 125/240 [10:59<10:14,  5.34s/it]

178


 52%|█████▎    | 126/240 [11:05<10:33,  5.55s/it]

163


 53%|█████▎    | 127/240 [11:11<10:57,  5.82s/it]

27


 53%|█████▎    | 128/240 [11:16<10:09,  5.44s/it]

172


 54%|█████▍    | 129/240 [11:20<09:34,  5.17s/it]

133


 54%|█████▍    | 130/240 [11:25<09:15,  5.05s/it]

17


 55%|█████▍    | 131/240 [11:31<09:53,  5.44s/it]

58


 55%|█████▌    | 132/240 [11:37<10:09,  5.64s/it]

215


 55%|█████▌    | 133/240 [11:45<11:16,  6.33s/it]

15


 56%|█████▌    | 134/240 [11:53<12:03,  6.82s/it]

162


 56%|█████▋    | 135/240 [12:01<12:26,  7.11s/it]

1


 57%|█████▋    | 136/240 [12:07<11:32,  6.65s/it]

7


 57%|█████▋    | 137/240 [12:15<12:23,  7.22s/it]

111


 57%|█████▊    | 138/240 [12:20<10:54,  6.41s/it]

90


 58%|█████▊    | 139/240 [12:25<10:11,  6.05s/it]

13


 58%|█████▊    | 140/240 [12:30<09:21,  5.61s/it]

166


 59%|█████▉    | 141/240 [12:34<08:48,  5.34s/it]

133


 59%|█████▉    | 142/240 [12:39<08:39,  5.30s/it]

8


 60%|█████▉    | 143/240 [12:44<08:20,  5.16s/it]

68


 60%|██████    | 144/240 [12:52<09:26,  5.90s/it]

216


 60%|██████    | 145/240 [12:58<09:21,  5.91s/it]

76


 61%|██████    | 146/240 [13:10<12:08,  7.75s/it]

116


 61%|██████▏   | 147/240 [13:22<14:11,  9.16s/it]

20


 62%|██████▏   | 148/240 [13:35<15:34, 10.16s/it]

74


 62%|██████▏   | 149/240 [13:47<16:27, 10.86s/it]

132


 62%|██████▎   | 150/240 [14:00<16:59, 11.33s/it]

81


 63%|██████▎   | 151/240 [14:12<17:18, 11.67s/it]

135


 63%|██████▎   | 152/240 [14:25<17:26, 11.89s/it]

1


 64%|██████▍   | 153/240 [14:34<16:21, 11.28s/it]

39


 64%|██████▍   | 154/240 [14:39<13:17,  9.27s/it]

159


 65%|██████▍   | 155/240 [14:45<11:36,  8.19s/it]

87


 65%|██████▌   | 156/240 [14:55<12:11,  8.70s/it]

58


 65%|██████▌   | 157/240 [15:05<12:47,  9.25s/it]

31


 66%|██████▌   | 158/240 [15:16<13:08,  9.62s/it]

87


 66%|██████▋   | 159/240 [15:27<13:47, 10.22s/it]

216


 67%|██████▋   | 160/240 [15:38<13:42, 10.28s/it]

104


 67%|██████▋   | 161/240 [15:48<13:40, 10.39s/it]

117


 68%|██████▊   | 162/240 [15:59<13:37, 10.48s/it]

71


 68%|██████▊   | 163/240 [16:12<14:16, 11.12s/it]

40


 68%|██████▊   | 164/240 [16:24<14:38, 11.56s/it]

24


 69%|██████▉   | 165/240 [16:37<14:49, 11.87s/it]

217


 69%|██████▉   | 166/240 [16:49<14:52, 12.06s/it]

152


 70%|██████▉   | 167/240 [16:54<11:57,  9.84s/it]

111


 70%|███████   | 168/240 [17:01<10:46,  8.98s/it]

81


 70%|███████   | 169/240 [17:13<11:38,  9.84s/it]

24


 71%|███████   | 170/240 [17:19<10:13,  8.76s/it]

170


 71%|███████▏  | 171/240 [17:30<10:57,  9.53s/it]

128


 72%|███████▏  | 172/240 [17:37<09:47,  8.63s/it]

56


 72%|███████▏  | 173/240 [17:42<08:37,  7.72s/it]

186


 72%|███████▎  | 174/240 [17:47<07:34,  6.89s/it]

210


 73%|███████▎  | 175/240 [17:53<07:03,  6.52s/it]

89


 73%|███████▎  | 176/240 [17:58<06:22,  5.97s/it]

20


 74%|███████▍  | 177/240 [18:08<07:28,  7.12s/it]

220


 74%|███████▍  | 178/240 [18:20<08:56,  8.66s/it]

209


 75%|███████▍  | 179/240 [18:30<09:21,  9.20s/it]

119


 75%|███████▌  | 180/240 [18:39<09:04,  9.08s/it]

9


 75%|███████▌  | 181/240 [18:50<09:27,  9.62s/it]

64


 75%|███████▌  | 181/240 [18:52<06:09,  6.26s/it]


KeyboardInterrupt: 

In [None]:
# 한 배치안에서 Negative 섞고, BM25, 
rank_doc_idx

In [135]:
answer_rank_list

[]

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer

tfidfv = TfidfVectorizer(tokenizer=tokenizer.tokenize, ngram_range=(1, 2))#, max_features=50000)
p_embedding = tfidfv.fit_transform(corpus)

def get_topk_similarity(qeury_vec, k):
    result = qeury_vec * p_embedding.T
    result = result.toarray()

    doc_scores3 = np.partition(result, -k)[:, -k:][:, ::-1]
    ind = np.argsort(doc_scores3, axis=-1)[:, ::-1]
    doc_scores3 = np.sort(doc_scores3, axis=-1)[:, ::-1]
    doc_indices3 = np.argpartition(result, -k)[:, -k:][:, ::-1]
    r, c = ind.shape
    ind = ind + np.tile(np.arange(r).reshape(-1, 1), (1, c)) * c
    doc_indices3 = doc_indices3.ravel()[ind].reshape(r, c)

    return doc_scores3, doc_indices3

In [134]:
rank

tensor([ 23, 125, 124,  47,  76,  55,  40, 195, 160, 208,  54, 110,  93, 233,
        174, 232,  44, 158, 228,  12,  84,  41,  26, 227, 169, 101,  92, 216,
         62,  33,  85,  96, 191, 103, 177,  88,   7, 139,  14,   5,  17, 221,
        100,  13,  43, 218, 154, 157,  29,  60, 179, 176, 171,  64, 187, 196,
        204, 202,  19, 126,  50, 234, 105, 140, 225,  71, 188, 133,  87, 152,
        224,  10,  11,  99,  49,  39, 181,  72,  90, 127,  31, 143, 207, 163,
         78,  28, 166,  35, 197, 150, 213, 151,  16,  30, 111,  98,  91, 148,
         82, 215, 189, 167, 209, 205, 117,  38, 200,   1,  32, 220,  83, 159,
        165,   2, 114, 129, 156, 206, 130,  95, 104,  25,  74,  18, 122, 162,
        226, 223, 173, 184,  53, 146, 109, 211,   6, 145,  86, 155,  75,  37,
         56, 170, 190,  24, 168,  80, 194,   0,  73,  89, 118,  77, 132,  65,
         66,  61, 121,  68,  48, 137,  63, 199,  46,  81, 185, 134, 210,  70,
        183, 128, 164, 147, 172, 102, 186, 144,  21, 198, 178,  

In [121]:
valid_corpus_idx[rank.tolist()]

array([55500,  5153, 23027, 52576, 11676, 40413, 11939, 55828, 29185,
       24486, 48820,  4792, 23816, 23335, 26381, 54179,  4938,  4620,
       39395, 51965, 55243, 30932, 20503,  9353, 49488,  5263, 23667,
       16171,  4959, 10976, 32363, 22337, 34686,  5238, 31555, 23695,
        5762,  5081, 23517, 44248, 29373, 38315, 16121, 25833, 25992,
        7993, 55851, 35401, 18142,  4651, 52635, 48939, 29001, 33104,
       12290, 14631,  5287,  9613, 23291,  5006,  6183, 54652, 21801,
        5188, 12748, 42759, 51565, 52707, 29979,  5256, 21174,  5159,
       37389, 30877, 45254,  4846, 54997, 34100, 37030, 31148, 49540,
        4871, 42016,  8387, 26055,  4834, 49702, 26999, 28253,  4963,
       40091, 55663,  8658, 54074, 56430,  5269, 32220, 33013,  6925,
       14415, 15653,  4832, 35175, 34822, 55399, 32778, 13676, 37140,
        5177,  5259, 37476, 25018, 46731, 51048, 32530,  9735, 13194,
       23856,  4583, 12004, 37785,  9855, 28583, 14171, 36572, 40862,
        5103, 12069,

In [120]:
print(ground_truth_doc_idx)
print(valid_corpus_idx[rank.tolist()])

11891
[ 9353 40862 36572 35401 54997 14631 29373 44463 53651 53324 12290 37476
 54074  5068 49022  5139 25992  5048  5067 23816 26055 38315 23667 21751
 18935  4832  8658  4693 21801  5238  4834 32220 55699 34822 49502 28253
 55828 30722 26381 40413  4620 42566 15653 23335 25833 49449  5617 27704
 10976  6183 37574 55273 11017 12748  4919 14801  4930  4789 51965  5103
 52635 15015 32778  8538 44764  5159 47950 42569 26999 12031 55702 48820
  4792 14415  4651 44248  5705 37389 40091 12069 22337  7714 42467 20729
 37030  4959  4842 23695 49874 55756 42568 52190  4938 32363 25018  6925
 55663 29834 42016 24286 23598 31254 16383  5264 23856 23517  6315  5153
 34686 27392  8387 21658 20988 23027 32530 11891  4948 21562 23623  5269
 55399  5263 45254 39395 28583 31426 16711 31963 30775 41755 33104  5983
  5259  4804 11939 35358 49702 53318  4846  5081  5287  8142  5298 49488
 22206 49540 51001 55500 30877  4963  4583 34100 42209 42759 51565 54652
  9855 29979 18142 21571  5188  4953 55851  4

In [118]:
rank

tensor([ 23, 125, 124,  47,  76,  55,  40, 195, 160, 208,  54, 110,  93, 233,
        174, 232,  44, 158, 228,  12,  84,  41,  26, 227, 169, 101,  92, 216,
         62,  33,  85,  96, 191, 103, 177,  88,   7, 139,  14,   5,  17, 221,
        100,  13,  43, 218, 154, 157,  29,  60, 179, 176, 171,  64, 187, 196,
        204, 202,  19, 126,  50, 234, 105, 140, 225,  71, 188, 133,  87, 152,
        224,  10,  11,  99,  49,  39, 181,  72,  90, 127,  31, 143, 207, 163,
         78,  28, 166,  35, 197, 150, 213, 151,  16,  30, 111,  98,  91, 148,
         82, 215, 189, 167, 209, 205, 117,  38, 200,   1,  32, 220,  83, 159,
        165,   2, 114, 129, 156, 206, 130,  95, 104,  25,  74,  18, 122, 162,
        226, 223, 173, 184,  53, 146, 109, 211,   6, 145,  86, 155,  75,  37,
         56, 170, 190,  24, 168,  80, 194,   0,  73,  89, 118,  77, 132,  65,
         66,  61, 121,  68,  48, 137,  63, 199,  46,  81, 185, 134, 210,  70,
        183, 128, 164, 147, 172, 102, 186, 144,  21, 198, 178,  

In [116]:
answer_rank_list[-10:]

[112, 60, 159, 37, 94, 33, 57, 57, 0, 0]

In [96]:
k = 1
print("[Search query]\n", query, "\n")
print("[Ground truth passage]")
print(ground_truth, "\n")

for i in range(k):
  print("Top-%d passage with score %.4f" % (i+1, dot_prod_scores.squeeze()[rank[i]]))
  print(valid_corpus[rank[i]])

[Search query]
 마르크스주의자들의 사상은? 

[Ground truth passage]
사회주의 혁명은 오로지 선진노동자계급에 기초한 계급투쟁으로서 이루어질 수 있다고 주장한 레닌의 노선은 본질적으로는 프랑스의 사회주의자인 오귀스트 블랑키(Auguste Blanqui)의 비밀결사주의와 동일하지 않다. 다음은 블랑키주의에 대한 레닌의 비판이다.\n블랑키주의는 계급투쟁을 긍정하는 이론이다. 그러나 블랑키주의는 프롤레타리아의 계급투쟁에 의거하지 않고 소수 인텔리겐차의 음모로써 인류가 임금노예제로부터 해방될 것을 기대한 것이다. 블랑키의 행동 지침은 부르주아 민주주의의 계급 모순을 지각한 혁명적 부르주아의 일반적인 경향이며, 노동계급에 의한 계급의식의 표출과는 무관한 것이다.|블라디미르 레닌, 『대회의 총결과에 붙여서』(1906년) \n동시에 레닌은 블랑키주의가 소수 지식인의 음모에 의한 혁명 방식이며, 소수에 의한 쿠데타와 다를 바가 없다고 비판하기도 하였다.\n\n그러나, 경제적 후진성과 러시아 정교회를 비롯한 여러 반동적 사상 조류가 극심했던 당시 러시아 사회의 특성을 고려하여, 소수 직업혁명가의 역량 확보를 강조하였고, 이 지점에서 블랑키의 사상과 밀접한 연관을 이루게 됐다. 특히 폭력혁명에 대한 긍정 및 합법 활동과 비합법 활동을 혁명의 성취라는 목적에 따라 적절히 배합해야 한다는 레닌의 주장은 블랑키의 주장과 상당히 흡사한 지점이다.\n\n특히, 당원의 지적 수련, 금욕적 생활, 사생취의(捨生取義) 정신을 강조했다는 점과, 일반적인 노동자계급과, 노동자계급을 지도하는 직업 혁명가의 뚜렷한 구분은 기존 마르크스주의자들과 달랐던 지점이다. 이러한 지점은 여러 학자들에 의해 블랑키주의의 영향을 받았다고 평가받는다. 공산주의 혁명가인 로자 룩셈부르크는 레닌과 블랑키의 차이는 지엽적이며, 본질적으로는 같은 것이라고 평가하였다. 이러한 비판은 러시아 10월 혁명이 “순수한 프롤레타리아 계급에 의한 혁명인가?”, 아니면 “소수 혁명적 지식인에 의한 쿠데타인

## sparse

In [100]:
from sklearn.feature_extraction.text import TfidfVectorizer

tfidfv = TfidfVectorizer(tokenizer=tokenizer.tokenize, ngram_range=(1, 2))#, max_features=50000)
p_embedding = tfidfv.fit_transform(corpus)

def get_topk_similarity(qeury_vec, k):
    result = qeury_vec * p_embedding.T
    result = result.toarray()

    doc_scores3 = np.partition(result, -k)[:, -k:][:, ::-1]
    ind = np.argsort(doc_scores3, axis=-1)[:, ::-1]
    doc_scores3 = np.sort(doc_scores3, axis=-1)[:, ::-1]
    doc_indices3 = np.argpartition(result, -k)[:, -k:][:, ::-1]
    r, c = ind.shape
    ind = ind + np.tile(np.arange(r).reshape(-1, 1), (1, c)) * c
    doc_indices3 = doc_indices3.ravel()[ind].reshape(r, c)

    return doc_scores3, doc_indices3

In [140]:
sparse_answer_rank_list = []
for idx in tqdm(range(len(dataset['validation']))):
    query = dataset['validation'][idx]['question']
    ground_truth = dataset['validation'][idx]['context']

    query_vec = tfidfv.transform([query])
    k=len(dataset['validation'])

    doc_scores, doc_indices = get_topk_similarity(query_vec, k)
    #print(doc_indices, doc_scores)
    # 4. rank에서 groundtruth와 비교
    ground_truth_doc_idx = np.where(corpus==ground_truth)[0].tolist()[0] # ground truth의 전체문서 idx
    #print(ground_truth_doc_idx)
    answer_rank = np.where(doc_indices[0]==ground_truth_doc_idx)[0].tolist()[0] # 정답문서의 rank
    #print(doc_indices)
    sparse_answer_rank_list.append(answer_rank)
    #break

 26%|██▋       | 63/240 [01:40<04:43,  1.60s/it]


IndexError: list index out of range

In [139]:
sparse_answer_rank_list

[1]

In [92]:
dense_rank = np.array(answer_rank_list).mean()
sparse_rank = np.array().mean()

112.05416666666666

#### Top-5개의 passage를 retrieve 하고 ground truth와 비교하기