In [75]:
# !pip install torch==1.7.1
# !pip install transformers==4.11.3
# !pip install huggingface-hub==0.0.19
# !pip install datasets==1.5.0

In [1]:
import json
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from tqdm import trange

from pprint import pprint

from sklearn.feature_extraction.text import TfidfVectorizer

import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    BertModel, RobertaModel,
    BertPreTrainedModel,
    AdamW, get_linear_schedule_with_warmup,
    TrainingArguments,
)
from datasets import (
    Dataset,
    load_from_disk,
    concatenate_datasets,
)

from typing import List
from torch.utils.data import Sampler

In [2]:
# 난수 고정
def set_seed(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # if use multi-GPU
    random.seed(random_seed)
    np.random.seed(random_seed)
    
set_seed(42) # magic number :)

In [3]:
print ("PyTorch version:[%s]."%(torch.__version__))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print ("device:[%s]."%(device))

PyTorch version:[1.7.1].
device:[cuda:0].


## Training

In [4]:
class BertEncoder(BertPreTrainedModel):
    def __init__(self, config):
        super(BertEncoder, self).__init__(config)

        self.bert = BertModel(config)
        self.init_weights()
        classifier_dropout=(
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = torch.nn.Dropout(classifier_dropout)
        self.linear = torch.nn.Linear(config.hidden_size, 1)
      
    def forward(
            self,
            input_ids, 
            attention_mask=None,
            token_type_ids=None
        ): 

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        output = self.linear(pooled_output)
        return output

In [5]:
dataset = load_from_disk('/opt/ml/data/train_dataset')
train_dataset = dataset['train']

In [6]:
class CustomSampler(Sampler) :
    def __init__(self, data_source, batch_size) :
        self.data_source = data_source
        self.batch_size = batch_size

    def __iter__(self) :
        n = len(self.data_source)
        index_list = []
        while True :
            out = True
            for i in range(self.batch_size) :
                tmp_data = random.randint(0, n-1)
                index_list.append(tmp_data)
            for f, s in zip(index_list, index_list[1:]) :
                if abs(s-f) <= 2 :
                    out = False
            if out == True :
                break

        while True : # 추가 삽입
            tmp_data = random.randint(0, n-1)
            if (tmp_data not in index_list) and \
                (abs(tmp_data-index_list[-i]) > 2 for i in range(1,self.batch_size+1)) \
            : 
                index_list.append(tmp_data)
            if len(index_list) == n :
                break
        return iter(index_list)

    def __len__(self) :
        return len(self.data_source)

In [7]:
# Anwer
class DenseRetrieval:
    def __init__(self,
        args,
        dataset,
        tokenizer,
        cross_encoder,
        sampler
    ):
        """
        학습과 추론에 사용될 여러 셋업을 마쳐봅시다.
        """

        self.args = args
        self.dataset = dataset

        self.tokenizer = tokenizer
        self.cross_encoder = cross_encoder
        self.sampler = sampler

    def train(self, args=None, tokenizer = None):
        if args is None:
            args = self.args
        if tokenizer is None :
            tokenizer = self.tokenizer
        
        tokenized_examples = tokenizer(
            self.dataset['question'],
            self.dataset['context'],
            truncation="only_second",
            max_length=512,
            stride=128,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            # return_token_type_ids=False,  # roberta모델을 사용할 경우 False, bert를 사용할 경우 True로 표기해야합니다.
            padding="max_length",
            return_tensors='pt'
        )

        train_dataset = TensorDataset(
            tokenized_examples['input_ids'],
            tokenized_examples['attention_mask'],
            tokenized_examples['token_type_ids']
        )

        sampler = self.sampler(train_dataset, args.per_device_train_batch_size)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.per_device_train_batch_size,
                                      sampler = sampler,
                                      drop_last = True)
                                      
        no_decay = ["bias" ,"LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {"params": [p for n, p in self.cross_encoder.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
            {"params": [p for n, p in self.cross_encoder.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            # eps=args.adam_epsilon
        )

        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
        
        self.cross_encoder.zero_grad()
        
        train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
        self.cross_encoder.train()
        for epoch, _ in enumerate(train_iterator) :
            epoch_iterator = tqdm(train_dataloader, desc = 'Iteration')
            losses = 0
            for step, batch in enumerate(epoch_iterator) :
                # if torch.cuda.is_available() :
                #     batch = tuple(t.cuda() for t in batch)
                
                cross_inputs = {
                    'input_ids': batch[0],
                    'attention_mask' : batch[1],
                    'token_type_ids' : batch[2]
                }
                for k in cross_inputs.keys() :
                    cross_inputs[k] = cross_inputs[k].tolist()

                new_input_ids = []
                new_attention_mask = []
                new_token_type_ids = []
                for i in range(len(cross_inputs['input_ids'])) :
                    sep_index = cross_inputs['input_ids'][i].index(3) # [SEP] token의 index

                    for j in range(len(cross_inputs['input_ids'])) :
                        query_id = cross_inputs['input_ids'][i][:sep_index]
                        query_att = cross_inputs['attention_mask'][i][:sep_index]
                        query_tok = cross_inputs['token_type_ids'][i][:sep_index]
        
                        context_id = cross_inputs['input_ids'][j][sep_index:]
                        context_att = cross_inputs['attention_mask'][j][sep_index:]
                        context_tok = cross_inputs['token_type_ids'][j][sep_index:]
                        query_id.extend(context_id)
                        query_att.extend(context_att)
                        query_tok.extend(context_tok)
                        new_input_ids.append(query_id)
                        new_attention_mask.append(query_att)
                        new_token_type_ids.append(query_tok)

                change_cross_inputs = {
                    'input_ids' : torch.tensor(new_input_ids).to('cuda'),
                    'attention_mask' : torch.tensor(new_attention_mask).to('cuda'),
                    'token_type_ids' : torch.tensor(new_token_type_ids).to('cuda')
                }

                cross_output = self.cross_encoder(**change_cross_inputs)
                cross_output = cross_output.view(-1, args.per_device_train_batch_size)
                targets = torch.arange(0, args.per_device_train_batch_size).long()
                                
                if torch.cuda.is_available():
                    targets = targets.to('cuda')

                score = F.log_softmax(cross_output, dim = 1)
                loss = F.nll_loss(score, targets)
                #########################No ACCUMULATION#########################
                # losses += loss.item()
                # if step % 100 == 0 :
                #     print(f'{epoch}epoch loss: {losses/(step+1)}') # Accumulation할 경우 주석처리
                
                # self.cross_encoder.zero_grad()
                # loss.backward()
                # optimizer.step()
                # scheduler.step()
                #################################################################

                #############################ACCUMULATION#########################
                loss.backward()
                if (step+1) % args.gradient_accumulation_steps == 0 :
                    optimizer.step()
                    scheduler.step()
                    self.cross_encoder.zero_grad()

                losses += loss.item()
                if (step+1) % 100 == 0 :
                    train_loss = losses / 100
                    print(f'training loss: {train_loss:4.4}')
                    losses = 0
                ##################################################################
        
        return self.cross_encoder

In [8]:
args = TrainingArguments(
    output_dir="dense_retireval",
    evaluation_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=40,
    weight_decay=0.01
)
model_checkpoint = "klue/bert-base"

# 혹시 위에서 사용한 encoder가 있다면 주석처리 후 진행해주세요 (CUDA ...)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
cross_encoder = BertEncoder.from_pretrained(model_checkpoint).to('cuda')

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertEncoder were not initialized from the model checkpoint at klue/bert-base and are newly initialized: 

In [9]:
# Retriever는 아래와 같이 사용할 수 있도록 코드를 짜봅시다.
retriever = DenseRetrieval(
    args=args,
    dataset=train_dataset,
    tokenizer=tokenizer,
    cross_encoder=cross_encoder,
    sampler = CustomSampler
)
c_encoder = retriever.train()

Epoch:   0%|          | 0/40 [00:00<?, ?it/s]

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.3505
training loss: 0.0798
training loss: 0.0634
training loss: 0.02339
training loss: 0.01976
training loss: 0.01066
training loss: 0.01028
training loss: 0.01023
training loss: 0.01002
training loss: 0.007944
training loss: 0.009278
training loss: 0.008298
training loss: 0.005324
training loss: 0.007314


Epoch:   2%|▎         | 1/40 [12:06<7:51:55, 726.04s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.007787
training loss: 0.002427
training loss: 0.00171
training loss: 0.0009209
training loss: 0.007005
training loss: 0.003466
training loss: 0.001762
training loss: 0.001404
training loss: 0.001284
training loss: 0.0008469
training loss: 0.0007349
training loss: 0.008568
training loss: 0.002439
training loss: 0.004033


Epoch:   5%|▌         | 2/40 [24:11<7:39:42, 725.87s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.002342
training loss: 0.002154
training loss: 0.0003874
training loss: 0.004067
training loss: 0.001495
training loss: 0.003042
training loss: 0.003225
training loss: 0.002308
training loss: 0.001339
training loss: 0.001433
training loss: 0.008992
training loss: 0.001328
training loss: 0.001717
training loss: 0.0006858


Epoch:   8%|▊         | 3/40 [36:17<7:27:40, 725.95s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.001086
training loss: 0.002971
training loss: 0.002006
training loss: 0.0007831
training loss: 0.001253
training loss: 0.008128
training loss: 0.0005222
training loss: 0.004348
training loss: 0.003308
training loss: 0.003969
training loss: 0.001107
training loss: 0.007538
training loss: 0.001356
training loss: 0.0004068


Epoch:  10%|█         | 4/40 [48:22<7:15:25, 725.70s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.001041
training loss: 0.001102
training loss: 0.0006416
training loss: 0.0005035
training loss: 0.0008136
training loss: 0.0009362
training loss: 0.0004269
training loss: 0.001646
training loss: 0.0002132
training loss: 0.0004679
training loss: 0.0009488
training loss: 0.004387
training loss: 0.001558
training loss: 0.0007145


Epoch:  12%|█▎        | 5/40 [1:00:26<7:02:59, 725.12s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.0002177
training loss: 0.002533
training loss: 0.000258
training loss: 0.001007
training loss: 0.001799
training loss: 0.0003331
training loss: 0.0003569
training loss: 0.003693
training loss: 0.0004952
training loss: 0.001425
training loss: 0.0002621
training loss: 0.0005605
training loss: 0.002447
training loss: 0.0001875


Epoch:  15%|█▌        | 6/40 [1:12:29<6:50:34, 724.54s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.0001969
training loss: 0.0009694
training loss: 0.0003521
training loss: 0.0003656
training loss: 0.004948
training loss: 0.0005442
training loss: 0.0001471
training loss: 0.0004391
training loss: 0.0001453
training loss: 0.0001747
training loss: 0.0002683
training loss: 0.0003567
training loss: 0.008905
training loss: 0.0003699


Epoch:  18%|█▊        | 7/40 [1:24:33<6:38:18, 724.20s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.001427
training loss: 0.002409
training loss: 0.0005248
training loss: 0.0004248
training loss: 0.000239
training loss: 0.0002524
training loss: 7.74e-05
training loss: 0.0006951
training loss: 5.222e-05
training loss: 7.559e-05
training loss: 0.0004297
training loss: 0.0002785
training loss: 0.001017
training loss: 0.003409


Epoch:  20%|██        | 8/40 [1:36:36<6:26:09, 724.03s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.00377
training loss: 0.0006355
training loss: 0.0002538
training loss: 0.0006935
training loss: 0.0005247
training loss: 6.654e-05
training loss: 0.001272
training loss: 0.0001693
training loss: 0.000369
training loss: 0.001372
training loss: 4.026e-05
training loss: 0.000114
training loss: 0.0001726
training loss: 0.0001016


Epoch:  22%|██▎       | 9/40 [1:48:41<6:14:10, 724.20s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.0032
training loss: 0.001085
training loss: 0.002624
training loss: 0.0001307
training loss: 7.584e-05
training loss: 0.003107
training loss: 0.0008919
training loss: 0.000364
training loss: 0.001271
training loss: 0.0005053
training loss: 3.187e-05
training loss: 0.005179
training loss: 0.0003564
training loss: 0.0002757


Epoch:  25%|██▌       | 10/40 [2:00:44<6:01:59, 723.97s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.0003902
training loss: 4.071e-05
training loss: 0.003874
training loss: 0.001587
training loss: 0.000473
training loss: 8.921e-05
training loss: 0.004937
training loss: 0.001061
training loss: 0.00507
training loss: 0.001584
training loss: 0.001534
training loss: 0.0005116
training loss: 0.000334
training loss: 0.01157


Epoch:  28%|██▊       | 11/40 [2:12:49<5:49:57, 724.07s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.0004293
training loss: 0.003299
training loss: 5.423e-05
training loss: 6.387e-05
training loss: 0.001355
training loss: 1.914e-05
training loss: 0.000104
training loss: 4.987e-05
training loss: 0.0005306
training loss: 0.003811
training loss: 0.004983
training loss: 0.00024
training loss: 0.003989
training loss: 0.0002248


Epoch:  30%|███       | 12/40 [2:24:53<5:37:55, 724.13s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 3.855e-05
training loss: 0.0002369
training loss: 0.001533
training loss: 8.149e-05
training loss: 0.0007819
training loss: 1.622e-05
training loss: 4.696e-05
training loss: 5.638e-05
training loss: 0.000303
training loss: 0.001329
training loss: 0.0002316
training loss: 0.0008193
training loss: 0.0002188
training loss: 3.11e-05


Epoch:  32%|███▎      | 13/40 [2:36:57<5:25:51, 724.14s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.0001875
training loss: 0.0003357
training loss: 5.633e-05
training loss: 0.0002197
training loss: 6.111e-05
training loss: 0.0007435
training loss: 0.0005621
training loss: 0.0002745
training loss: 1.181e-05
training loss: 2.955e-05
training loss: 0.0002014
training loss: 0.0003657
training loss: 6.319e-05
training loss: 3.666e-05


Epoch:  35%|███▌      | 14/40 [2:49:01<5:13:46, 724.11s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 5.278e-05
training loss: 3.229e-05
training loss: 0.01447
training loss: 0.002361
training loss: 0.001795
training loss: 0.0003298
training loss: 5.309e-05
training loss: 0.008066
training loss: 0.004494
training loss: 0.003607
training loss: 0.0001025
training loss: 0.003733
training loss: 0.0004976
training loss: 7.054e-05


Epoch:  38%|███▊      | 15/40 [3:01:04<5:01:36, 723.87s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 3.416e-05
training loss: 9.99e-05
training loss: 4.667e-05
training loss: 4.478e-05
training loss: 0.0001289
training loss: 0.003132
training loss: 0.0001834
training loss: 0.002526
training loss: 0.0006992
training loss: 0.0001326
training loss: 0.0003289
training loss: 0.0001061
training loss: 3.385e-05
training loss: 0.000167


Epoch:  40%|████      | 16/40 [3:13:08<4:49:27, 723.66s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 6.801e-05
training loss: 5.803e-05
training loss: 9.799e-05
training loss: 0.003174
training loss: 0.0002081
training loss: 6.215e-05
training loss: 6.815e-05
training loss: 2.941e-05
training loss: 0.0001565
training loss: 9.991e-05
training loss: 0.0002029
training loss: 0.001852
training loss: 0.0001041
training loss: 9.781e-05


Epoch:  42%|████▎     | 17/40 [3:25:12<4:37:29, 723.88s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 4.822e-05
training loss: 2.19e-05
training loss: 2.672e-05
training loss: 6.142e-05
training loss: 0.0003118
training loss: 3.865e-05
training loss: 7.622e-05
training loss: 0.0001078
training loss: 0.005495
training loss: 8.457e-05
training loss: 3.411e-05
training loss: 0.0002521
training loss: 0.001793
training loss: 0.007513


Epoch:  45%|████▌     | 18/40 [3:37:16<4:25:27, 723.96s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.001637
training loss: 0.0002192
training loss: 0.006405
training loss: 4.204e-05
training loss: 4.422e-05
training loss: 0.001489
training loss: 0.0005255
training loss: 9.397e-05
training loss: 2.994e-05
training loss: 9.491e-05
training loss: 0.0008316
training loss: 1.936e-05
training loss: 0.001095
training loss: 0.0006553


Epoch:  48%|████▊     | 19/40 [3:49:20<4:13:21, 723.90s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 1.947e-05
training loss: 0.005049
training loss: 0.002862
training loss: 0.0002204
training loss: 0.0002051
training loss: 0.0006704
training loss: 0.0001622
training loss: 0.0001382
training loss: 4.776e-05
training loss: 3.332e-05
training loss: 0.000456
training loss: 8.366e-06
training loss: 2.107e-05
training loss: 0.00104


Epoch:  50%|█████     | 20/40 [4:01:24<4:01:18, 723.94s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.00245
training loss: 0.0002455
training loss: 0.0001415
training loss: 3.691e-05
training loss: 0.0001752
training loss: 9.692e-05
training loss: 2.576e-05
training loss: 1.542e-05
training loss: 0.004616
training loss: 7.137e-05
training loss: 0.0002435
training loss: 0.002687
training loss: 5.116e-05
training loss: 3.143e-05


Epoch:  52%|█████▎    | 21/40 [4:13:28<3:49:14, 723.92s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.0003854
training loss: 4.08e-05
training loss: 4.277e-05
training loss: 0.001202
training loss: 4.864e-05
training loss: 0.000137
training loss: 4.186e-05
training loss: 0.004008
training loss: 0.0004
training loss: 0.000417
training loss: 4.793e-05
training loss: 6.979e-05
training loss: 4.95e-05
training loss: 0.0002307


Epoch:  55%|█████▌    | 22/40 [4:25:31<3:37:08, 723.79s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.001241
training loss: 2.004e-05
training loss: 0.0003174
training loss: 2.78e-05
training loss: 2.786e-05
training loss: 0.0001504
training loss: 0.0001866
training loss: 1.433e-05
training loss: 0.0001144
training loss: 0.0008227
training loss: 1.629e-05
training loss: 0.0002608
training loss: 5.758e-06
training loss: 9.11e-05


Epoch:  57%|█████▊    | 23/40 [4:37:35<3:25:04, 723.77s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 2.185e-05
training loss: 2.327e-05
training loss: 2.603e-05
training loss: 9.641e-05
training loss: 3.928e-05
training loss: 0.005552
training loss: 0.004783
training loss: 1.844e-05
training loss: 0.004616
training loss: 0.0007286
training loss: 0.0003371
training loss: 7.842e-05
training loss: 0.004854
training loss: 0.001918


Epoch:  60%|██████    | 24/40 [4:49:39<3:13:00, 723.78s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.0001182
training loss: 0.00311
training loss: 0.0002064
training loss: 0.0002379
training loss: 0.003165
training loss: 0.0001685
training loss: 1.381e-05
training loss: 7.732e-06
training loss: 2.815e-05
training loss: 9.496e-05
training loss: 2.398e-05
training loss: 2.245e-05
training loss: 0.01037
training loss: 6.193e-05


Epoch:  62%|██████▎   | 25/40 [5:01:43<3:00:59, 723.99s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.000158
training loss: 0.0002663
training loss: 0.0002359
training loss: 0.008378
training loss: 0.0004428
training loss: 0.0002418
training loss: 0.004264
training loss: 8.535e-05
training loss: 0.0001838
training loss: 0.0004271
training loss: 0.0005582
training loss: 0.0002043
training loss: 0.0001847
training loss: 0.0003118


Epoch:  65%|██████▌   | 26/40 [5:13:48<2:48:57, 724.09s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 1.84e-05
training loss: 0.0002231
training loss: 1.995e-05
training loss: 0.0008134
training loss: 1.798e-05
training loss: 2.245e-05
training loss: 1.192e-05
training loss: 0.0001432
training loss: 0.001195
training loss: 0.003376
training loss: 7.274e-05
training loss: 3.601e-05
training loss: 0.0001614
training loss: 4.06e-05


Epoch:  68%|██████▊   | 27/40 [5:25:51<2:36:50, 723.91s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.002204
training loss: 8.163e-05
training loss: 0.0005679
training loss: 0.0001615
training loss: 0.001744
training loss: 5.213e-05
training loss: 3.213e-05
training loss: 0.004149
training loss: 0.0001344
training loss: 0.0002419
training loss: 0.0003039
training loss: 3.663e-05
training loss: 3.085e-05
training loss: 2.287e-05


Epoch:  70%|███████   | 28/40 [5:37:54<2:24:44, 723.74s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.0001181
training loss: 0.00359
training loss: 0.0008065
training loss: 3.833e-05
training loss: 0.003734
training loss: 2.715e-05
training loss: 6.239e-05
training loss: 1.074e-05
training loss: 2.941e-05
training loss: 0.004575
training loss: 6.001e-05
training loss: 2.708e-05
training loss: 0.007172
training loss: 5.03e-06


Epoch:  72%|███████▎  | 29/40 [5:49:58<2:12:42, 723.84s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.0005449
training loss: 3.247e-06
training loss: 7.748e-05
training loss: 0.00833
training loss: 8.24e-05
training loss: 0.0006284
training loss: 0.003671
training loss: 3.933e-06
training loss: 0.0008329
training loss: 0.0002434
training loss: 0.0004809
training loss: 0.0001319
training loss: 0.0001727
training loss: 5.503e-06


Epoch:  75%|███████▌  | 30/40 [6:02:02<2:00:37, 723.74s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 6.357e-05
training loss: 3.211e-05
training loss: 0.002569
training loss: 0.0001236
training loss: 0.002123
training loss: 4.966e-05
training loss: 1.966e-06
training loss: 0.0001368
training loss: 2.347e-05
training loss: 7.517e-05
training loss: 0.0006265
training loss: 0.0007982
training loss: 4.347e-05
training loss: 2.195e-06


Epoch:  78%|███████▊  | 31/40 [6:14:06<1:48:34, 723.88s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.002003
training loss: 1.408e-05
training loss: 0.003017
training loss: 4.902e-06
training loss: 0.003926
training loss: 1.194e-05
training loss: 4.011e-06
training loss: 0.005143
training loss: 0.001303
training loss: 0.003604
training loss: 1.981e-05
training loss: 5.223e-05
training loss: 6.534e-06
training loss: 1.126e-05


Epoch:  80%|████████  | 32/40 [6:26:10<1:36:31, 723.93s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.004037
training loss: 2.049e-06
training loss: 0.000147
training loss: 6.309e-06
training loss: 0.0003035
training loss: 7.727e-06
training loss: 7.828e-06
training loss: 0.0008264
training loss: 0.0002318
training loss: 1.127e-05
training loss: 7.573e-06
training loss: 0.00225
training loss: 1.424e-05
training loss: 0.000622


Epoch:  82%|████████▎ | 33/40 [6:38:14<1:24:26, 723.83s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 2.202e-05
training loss: 0.0001188
training loss: 0.02028
training loss: 0.0003871
training loss: 0.0001118
training loss: 1.143e-05
training loss: 3.567e-05
training loss: 1.098e-05
training loss: 0.0003817
training loss: 0.0001465
training loss: 0.0001211
training loss: 4.683e-05
training loss: 0.002093
training loss: 0.0001833


Epoch:  85%|████████▌ | 34/40 [6:50:19<1:12:24, 724.12s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.001294
training loss: 3.352e-06
training loss: 1.453e-05
training loss: 3.072e-06
training loss: 3.855e-05
training loss: 2.285e-05
training loss: 0.003578
training loss: 0.0007481
training loss: 2.068e-05
training loss: 3.525e-05
training loss: 6.191e-06
training loss: 0.000141
training loss: 3.082e-06
training loss: 0.00388


Epoch:  88%|████████▊ | 35/40 [7:02:23<1:00:20, 724.11s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.003266
training loss: 2.702e-05
training loss: 3.886e-05
training loss: 8.95e-05
training loss: 9.016e-06
training loss: 0.0001198
training loss: 3.098e-06
training loss: 3.673e-06
training loss: 0.003165
training loss: 0.004281
training loss: 0.0002946
training loss: 0.005452
training loss: 1.561e-05
training loss: 0.000985


Epoch:  90%|█████████ | 36/40 [7:14:27<48:17, 724.29s/it]  




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 2.465e-05
training loss: 2.735e-05
training loss: 1.919e-05
training loss: 4.198e-06
training loss: 4.287e-06
training loss: 1.09e-05
training loss: 0.0001368
training loss: 0.00149
training loss: 0.004442
training loss: 0.01746
training loss: 2.167e-06
training loss: 2.344e-05
training loss: 2.894e-05
training loss: 0.0002701


Epoch:  92%|█████████▎| 37/40 [7:26:32<36:12, 724.28s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.00067
training loss: 3.091e-06
training loss: 0.007828
training loss: 2.818e-05
training loss: 0.0009956
training loss: 0.0002
training loss: 0.00065
training loss: 0.005993
training loss: 5.063e-05
training loss: 0.0002076
training loss: 0.0001719
training loss: 3.02e-05
training loss: 1.212e-05
training loss: 8.095e-05


Epoch:  95%|█████████▌| 38/40 [7:43:03<26:49, 804.54s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 0.002083
training loss: 0.003962
training loss: 0.0004586
training loss: 2.468e-05
training loss: 0.001955
training loss: 0.003146
training loss: 5.406e-05
training loss: 4.448e-06
training loss: 4.7e-06
training loss: 0.005023
training loss: 5.499e-06
training loss: 9.262e-06
training loss: 0.0002058
training loss: 1.164e-05


Epoch:  98%|█████████▊| 39/40 [7:55:30<13:07, 787.15s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1442.0, style=ProgressStyle(description_w…

training loss: 4.615e-05
training loss: 2.516e-05
training loss: 5.389e-06
training loss: 2.925e-06
training loss: 1.48e-05
training loss: 9.347e-06
training loss: 0.0001307
training loss: 9.974e-06
training loss: 0.001009
training loss: 2.184e-05
training loss: 5.502e-05
training loss: 1.904e-05
training loss: 3.772e-05
training loss: 3.033e-05


Epoch: 100%|██████████| 40/40 [8:07:35<00:00, 731.38s/it]







In [10]:
torch.save(c_encoder, '/opt/ml/custom/c_encoder_e40_b16.pt')

## 실험

In [11]:
valid_corpus = list(set([example['context'] for example in dataset['validation']]))[:10]
sample_idx = random.choice(range(len(dataset['validation'])))
query = dataset['validation'][sample_idx]['question']
ground_truth = dataset['validation'][sample_idx]['context']

if not ground_truth in valid_corpus:
  valid_corpus.append(ground_truth)

print(query)
print(ground_truth)

볼드윈이 "당신들은 나를 야유합니까?"라는 말을 한 연도는?
퇴임에서 볼드윈의 세월은 조용하였다. 네빌 체임벌린이 사망하면서 전쟁 이전의 유화 정책에서 볼드윈의 지각된 부분은 제2차 세계 대전이 일어난 동안과 그 후에 그를 인기없는 인물로 만들었다. 신문의 캠페인은 그를 전쟁 생산에 자신의 시골 저택의 철문을 기부하지 않은 것으로 사냥하였다. 전쟁이 일어난 동안 윈스턴 처칠은 에이먼 데 벌레라의 아일랜드의 지속적인 중립을 향한 더욱 힘든 경향을 취하는 영국의 조언에 그를 단 한번 상담하였다.\n\n1945년 6월 부인 루시 여사가 사망하였다. 이제 볼드윈 자신은 관절염을 겪어 걸어다는 데 지팡이가 필요하였다. 조지 5세의 동상의 공개식에 1947년 런던에서 자신의 최종 공개적인 출연을 이루었다. 관중들은 전직 총리를 알아주어 그를 응원하였으나 이 당시 볼드윈은 귀머거리였고, 그들에게 "당신들은 나를 야유합니까?"라고 의문하였다. 1930년 케임브리지 대학교의 총장으로 만들어진 그는 1947년 12월 14일 80세의 나이에 우스터셔주 스투어포트온세번 근처 애슬리홀에서 수면 중 자신의 사망까지 이 수용력에 지속하였다. 그는 화장되었고, 그의 재는 우스터 대성당에 안치되었다.


In [12]:
with torch.no_grad() :
    c_encoder.eval()
    
    score_list = []
    for i in range(len(valid_corpus)) :
        passage = valid_corpus[i]
        tokenized_examples = tokenizer(
            query,
            passage,
            truncation="only_second",
            max_length=512,
            stride=128,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            #return_token_type_ids=False,  # roberta모델을 사용할 경우 False, bert를 사용할 경우 True로 표기해야합니다.
            padding="max_length",
            return_tensors='pt'
        )

        score = 0
        for i in range(len(tokenized_examples['input_ids'])) :
            c_input = {
                'input_ids' : torch.tensor(tokenized_examples['input_ids'][i].unsqueeze(dim=0)).to('cuda'),
                'attention_mask' : torch.tensor(tokenized_examples['attention_mask'][i].unsqueeze(dim=0)).to('cuda'),
                'token_type_ids' : torch.tensor(tokenized_examples['token_type_ids'][i].unsqueeze(dim=0)).to('cuda')
            }
            tmp_score = c_encoder(**c_input).to('cpu')
            score += tmp_score
        score = score / len(tokenized_examples['input_ids'])
        score_list.append(score)
    sort_result = torch.sort(torch.tensor(score_list), descending=True)

    scores, index_list = sort_result[0], sort_result[1]

  'input_ids' : torch.tensor(tokenized_examples['input_ids'][i].unsqueeze(dim=0)).to('cuda'),
  'attention_mask' : torch.tensor(tokenized_examples['attention_mask'][i].unsqueeze(dim=0)).to('cuda'),
  'token_type_ids' : torch.tensor(tokenized_examples['token_type_ids'][i].unsqueeze(dim=0)).to('cuda')


In [13]:
k = 5
print("[Search query]\n", query, "\n")
print("[Ground truth passage]")
print(ground_truth, "\n")

for i in range(k):
  print("Top-%d passage with score %.4f" % (i+1, scores[i]))
  print(valid_corpus[index_list[i]])

[Search query]
 볼드윈이 "당신들은 나를 야유합니까?"라는 말을 한 연도는? 

[Ground truth passage]
퇴임에서 볼드윈의 세월은 조용하였다. 네빌 체임벌린이 사망하면서 전쟁 이전의 유화 정책에서 볼드윈의 지각된 부분은 제2차 세계 대전이 일어난 동안과 그 후에 그를 인기없는 인물로 만들었다. 신문의 캠페인은 그를 전쟁 생산에 자신의 시골 저택의 철문을 기부하지 않은 것으로 사냥하였다. 전쟁이 일어난 동안 윈스턴 처칠은 에이먼 데 벌레라의 아일랜드의 지속적인 중립을 향한 더욱 힘든 경향을 취하는 영국의 조언에 그를 단 한번 상담하였다.\n\n1945년 6월 부인 루시 여사가 사망하였다. 이제 볼드윈 자신은 관절염을 겪어 걸어다는 데 지팡이가 필요하였다. 조지 5세의 동상의 공개식에 1947년 런던에서 자신의 최종 공개적인 출연을 이루었다. 관중들은 전직 총리를 알아주어 그를 응원하였으나 이 당시 볼드윈은 귀머거리였고, 그들에게 "당신들은 나를 야유합니까?"라고 의문하였다. 1930년 케임브리지 대학교의 총장으로 만들어진 그는 1947년 12월 14일 80세의 나이에 우스터셔주 스투어포트온세번 근처 애슬리홀에서 수면 중 자신의 사망까지 이 수용력에 지속하였다. 그는 화장되었고, 그의 재는 우스터 대성당에 안치되었다. 

Top-1 passage with score 9.6913
퇴임에서 볼드윈의 세월은 조용하였다. 네빌 체임벌린이 사망하면서 전쟁 이전의 유화 정책에서 볼드윈의 지각된 부분은 제2차 세계 대전이 일어난 동안과 그 후에 그를 인기없는 인물로 만들었다. 신문의 캠페인은 그를 전쟁 생산에 자신의 시골 저택의 철문을 기부하지 않은 것으로 사냥하였다. 전쟁이 일어난 동안 윈스턴 처칠은 에이먼 데 벌레라의 아일랜드의 지속적인 중립을 향한 더욱 힘든 경향을 취하는 영국의 조언에 그를 단 한번 상담하였다.\n\n1945년 6월 부인 루시 여사가 사망하였다. 이제 볼드윈 자신은 관절염을 겪어 걸어다는 데 지팡이가 필요하였다. 조지 5세의

In [28]:
index_list

tensor([10,  0,  4,  1,  5,  9,  6,  3,  7,  2,  8])

In [27]:
tokenized_examples['input_ids'][0].unsqueeze(dim=0).shape

torch.Size([1, 512])

In [None]:
top_k_index_list = []
for i in range(len(index_list)) :
    temp = index_list[i][:k]
    top_k_index_list.appedn(temp)

## 실제

In [30]:
with open('/opt/ml/data/wikipedia_documents.json', "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = list(
    dict.fromkeys([v["text"] for v in wiki.values()])
)  # set 은 매번 순서가 바뀌므로

In [32]:
question_data = dataset['validation']['question']
with torch.no_grad() :
    c_encoder.eval()

    result_scores = []
    result_indices = []
    for i in tqdm(range(len(question_data))) :
        question = question_data[i]

        question_score = []
        for i in tqdm(range(len(corpus))) :
            passage = corpus[i]
            tokenized_examples = tokenizer(
                question,
                passage,
                truncation="only_second",
                max_length=512,
                stride=128,
                return_overflowing_tokens=True,
                return_offsets_mapping=True,
                #return_token_type_ids=False,  # roberta모델을 사용할 경우 False, bert를 사용할 경우 True로 표기해야합니다.
                padding="max_length",
                return_tensors='pt'
            )

            score = 0
            for i in range(len(tokenized_examples['input_ids'])) :
                c_input = {
                    'input_ids' : torch.tensor(tokenized_examples['input_ids'][i].unsqueeze(dim=0)).to('cuda'),
                    'attention_mask' : torch.tensor(tokenized_examples['attention_mask'][i].unsqueeze(dim=0)).to('cuda'),
                    'token_type_ids' : torch.tensor(tokenized_examples['token_type_ids'][i].unsqueeze(dim=0)).to('cuda')
                }
                tmp_score = c_encoder(**c_input).to('cpu')
                score += tmp_score
            score = score / len(tokenized_examples['input_ids'])
            question_score.append(score)

        sort_result = torch.sort(torch.tensor(question_score), descending=True)
        scores, index_list = sort_result[0], sort_result[1]

        result_scores.append(scores)
        result_indices.append(index_list)

HBox(children=(FloatProgress(value=0.0, max=240.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=56737.0), HTML(value='')))

  'input_ids' : torch.tensor(tokenized_examples['input_ids'][i].unsqueeze(dim=0)).to('cuda'),
  'attention_mask' : torch.tensor(tokenized_examples['attention_mask'][i].unsqueeze(dim=0)).to('cuda'),
  'token_type_ids' : torch.tensor(tokenized_examples['token_type_ids'][i].unsqueeze(dim=0)).to('cuda')


KeyboardInterrupt: 

In [None]:
top_k_index_list = []
for i in range(len(index_list)) :
    temp = index_list[i][:k]
    top_k_index_list.appedn(temp)

In [None]:
total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": top_k_index_list[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in top_k_index_list[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_100 = pd.DataFrame(total)

## Elastic

In [11]:
dataset = load_from_disk('/opt/ml/data/train_dataset')
train_dataset = dataset['train']

In [12]:
data = pd.read_csv('/opt/ml/custom/top100_wikipedia.csv')

In [13]:
doc_indices = []
for i in range(len(data)) :
    tmp = eval(data['document_id'][i])
    doc_indices.append(tmp)

In [14]:
with open('/opt/ml/data/wikipedia_documents.json', "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = []
for v in wiki.values() :
    corpus.append(v['text'])

In [16]:
class BertEncoder(BertPreTrainedModel):
    def __init__(self, config):
        super(BertEncoder, self).__init__(config)

        self.bert = BertModel(config)
        self.init_weights()
        classifier_dropout=(
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = torch.nn.Dropout(classifier_dropout)
        self.linear = torch.nn.Linear(config.hidden_size, 1)
      
    def forward(
            self,
            input_ids, 
            attention_mask=None,
            token_type_ids=None
        ): 

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        output = self.linear(pooled_output)
        return output

In [12]:
c_encoder = torch.load('/opt/ml/custom/c_encoder_e20.pt')

In [15]:
model_checkpoint = "klue/bert-base"
# 혹시 위에서 사용한 encoder가 있다면 주석처리 후 진행해주세요 (CUDA ...)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
question_data = dataset['validation']['question']
with torch.no_grad() : 
    c_encoder.eval()

    result_scores = []
    result_indices = []
    for i in tqdm(range(len(question_data))) :
        question = question_data[i]
        question_score = []
        for indice in tqdm(doc_indices[i]) :
            passage = corpus[indice]
            tokenized_examples = tokenizer(
                question,
                passage,
                truncation="only_second",
                max_length=512,
                stride=128,
                return_overflowing_tokens=True,
                return_offsets_mapping=True,
                #return_token_type_ids=False,  # roberta모델을 사용할 경우 False, bert를 사용할 경우 True로 표기해야합니다.
                padding="max_length",
                return_tensors='pt'
            )
            score = 0
            for i in range(len(tokenized_examples['input_ids'])) :
                c_input = {
                    'input_ids' : torch.tensor(tokenized_examples['input_ids'][i].unsqueeze(dim=0)).to('cuda'),
                    'attention_mask' : torch.tensor(tokenized_examples['attention_mask'][i].unsqueeze(dim=0)).to('cuda'),
                    'token_type_ids' : torch.tensor(tokenized_examples['token_type_ids'][i].unsqueeze(dim=0)).to('cuda')
                }
                tmp_score = c_encoder(**c_input).to('cpu')
                score += tmp_score
            score = score / len(tokenized_examples['input_ids'])
            question_score.append(score)
        sort_result = torch.sort(torch.tensor(question_score), descending=True)
        scores, index_list = sort_result[0], sort_result[1]

        result_scores.append(scores.tolist())
        result_indices.append(index_list.tolist())        

In [33]:
final_indices = []
for i in range(len(doc_indices)) :
    t_list = [doc_indices[i][result_indices[i][k]] for k in range(7)]
    final_indices.append(t_list)

In [34]:
total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": final_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in final_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_50 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [35]:
correct_length = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length.append(i)
print(len(correct_length) / len(dataset['validation']))

0.875


In [36]:
cqas_50.to_csv('b16_special_shuffle_elastic_ce40_t7.csv', index = False)

## Test

In [7]:
dataset = load_from_disk('/opt/ml/data/test_dataset')

In [8]:
data = pd.read_csv('/opt/ml/custom/test_elastic_top100.csv')

doc_indices = []
for i in range(len(data)) :
    tmp = eval(data['document_id'][i])
    doc_indices.append(tmp)

In [9]:
# test에 대해서만 실행
for i in tqdm(range(len(doc_indices))) :
    for j in range(len(doc_indices[i])) :
        doc_indices[i][j] = int(doc_indices[i][j])

HBox(children=(FloatProgress(value=0.0, max=600.0), HTML(value='')))




In [10]:
with open('/opt/ml/data/wikipedia_documents.json', "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = []
for v in wiki.values() :
    corpus.append(v['text'])

In [11]:
class BertEncoder(BertPreTrainedModel):
    def __init__(self, config):
        super(BertEncoder, self).__init__(config)

        self.bert = BertModel(config)
        self.init_weights()
        classifier_dropout=(
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = torch.nn.Dropout(classifier_dropout)
        self.linear = torch.nn.Linear(config.hidden_size, 1)
      
    def forward(
            self,
            input_ids, 
            attention_mask=None,
            token_type_ids=None
        ): 

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        output = self.linear(pooled_output)
        return output

In [12]:
c_encoder = torch.load('/opt/ml/custom/c_encoder_e40_b16.pt')
model_checkpoint = "klue/bert-base"
# 혹시 위에서 사용한 encoder가 있다면 주석처리 후 진행해주세요 (CUDA ...)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
question_data = dataset['validation']['question']
with torch.no_grad() : 
    c_encoder.eval()

    result_scores = []
    result_indices = []
    for i in tqdm(range(len(question_data))) :
        question = question_data[i]
        question_score = []
        for indice in tqdm(doc_indices[i]) :
            passage = corpus[int(indice)]
            tokenized_examples = tokenizer(
                question,
                passage,
                truncation="only_second",
                max_length=512,
                stride=128,
                return_overflowing_tokens=True,
                return_offsets_mapping=True,
                #return_token_type_ids=False,  # roberta모델을 사용할 경우 False, bert를 사용할 경우 True로 표기해야합니다.
                padding="max_length",
                return_tensors='pt'
            )
            score = 0
            for i in range(len(tokenized_examples['input_ids'])) :
                c_input = {
                    'input_ids' : torch.tensor(tokenized_examples['input_ids'][i].unsqueeze(dim=0)).to('cuda'),
                    'attention_mask' : torch.tensor(tokenized_examples['attention_mask'][i].unsqueeze(dim=0)).to('cuda'),
                    'token_type_ids' : torch.tensor(tokenized_examples['token_type_ids'][i].unsqueeze(dim=0)).to('cuda')
                }
                tmp_score = c_encoder(**c_input).to('cpu')
                score += tmp_score
            score = score / len(tokenized_examples['input_ids'])
            question_score.append(score)
        sort_result = torch.sort(torch.tensor(question_score), descending=True)
        scores, index_list = sort_result[0], sort_result[1]

        result_scores.append(scores.tolist())
        result_indices.append(index_list.tolist())        

### result_indices 저장 및 불러오기

In [26]:
import csv
with open('listfile.csv', 'w', newline='') as f: 
    writer = csv.writer(f)
    writer.writerow(result_indices)

In [29]:
with open('listfile.csv', 'r', encoding='utf-8') as f:
    rdr = csv.reader(f)
    for i, line in enumerate(rdr) :
        if i == 0 :
            kk = line

### 끝

In [14]:
final_indices = []
for i in range(len(doc_indices)) :
    t_list = [doc_indices[i][result_indices[i][k]] for k in range(5)]
    final_indices.append(t_list)

In [15]:
total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": final_indices[idx],
            "context": [corpus[pid] for pid in final_indices[idx]]
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_50 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=600.0, style=ProgressStyle(descri…




In [24]:
cqas_50.to_csv('elastic_crossencoder.csv', index = False)

In [59]:
total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": final_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in final_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_50 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=600.0, style=ProgressStyle(descri…




In [None]:
correct_length = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length.append(i)
print(len(correct_length) / len(dataset['validation']))

In [62]:
cqas_50.to_csv('test_b16_special_shuffle_elastic_ce40_t5.csv', index = False)