In [1]:
import json
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from tqdm import trange

from pprint import pprint

from sklearn.feature_extraction.text import TfidfVectorizer

import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    BertModel, RobertaModel,
    BertPreTrainedModel,
    AdamW, get_linear_schedule_with_warmup,
    TrainingArguments,
)
from datasets import (
    Dataset,
    load_from_disk,
    concatenate_datasets,
)

from typing import List

In [2]:
# 난수 고정
def set_seed(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # if use multi-GPU
    random.seed(random_seed)
    np.random.seed(random_seed)
    
set_seed(42) # magic number :)

In [3]:
print ("PyTorch version:[%s]."%(torch.__version__))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print ("device:[%s]."%(device))

PyTorch version:[1.6.0].
device:[cuda:0].


## Training

In [4]:
class BertEncoder(BertPreTrainedModel):
    def __init__(self, config):
        super(BertEncoder, self).__init__(config)

        self.bert = BertModel(config)
        self.init_weights()
      
    def forward(
            self,
            input_ids, 
            attention_mask=None,
            token_type_ids=None
        ): 

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        pooled_output = outputs[1]
        return pooled_output

In [5]:
dataset = load_from_disk('/opt/ml/data/train_dataset')
train_dataset = dataset['train']

In [6]:
# Anwer
class DenseRetrieval:
    def __init__(self,
        args,
        dataset,
        tokenizer,
        p_encoder,
        q_encoder
    ):
        """
        학습과 추론에 사용될 여러 셋업을 마쳐봅시다.
        """

        self.args = args
        self.dataset = dataset

        self.tokenizer = tokenizer
        self.p_encoder = p_encoder
        self.q_encoder = q_encoder

    def train(self, args=None, tokenizer = None):
        if args is None:
            args = self.args
        if tokenizer is None :
            tokenizer = self.tokenizer

        for i in (range(len(self.dataset))) :
            if i == 0 :
                q_seqs = tokenizer(
                    self.dataset[i]['question'],
                    padding='max_length',
                    truncation = True,
                    return_tensors='pt'
                )
                p_seqs = tokenizer(
                    self.dataset[i]['context'],
                    truncation = True,
                    stride = 128,
                    padding='max_length',
                    return_overflowing_tokens=True,
                    return_offsets_mapping=True,
                    return_tensors='pt'
                )

                p_seqs.pop('overflow_to_sample_mapping')
                p_seqs.pop('offset_mapping')

                for k in q_seqs.keys() :
                    q_seqs[k] = q_seqs[k].tolist()
                    p_seqs[k] = p_seqs[k].tolist()
            else :
                tmp_q_seq = tokenizer(
                    self.dataset[i]['question'],
                    padding='max_length',
                    truncation = True,
                    return_tensors='pt')
                tmp_p_seq = tokenizer(
                    self.dataset[i]['context'],
                    truncation = True,
                    stride = 128,
                    padding='max_length',
                    return_overflowing_tokens=True,
                    return_offsets_mapping=True,
                    return_tensors='pt'
                )

                tmp_p_seq.pop('overflow_to_sample_mapping')
                tmp_p_seq.pop('offset_mapping')

                for k in tmp_p_seq.keys() :
                    tmp_p_seq[k] = tmp_p_seq[k].tolist()
                    tmp_q_seq[k] = tmp_q_seq[k].tolist()


                for j in range(len(tmp_p_seq['input_ids'])) :
                    q_seqs['input_ids'].append(tmp_q_seq['input_ids'][0])
                    q_seqs['token_type_ids'].append(tmp_q_seq['token_type_ids'][0])
                    q_seqs['attention_mask'].append(tmp_q_seq['attention_mask'][0])
                    p_seqs['input_ids'].append(tmp_p_seq['input_ids'][j])
                    p_seqs['token_type_ids'].append(tmp_p_seq['token_type_ids'][j])
                    p_seqs['attention_mask'].append(tmp_p_seq['attention_mask'][j])

        for k in q_seqs.keys() :
            q_seqs[k] = torch.tensor(q_seqs[k])
            p_seqs[k] = torch.tensor(p_seqs[k])

        train_dataset = TensorDataset(p_seqs['input_ids'], p_seqs['attention_mask'], p_seqs['token_type_ids'], 
                        q_seqs['input_ids'], q_seqs['attention_mask'], q_seqs['token_type_ids'])
        train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size, shuffle = True)

        no_decay = ["bias" ,"LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {"params": [p for n, p in self.p_encoder.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
            {"params": [p for n, p in self.p_encoder.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
            {"params": [p for n, p in self.q_encoder.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
            {"params": [p for n, p in self.q_encoder.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            # eps=args.adam_epsilon
        )

        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
        
        global_step = 0

        self.p_encoder.zero_grad()
        self.q_encoder.zero_grad()
        # torch.cuda.empty_cache()

        train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
        self.q_encoder.train()
        self.p_encoder.train()
        for epoch, _ in enumerate(train_iterator):
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            loss_value=0 # Accumulation할 때 진행
            losses = 0
            for step, batch in enumerate(epoch_iterator):
                if torch.cuda.is_available():
                    batch = tuple(t.cuda() for t in batch)

                p_inputs = {'input_ids': batch[0],
                            'attention_mask': batch[1],
                            'token_type_ids': batch[2]
                            }
                
                q_inputs = {'input_ids': batch[3],
                            'attention_mask': batch[4],
                            'token_type_ids': batch[5]}

                p_outputs = self.p_encoder(**p_inputs)  # (batch_size, emb_dim)
                q_outputs = self.q_encoder(**q_inputs)  # (batch_size, emb_dim)

                # Calculate similarity score & loss
                sim_scores = torch.matmul(q_outputs, torch.transpose(p_outputs, 0, 1))  # (batch_size, emb_dim) x (emb_dim, batch_size) = (batch_size, batch_size)

                # target: position of positive samples = diagonal element 
                # targets = torch.arange(0, args.per_device_train_batch_size).long()
                targets = torch.arange(0, len(p_inputs['input_ids'])).long()
                
                if torch.cuda.is_available():
                    targets = targets.to('cuda')
                    
                sim_scores = F.log_softmax(sim_scores, dim=1)

                loss = F.nll_loss(sim_scores, targets)
                losses += loss.item()
                if step % 100 == 0 :
                    print(f'{epoch}epoch loss: {losses/(step+1)}') # Accumulation할 경우 주석처리

                # ################ACCUMULATION###############################
                # if (step+1) % args.gradient_accumulation_steps == 0 :
                #     optimizer.step()
                #     scheduler.step()
                #     self.q_encoder.zero_grad()
                #     self.p_encoder.zero_grad()

                # losses += loss.item()
                # if (step+1) % 64 == 0:
                #     train_loss = losses / 64
                #     print(f'training loss: {train_loss:4.4}')
                #     losses = 0
                # ###########################################################
                self.q_encoder.zero_grad()
                self.p_encoder.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

                # global_step += 1
                
                #torch.cuda.empty_cache()
                del p_inputs, q_inputs

        return self.p_encoder, self.q_encoder

In [8]:
# p_encoder.cpu()
# q_encoder.cpu()
# del p_encoder
# del q_encoder
# torch.cuda.empty_cache()

## Make Q_Encoder & P_Embedding

In [7]:
args = TrainingArguments(
    output_dir="dense_retireval",
    evaluation_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=1,
    num_train_epochs=100,
    weight_decay=0.01
)
model_checkpoint = "klue/bert-base"

# 혹시 위에서 사용한 encoder가 있다면 주석처리 후 진행해주세요 (CUDA ...)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
p_encoder = BertEncoder.from_pretrained(model_checkpoint).to(args.device)
q_encoder = BertEncoder.from_pretrained(model_checkpoint).to(args.device)

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.bi

In [8]:
# Retriever는 아래와 같이 사용할 수 있도록 코드를 짜봅시다.
retriever = DenseRetrieval(
    args=args,
    dataset=train_dataset,
    tokenizer=tokenizer,
    p_encoder=p_encoder,
    q_encoder=q_encoder
)
p_encoder, q_encoder = retriever.train()

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

0epoch loss: 35.39319610595703
0epoch loss: 3.243883357042133
0epoch loss: 2.173376719212502
0epoch loss: 1.6864827290622895


Epoch:   1%|          | 1/100 [05:57<9:49:27, 357.25s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

1epoch loss: 0.14257673919200897
1epoch loss: 0.3896944067570021
1epoch loss: 0.39320520209426535
1epoch loss: 0.3737643955297075


Epoch:   2%|▏         | 2/100 [11:54<9:43:21, 357.16s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

2epoch loss: 0.17028941214084625
2epoch loss: 0.2654158337373692
2epoch loss: 0.2372228529939048
2epoch loss: 0.24242241738974799


Epoch:   3%|▎         | 3/100 [17:51<9:37:22, 357.14s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

3epoch loss: 0.23236776888370514
3epoch loss: 0.18807194980857248
3epoch loss: 0.15850694561445222
3epoch loss: 0.16960547505290033


Epoch:   4%|▍         | 4/100 [23:48<9:31:17, 357.06s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

4epoch loss: 0.0815662294626236
4epoch loss: 0.11345460191242268
4epoch loss: 0.11748132268540255
4epoch loss: 0.1244617241132692


Epoch:   5%|▌         | 5/100 [29:45<9:25:19, 357.05s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

5epoch loss: 0.009182417765259743
5epoch loss: 0.09690173946339453
5epoch loss: 0.09024142714810096
5epoch loss: 0.0979728851029073


Epoch:   6%|▌         | 6/100 [35:41<9:19:15, 356.98s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

6epoch loss: 0.04348301514983177
6epoch loss: 0.07601996329091466
6epoch loss: 0.08326804263845076
6epoch loss: 0.09045864614829974


Epoch:   7%|▋         | 7/100 [41:39<9:13:23, 357.03s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

7epoch loss: 0.009857789613306522
7epoch loss: 0.07908879441228156
7epoch loss: 0.08452791205729851
7epoch loss: 0.07508691501524455


Epoch:   8%|▊         | 8/100 [47:35<9:07:18, 356.94s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

8epoch loss: 0.002726450562477112
8epoch loss: 0.07993279260523141
8epoch loss: 0.07385686041269407
8epoch loss: 0.0726416710312456


Epoch:   9%|▉         | 9/100 [53:32<9:01:12, 356.84s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

9epoch loss: 0.005537758581340313
9epoch loss: 0.05652460201718616
9epoch loss: 0.059765428435775224
9epoch loss: 0.05569502237376824


Epoch:  10%|█         | 10/100 [59:29<8:55:07, 356.75s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

10epoch loss: 0.0005611367523670197
10epoch loss: 0.06483058178434142
10epoch loss: 0.07017238254598825
10epoch loss: 0.06472529302803036


Epoch:  11%|█         | 11/100 [1:05:25<8:49:08, 356.73s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

11epoch loss: 0.00044668660848401487
11epoch loss: 0.045631461358011216
11epoch loss: 0.05507418273606814
11epoch loss: 0.053417210214368564


Epoch:  12%|█▏        | 12/100 [1:11:22<8:43:11, 356.72s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

12epoch loss: 0.3103611171245575
12epoch loss: 0.051021352980290494
12epoch loss: 0.05644907920742384
12epoch loss: 0.05795966412422913


Epoch:  13%|█▎        | 13/100 [1:17:19<8:37:15, 356.73s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

13epoch loss: 0.1715841442346573
13epoch loss: 0.058709149139866725
13epoch loss: 0.05058319387247668
13epoch loss: 0.051578907014414276


Epoch:  14%|█▍        | 14/100 [1:23:15<8:31:11, 356.64s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

14epoch loss: 0.00015783015987835824
14epoch loss: 0.06356728319078324
14epoch loss: 0.05516583227543026
14epoch loss: 0.05405735607922925


Epoch:  15%|█▌        | 15/100 [1:29:12<8:25:30, 356.82s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

15epoch loss: 0.03195515275001526
15epoch loss: 0.04658971175549686
15epoch loss: 0.05088331224483702
15epoch loss: 0.047518344431878135


Epoch:  16%|█▌        | 16/100 [1:35:09<8:19:36, 356.86s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

16epoch loss: 0.001754599274136126
16epoch loss: 0.043921626593493045
16epoch loss: 0.044682040334652326
16epoch loss: 0.04612442840740992


Epoch:  17%|█▋        | 17/100 [1:41:06<8:13:36, 356.83s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

17epoch loss: 0.0009717377834022045
17epoch loss: 0.04291302589054155
17epoch loss: 0.03728031469138466
17epoch loss: 0.044500238042818774


Epoch:  18%|█▊        | 18/100 [1:47:02<8:07:29, 356.71s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

18epoch loss: 0.0007923158700577915
18epoch loss: 0.03182609235399567
18epoch loss: 0.030238451297059078
18epoch loss: 0.027743996135057784


Epoch:  19%|█▉        | 19/100 [1:52:59<8:01:38, 356.77s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

19epoch loss: 0.12508225440979004
19epoch loss: 0.03372115002429661
19epoch loss: 0.04312611354589309
19epoch loss: 0.03989581439419601


Epoch:  20%|██        | 20/100 [1:58:56<7:55:41, 356.76s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

20epoch loss: 0.06931550800800323
20epoch loss: 0.024844334896181746
20epoch loss: 0.025689263718333282
20epoch loss: 0.024592479893924993


Epoch:  21%|██        | 21/100 [2:04:53<7:49:35, 356.65s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

21epoch loss: 8.5825304267928e-06
21epoch loss: 0.036074695248217804
21epoch loss: 0.0394533821021179
21epoch loss: 0.03522248437197536


Epoch:  22%|██▏       | 22/100 [2:10:49<7:43:37, 356.64s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

22epoch loss: 0.0017766659148037434
22epoch loss: 0.041953617276694874
22epoch loss: 0.03072899951888785
22epoch loss: 0.033537617906976784


Epoch:  23%|██▎       | 23/100 [2:16:46<7:37:42, 356.66s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

23epoch loss: 2.168108267142088e-06
23epoch loss: 0.02783040331999221
23epoch loss: 0.03250926925827752
23epoch loss: 0.03346561095370558


Epoch:  24%|██▍       | 24/100 [2:22:42<7:31:45, 356.65s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

24epoch loss: 0.00022234527568798512
24epoch loss: 0.028236998073162517
24epoch loss: 0.03538929072201834
24epoch loss: 0.042069938391020974


Epoch:  25%|██▌       | 25/100 [2:28:39<7:25:48, 356.64s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

25epoch loss: 0.004246781580150127
25epoch loss: 0.03572068940192736
25epoch loss: 0.031668784930255454
25epoch loss: 0.02688982514945054


Epoch:  26%|██▌       | 26/100 [2:34:36<7:19:48, 356.60s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

26epoch loss: 0.00031886593205854297
26epoch loss: 0.03004251657200419
26epoch loss: 0.030311628134350982
26epoch loss: 0.02867366596734183


Epoch:  27%|██▋       | 27/100 [2:40:32<7:13:55, 356.65s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

27epoch loss: 0.01052794884890318
27epoch loss: 0.01501742076791941
27epoch loss: 0.014812172903853409
27epoch loss: 0.022863822057456445


Epoch:  28%|██▊       | 28/100 [2:46:29<7:07:59, 356.66s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

28epoch loss: 0.002032653661444783
28epoch loss: 0.025796805828519502
28epoch loss: 0.022383853532429766
28epoch loss: 0.024632183896453497


Epoch:  29%|██▉       | 29/100 [2:52:26<7:02:00, 356.62s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

29epoch loss: 5.287792009767145e-05
29epoch loss: 0.021027377392403363
29epoch loss: 0.022650906830451154
29epoch loss: 0.026712168832140743


Epoch:  30%|███       | 30/100 [2:58:22<6:56:07, 356.68s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

30epoch loss: 5.930579845880857e-06
30epoch loss: 0.018846014091053857
30epoch loss: 0.022240276518896936
30epoch loss: 0.02550958593874981


Epoch:  31%|███       | 31/100 [3:04:19<6:50:16, 356.76s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

31epoch loss: 0.4265448749065399
31epoch loss: 0.017176454848033068
31epoch loss: 0.02637294579223315
31epoch loss: 0.02174207971521401


Epoch:  32%|███▏      | 32/100 [3:10:16<6:44:18, 356.74s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

32epoch loss: 0.025244249030947685
32epoch loss: 0.021351488738937045
32epoch loss: 0.015511826716131643
32epoch loss: 0.01894679964076519


Epoch:  33%|███▎      | 33/100 [3:16:13<6:38:20, 356.73s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

33epoch loss: 0.11079074442386627
33epoch loss: 0.0303896324547681
33epoch loss: 0.03559573255358092
33epoch loss: 0.03148644917483018


Epoch:  34%|███▍      | 34/100 [3:22:09<6:32:23, 356.73s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

34epoch loss: 0.0007288837223313749
34epoch loss: 0.022973122514689758
34epoch loss: 0.022781096603638102
34epoch loss: 0.022078057279278796


Epoch:  35%|███▌      | 35/100 [3:28:07<6:26:33, 356.83s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

35epoch loss: 6.028224743204191e-05
35epoch loss: 0.030193841557463224
35epoch loss: 0.02712952971703944
35epoch loss: 0.022602538200271464


Epoch:  36%|███▌      | 36/100 [3:34:03<6:20:36, 356.82s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

36epoch loss: 1.1183023161720484e-05
36epoch loss: 0.023286198280565672
36epoch loss: 0.029149372182569185
36epoch loss: 0.03018232734941225


Epoch:  37%|███▋      | 37/100 [3:40:00<6:14:42, 356.87s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

37epoch loss: 0.022592470049858093
37epoch loss: 0.03579067936689261
37epoch loss: 0.025919022138415262
37epoch loss: 0.024457110203835054


Epoch:  38%|███▊      | 38/100 [3:45:57<6:08:42, 356.81s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

38epoch loss: 8.418972356594168e-06
38epoch loss: 0.013844939548042782
38epoch loss: 0.018069094309144145
38epoch loss: 0.016666568694155212


Epoch:  39%|███▉      | 39/100 [3:51:54<6:02:42, 356.76s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

39epoch loss: 0.0005753502482548356
39epoch loss: 0.02032099511873162
39epoch loss: 0.013561226490790983
39epoch loss: 0.014842836601491492


Epoch:  40%|████      | 40/100 [3:57:50<5:56:41, 356.70s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

40epoch loss: 0.22114665806293488
40epoch loss: 0.02536685206026653
40epoch loss: 0.022146871364051846
40epoch loss: 0.02239727534360432


Epoch:  41%|████      | 41/100 [4:03:47<5:50:38, 356.59s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

41epoch loss: 0.44080060720443726
41epoch loss: 0.023907745863711496
41epoch loss: 0.018671829952004837
41epoch loss: 0.018440303999728975


Epoch:  42%|████▏     | 42/100 [4:09:43<5:44:40, 356.56s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

42epoch loss: 0.0033354992046952248
42epoch loss: 0.021495893380451415
42epoch loss: 0.017225594448554593
42epoch loss: 0.015173684715694207


Epoch:  43%|████▎     | 43/100 [4:15:39<5:38:42, 356.54s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

43epoch loss: 9.72132183960639e-05
43epoch loss: 0.009981770692558772
43epoch loss: 0.012189156998017632
43epoch loss: 0.014172189932389964


Epoch:  44%|████▍     | 44/100 [4:21:36<5:32:46, 356.55s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

44epoch loss: 4.097751570952823e-06
44epoch loss: 0.015440415663598436
44epoch loss: 0.012591564191778567
44epoch loss: 0.015812732404720878


Epoch:  45%|████▌     | 45/100 [4:27:33<5:26:58, 356.70s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

45epoch loss: 0.006328011862933636
45epoch loss: 0.016901549618160146
45epoch loss: 0.018870257313610556
45epoch loss: 0.01995930403125987


Epoch:  46%|████▌     | 46/100 [4:33:30<5:21:04, 356.75s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

46epoch loss: 0.026788605377078056
46epoch loss: 0.007131962530194761
46epoch loss: 0.009885149306807594
46epoch loss: 0.010367634995696847


Epoch:  47%|████▋     | 47/100 [4:39:26<5:15:03, 356.67s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

47epoch loss: 0.07060685008764267
47epoch loss: 0.014585675343028437
47epoch loss: 0.012464364844599297
47epoch loss: 0.014585889259094579


Epoch:  48%|████▊     | 48/100 [4:45:23<5:09:06, 356.66s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

48epoch loss: 0.002968807006254792
48epoch loss: 0.018274380446083604
48epoch loss: 0.01852076654593269
48epoch loss: 0.018723099475835357


Epoch:  49%|████▉     | 49/100 [4:51:22<5:03:39, 357.25s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

49epoch loss: 0.00046213489258661866
49epoch loss: 0.02309348977498735
49epoch loss: 0.024156668637147615
49epoch loss: 0.019302991901847092


Epoch:  50%|█████     | 50/100 [4:57:23<4:58:50, 358.61s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

50epoch loss: 9.215995305567048e-06
50epoch loss: 0.012157130961275533
50epoch loss: 0.013978314076188783
50epoch loss: 0.014456840683113075


Epoch:  51%|█████     | 51/100 [5:03:22<4:52:43, 358.45s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

51epoch loss: 0.000621188897639513
51epoch loss: 0.010451428877600377
51epoch loss: 0.009395310493272092
51epoch loss: 0.0092050746212156


Epoch:  52%|█████▏    | 52/100 [5:09:18<4:46:22, 357.97s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

52epoch loss: 0.0011346102692186832
52epoch loss: 0.008582994913750792
52epoch loss: 0.0073311530611724195
52epoch loss: 0.009042983476693712


Epoch:  53%|█████▎    | 53/100 [5:15:15<4:40:05, 357.57s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

53epoch loss: 2.8590831789188087e-05
53epoch loss: 0.013039367941361757
53epoch loss: 0.010846125843793667
53epoch loss: 0.014180297264112598


Epoch:  54%|█████▍    | 54/100 [5:21:12<4:33:58, 357.35s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

54epoch loss: 0.0009172240388579667
54epoch loss: 0.012890580240497624
54epoch loss: 0.013368242482804967
54epoch loss: 0.013160417883160998


Epoch:  55%|█████▌    | 55/100 [5:27:09<4:27:51, 357.14s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

55epoch loss: 2.1084965737827588e-06
55epoch loss: 0.011801495490867115
55epoch loss: 0.01609066063703027
55epoch loss: 0.01439715460100117


Epoch:  56%|█████▌    | 56/100 [5:33:05<4:21:47, 356.99s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

56epoch loss: 8.783033263171092e-05
56epoch loss: 0.012300895889344776
56epoch loss: 0.014693141529133888
56epoch loss: 0.014482006299313864


Epoch:  57%|█████▋    | 57/100 [5:39:02<4:15:50, 356.99s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

57epoch loss: 0.008025750517845154
57epoch loss: 0.014412141275730022
57epoch loss: 0.012663103459233152
57epoch loss: 0.010835852126429788


Epoch:  58%|█████▊    | 58/100 [5:44:59<4:09:49, 356.89s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

58epoch loss: 2.3206855985336006e-05
58epoch loss: 0.013933425163376643
58epoch loss: 0.010756715130003352
58epoch loss: 0.01367606530605991


Epoch:  59%|█████▉    | 59/100 [5:50:56<4:03:53, 356.92s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

59epoch loss: 0.0004552770988084376
59epoch loss: 0.010979307056115346
59epoch loss: 0.012174006712711506
59epoch loss: 0.010726030364415946


Epoch:  60%|██████    | 60/100 [5:56:52<3:57:53, 356.84s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

60epoch loss: 2.5406163786101388e-06
60epoch loss: 0.006252927769096324
60epoch loss: 0.009849969532928042
60epoch loss: 0.012464416208467168


Epoch:  61%|██████    | 61/100 [6:02:49<3:51:56, 356.84s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

61epoch loss: 0.07620466500520706
61epoch loss: 0.008754743004812321
61epoch loss: 0.009875996442690944
61epoch loss: 0.008587858806438378


Epoch:  62%|██████▏   | 62/100 [6:08:46<3:45:54, 356.69s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

62epoch loss: 0.0009512215037830174
62epoch loss: 0.009026762693605856
62epoch loss: 0.01189593839322538
62epoch loss: 0.011363464990415715


Epoch:  63%|██████▎   | 63/100 [6:14:42<3:39:57, 356.68s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

63epoch loss: 1.4311930499388836e-05
63epoch loss: 0.006843209009279114
63epoch loss: 0.006698178292505266
63epoch loss: 0.005310386319427134


Epoch:  64%|██████▍   | 64/100 [6:20:39<3:33:59, 356.66s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

64epoch loss: 2.375686017330736e-05
64epoch loss: 0.01034412304963309
64epoch loss: 0.006857331537436794
64epoch loss: 0.011631562610084387


Epoch:  65%|██████▌   | 65/100 [6:26:36<3:28:08, 356.82s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

65epoch loss: 0.0021835819352418184
65epoch loss: 0.004564185298538165
65epoch loss: 0.007312638220784309
65epoch loss: 0.008651709787572128


Epoch:  66%|██████▌   | 66/100 [6:32:33<3:22:08, 356.71s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

66epoch loss: 0.14830544590950012
66epoch loss: 0.02399487359346186
66epoch loss: 0.020519670800203938
66epoch loss: 0.020011595591115657


Epoch:  67%|██████▋   | 67/100 [6:38:29<3:16:09, 356.66s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

67epoch loss: 0.0020324254874140024
67epoch loss: 0.0040632546378425545
67epoch loss: 0.014513575684162878
67epoch loss: 0.010515102431127


Epoch:  68%|██████▊   | 68/100 [6:44:26<3:10:13, 356.66s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

68epoch loss: 2.3474935005651787e-05
68epoch loss: 0.013537304328253293
68epoch loss: 0.013325887709745434
68epoch loss: 0.01230423027390942


Epoch:  69%|██████▉   | 69/100 [6:50:22<3:04:16, 356.65s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

69epoch loss: 0.0006812261417508125
69epoch loss: 0.009420982704013887
69epoch loss: 0.010411615962064177
69epoch loss: 0.011440455737990404


Epoch:  70%|███████   | 70/100 [6:56:19<2:58:18, 356.61s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

70epoch loss: 0.00026092157349921763
70epoch loss: 0.008347232330556065
70epoch loss: 0.011130682491423105
70epoch loss: 0.01041879357192312


Epoch:  71%|███████   | 71/100 [7:02:16<2:52:21, 356.60s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

71epoch loss: 0.001316916779614985
71epoch loss: 0.009080047972379744
71epoch loss: 0.009007710044478544
71epoch loss: 0.009187731425991351


Epoch:  72%|███████▏  | 72/100 [7:08:12<2:46:24, 356.57s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

72epoch loss: 1.188291935250163e-05
72epoch loss: 0.006873730956869221
72epoch loss: 0.01169566788974603
72epoch loss: 0.010540293116871892


Epoch:  73%|███████▎  | 73/100 [7:14:08<2:40:26, 356.54s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

73epoch loss: 2.4065211619017646e-06
73epoch loss: 0.00192479604117468
73epoch loss: 0.005166272600203425
73epoch loss: 0.009532459291711523


Epoch:  74%|███████▍  | 74/100 [7:20:05<2:34:32, 356.63s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

74epoch loss: 2.349782698729541e-05
74epoch loss: 0.006113568341168565
74epoch loss: 0.005691779866500372
74epoch loss: 0.005781932081436441


Epoch:  75%|███████▌  | 75/100 [7:26:02<2:28:37, 356.71s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

75epoch loss: 2.682205888504541e-07
75epoch loss: 0.01493366365991761
75epoch loss: 0.009926165075911153
75epoch loss: 0.008746717503605882


Epoch:  76%|███████▌  | 76/100 [7:31:59<2:22:40, 356.71s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

76epoch loss: 0.00027048191986978054
76epoch loss: 0.017391372228845913
76epoch loss: 0.015155583149952793
76epoch loss: 0.014257395284670253


Epoch:  77%|███████▋  | 77/100 [7:37:55<2:16:40, 356.56s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

77epoch loss: 0.00026504756533540785
77epoch loss: 0.010703559979106103
77epoch loss: 0.010347327122489314
77epoch loss: 0.00808779802252447


Epoch:  78%|███████▊  | 78/100 [7:43:52<2:10:45, 356.63s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

78epoch loss: 5.95281971982331e-06
78epoch loss: 0.0042858690790902945
78epoch loss: 0.005411402890567231
78epoch loss: 0.005323899191289999


Epoch:  79%|███████▉  | 79/100 [7:49:49<2:04:50, 356.68s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

79epoch loss: 0.0035857995972037315
79epoch loss: 0.003752878512780779
79epoch loss: 0.00825307279912414
79epoch loss: 0.010012672679430564


Epoch:  80%|████████  | 80/100 [7:55:46<1:58:54, 356.75s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

80epoch loss: 0.0009992034174501896
80epoch loss: 0.016158743702736404
80epoch loss: 0.010975904005322466
80epoch loss: 0.009314691196295539


Epoch:  81%|████████  | 81/100 [8:01:42<1:52:58, 356.77s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

81epoch loss: 2.2253167117014527e-05
81epoch loss: 0.0029503987796734165
81epoch loss: 0.0037617391629246595
81epoch loss: 0.00573573632264608


Epoch:  82%|████████▏ | 82/100 [8:07:39<1:47:01, 356.74s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

82epoch loss: 0.05940476432442665
82epoch loss: 0.009497276458354449
82epoch loss: 0.011910205123177288
82epoch loss: 0.013528129326224136


Epoch:  83%|████████▎ | 83/100 [8:13:36<1:41:05, 356.79s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

83epoch loss: 8.406581036979333e-05
83epoch loss: 0.006916501694176716
83epoch loss: 0.008287732846699624
83epoch loss: 0.006375618461863097


Epoch:  84%|████████▍ | 84/100 [8:19:33<1:35:09, 356.83s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

84epoch loss: 0.0036184529308229685
84epoch loss: 0.010796113535236025
84epoch loss: 0.012968245521910938
84epoch loss: 0.011050193588358804


Epoch:  85%|████████▌ | 85/100 [8:25:30<1:29:13, 356.87s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

85epoch loss: 8.359278581338003e-06
85epoch loss: 0.006380609580334337
85epoch loss: 0.005678963220112475
85epoch loss: 0.006752799188478435


Epoch:  86%|████████▌ | 86/100 [8:31:26<1:23:14, 356.78s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

86epoch loss: 2.1895257305004634e-05
86epoch loss: 0.01042900415934965
86epoch loss: 0.010740225140586939
86epoch loss: 0.01029672237411109


Epoch:  87%|████████▋ | 87/100 [8:37:23<1:17:18, 356.80s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

87epoch loss: 8.72498276294209e-05
87epoch loss: 0.00369717407488904
87epoch loss: 0.005884581596617906
87epoch loss: 0.006489317573810208


Epoch:  88%|████████▊ | 88/100 [8:43:20<1:11:21, 356.82s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

88epoch loss: 0.0003015085239894688
88epoch loss: 0.0049172734510747355
88epoch loss: 0.007924934587423342
88epoch loss: 0.008808387897099567


Epoch:  89%|████████▉ | 89/100 [8:49:17<1:05:25, 356.88s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

89epoch loss: 0.0002589902433101088
89epoch loss: 0.005663758717077198
89epoch loss: 0.010319694364893984
89epoch loss: 0.009910299245314527


Epoch:  90%|█████████ | 90/100 [8:55:14<59:29, 356.94s/it]  




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

90epoch loss: 0.0027719817589968443
90epoch loss: 0.004387489962560735
90epoch loss: 0.006176622895610315
90epoch loss: 0.009526808566880788


Epoch:  91%|█████████ | 91/100 [9:01:12<53:33, 357.05s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

91epoch loss: 2.3913375116535462e-05
91epoch loss: 0.017218774280435566
91epoch loss: 0.01155098235795608
91epoch loss: 0.008537613594585634


Epoch:  92%|█████████▏| 92/100 [9:07:08<47:35, 356.97s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

92epoch loss: 0.002006939146667719
92epoch loss: 0.006950330690184535
92epoch loss: 0.004992705903152325
92epoch loss: 0.008073642565911544


Epoch:  93%|█████████▎| 93/100 [9:13:06<41:39, 357.04s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

93epoch loss: 0.00013216132356319577
93epoch loss: 0.010263327874171599
93epoch loss: 0.00801541533596787
93epoch loss: 0.007002642305260236


Epoch:  94%|█████████▍| 94/100 [9:19:03<35:42, 357.07s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

94epoch loss: 0.1555933654308319
94epoch loss: 0.012265588291598977
94epoch loss: 0.011309879525147063
94epoch loss: 0.009136508692996845


Epoch:  95%|█████████▌| 95/100 [9:25:00<29:45, 357.10s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

95epoch loss: 9.365096957481e-06
95epoch loss: 0.006802244529244445
95epoch loss: 0.004807530872612765
95epoch loss: 0.006966427668955271


Epoch:  96%|█████████▌| 96/100 [9:30:57<23:48, 357.01s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

96epoch loss: 5.05391217302531e-05
96epoch loss: 0.00620603459966847
96epoch loss: 0.006947390603406277
96epoch loss: 0.0068761742412821614


Epoch:  97%|█████████▋| 97/100 [9:36:53<17:50, 356.90s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

97epoch loss: 1.4788811313337646e-05
97epoch loss: 0.010660536323397612
97epoch loss: 0.008831238453388972
97epoch loss: 0.008319933481140936


Epoch:  98%|█████████▊| 98/100 [9:42:50<11:53, 356.95s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

98epoch loss: 6.757536539225839e-06
98epoch loss: 0.009050257815625056
98epoch loss: 0.005678103845493344
98epoch loss: 0.004424278879769515


Epoch:  99%|█████████▉| 99/100 [9:48:47<05:56, 356.97s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

99epoch loss: 0.00036039017140865326
99epoch loss: 0.002426564289057278
99epoch loss: 0.004622302290174706
99epoch loss: 0.0033815273877227307


Epoch: 100%|██████████| 100/100 [9:54:44<00:00, 356.85s/it]







In [9]:
with open('/opt/ml/data/wikipedia_documents.json', "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = list(
    dict.fromkeys([v["text"] for v in wiki.values()])
)  # set 은 매번 순서가 바뀌므로

In [10]:
# p_encoder = retriever.p_encoder
# q_encoder = retriever.q_encoder
with torch.no_grad() :
    p_encoder.eval()

    p_embs = []
    for p in tqdm(corpus) :
        p = tokenizer(p, padding='max_length', truncation=True, return_tensors='pt').to('cuda')
        p_emb = p_encoder(**p).to('cpu').numpy()
        p_embs.append(p_emb)
p_embs = torch.Tensor(p_embs).squeeze()

HBox(children=(FloatProgress(value=0.0, max=56737.0), HTML(value='')))




In [11]:
import pickle

file_path = '/opt/ml/custom/passage_embedding_special_shuffle_100.bin'
with open(file_path, 'wb') as file :
    pickle.dump(p_embs, file)

In [12]:
p_encoder.cpu()
del p_encoder
torch.cuda.empty_cache()

In [13]:
torch.save(q_encoder, '/opt/ml/custom/q_encoder_special_shuffle_100.pt')

## Get Relavant Documnet

In [44]:
with open('/opt/ml/data/wikipedia_documents.json', "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = list(
    dict.fromkeys([v["text"] for v in wiki.values()])
)  # set 은 매번 순서가 바뀌므로

In [6]:
model_checkpoint = "klue/bert-base"

# 혹시 위에서 사용한 encoder가 있다면 주석처리 후 진행해주세요 (CUDA ...)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [19]:
import pickle
with open('/opt/ml/custom/passage_embedding.bin', 'rb') as file :
    p_embs = pickle.load(file)
p_embs = p_embs

In [20]:
q_encoder = torch.load('/opt/ml/custom/q_encoder.pt')

In [14]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

### 1개 확인

In [21]:
query = dataset['validation']['question'][0]
queries = dataset['validation']['question'][:2]

In [22]:
with torch.no_grad() :
    q_encoder.eval()
    q_seqs_val = tokenizer([query], padding="max_length", truncation=True, return_tensors="pt").to(device)
    q_emb = q_encoder(**q_seqs_val).to('cpu')

In [32]:
dot_prod_scores = torch.mm(q_emb, p_embs.T)
dot_prod_scores

tensor([[159.0737, 155.1720, 158.0787,  ..., 156.8211, 155.6962, 156.6308]])

In [33]:
dot_prod_scores = torch.mm(q_emb, p_embs.T)
rank = torch.argsort(dot_prod_scores, dim=1, descending=True)

In [34]:
rank

tensor([[ 5694,  6509,  9728,  ..., 20579,  4294,  4291]])

### 여러개

In [35]:
with torch.no_grad() :
    q_encoder.eval()
    q_seqs_val = tokenizer(queries, padding="max_length", truncation=True, return_tensors="pt").to(device)
    q_emb = q_encoder(**q_seqs_val).to('cpu')

In [47]:
dot_prod_scores = torch.mm(q_emb, p_embs.T)
sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

In [49]:
scores = sort_result[0]
ranks = sort_result[1]

In [None]:
k = 5
print("[Search query]\n", query, "\n")
print("[Ground truth passage]")
print(dataset['validation']['context'][0], "\n")

for i in range(k):
  print("Top-%d passage with score %.4f" % (i+1, scores[0].squeeze()[i]))
  print(corpus[ranks[0][i]])

### 연산 시작 - Batch_size:16/epoch:5

In [15]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [20]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 50)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_50 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [21]:
correct_length = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length.append(i)
print(len(correct_length) / len(dataset['validation']))

0.6333333333333333


### 연산 시작 - Batch_size:16/epoch:10

In [16]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [17]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 50)

In [19]:
total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [20]:
correct_length = []
for i in range(len(cqas)) :
    if cqas['original_context'][i] in cqas['context'][i] :
        correct_length.append(i)
print(len(correct_length) / len(dataset['validation']))

0.6625


In [21]:
cqas.to_csv('/opt/ml/custom/valid_dpr_b16_e10_t50.csv')

In [30]:
correct_length_50 = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length_50.append(i)
print(len(correct_length_50) / len(dataset['validation']))

0.6625


In [39]:
correct_length_100 = []
for i in range(len(cqas_100)) :
    if cqas_100['original_context'][i] in cqas_100['context'][i] :
        correct_length_100.append(i)
print(len(correct_length_100) / len(dataset['validation']))

0.7416666666666667


In [53]:
correct_length_150 = []
for i in range(len(cqas_150)) :
    if cqas_150['original_context'][i] in cqas_150['context'][i] :
        correct_length_150.append(i)
print(len(correct_length_150) / len(dataset['validation']))

0.7625


In [56]:
correct_length_200 = []
for i in range(len(cqas_200)) :
    if cqas_200['original_context'][i] in cqas_200['context'][i] :
        correct_length_200.append(i)
print(len(correct_length_200) / len(dataset['validation']))

0.825


In [75]:
cqas_200.to_csv('/opt/ml/custom/valid_dpr_200.csv', index = False)

In [50]:
correct_length_300 = []
for i in range(len(cqas_300)) :
    if cqas_300['original_context'][i] in cqas_300['context'][i] :
        correct_length_300.append(i)
print(len(correct_length_200) / len(dataset['validation']))

0.825


In [26]:
cqas.to_csv('/opt/ml/custom/valid_dpr.csv', index = False)

### 연산 시작 - Batch_size:16/epoch:40

In [14]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [31]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 200)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_50 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [32]:
correct_length = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length.append(i)
print(len(correct_length) / len(dataset['validation']))

0.85


In [33]:
cqas_50.to_csv('/opt/ml/custom/valid_dpr_b16_e40_t200.csv', index = False)

### 연산 시작 - Batch_size:16/epoch:100

In [15]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [26]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 200)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_50 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [27]:
correct_length = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length.append(i)
print(len(correct_length) / len(dataset['validation']))

0.8166666666666667


In [28]:
cqas_50.to_csv('/opt/ml/custom/valid_dpr_b16_e100_t200.csv', index = False)

### 연산 시작 - Batch_size 128/epoch:10


In [14]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [18]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 100)

In [19]:
total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_50 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [20]:
correct_length_50 = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length_50.append(i)
print(len(correct_length_50) / len(dataset['validation']))

0.6083333333333333


In [None]:
correct_length_50 = []
for i in range(len(cqas_20)) :
    if cqas_20['original_context'][i] in cqas_20['context'][i] :
        correct_length_20.append(i)
print(len(correct_length_20) / len(dataset['validation']))

### 연산 시작 - Batch_size 128/epoch:5


In [15]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [16]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 50)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_50 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [17]:
correct_length_50 = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length_50.append(i)
print(len(correct_length_50) / len(dataset['validation']))

0.30416666666666664


### 연산 시작 - Batch_size 128/epoch:20


In [16]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [21]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 20)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_100 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [22]:
correct_length_100 = []
for i in range(len(cqas_100)) :
    if cqas_100['original_context'][i] in cqas_100['context'][i] :
        correct_length_100.append(i)
print(len(correct_length_100) / len(dataset['validation']))

0.5041666666666667


In [19]:
len(correct_length_100)

120

### 연산 시작 - Batch_size 16/epoch:10/Special/NoShuffle


In [18]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [21]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 100)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_100 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [22]:
correct_length_100 = []
for i in range(len(cqas_100)) :
    if cqas_100['original_context'][i] in cqas_100['context'][i] :
        correct_length_100.append(i)
print(len(correct_length_100) / len(dataset['validation']))

0.6916666666666667


### 연산 시작 - Batch_size 16/epoch:10/Special/Shuffle


In [19]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 100)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_100 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [20]:
correct_length_100 = []
for i in range(len(cqas_100)) :
    if cqas_100['original_context'][i] in cqas_100['context'][i] :
        correct_length_100.append(i)
print(len(correct_length_100) / len(dataset['validation']))

0.775


### 연산 시작 - Batch_size 16/epoch:100/Special/Shuffle

In [17]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 50)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_100 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [18]:
correct_length_100 = []
for i in range(len(cqas_100)) :
    if cqas_100['original_context'][i] in cqas_100['context'][i] :
        correct_length_100.append(i)
print(len(correct_length_100) / len(dataset['validation']))

0.7458333333333333
