In [1]:
import json
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from tqdm import trange

from pprint import pprint

from sklearn.feature_extraction.text import TfidfVectorizer

import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    BertModel, RobertaModel,
    BertPreTrainedModel,
    AdamW, get_linear_schedule_with_warmup,
    TrainingArguments,
)
from datasets import (
    Dataset,
    load_from_disk,
    concatenate_datasets,
)

from typing import List

In [2]:
# 난수 고정
def set_seed(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # if use multi-GPU
    random.seed(random_seed)
    np.random.seed(random_seed)
    
set_seed(42) # magic number :)

In [3]:
print ("PyTorch version:[%s]."%(torch.__version__))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print ("device:[%s]."%(device))

PyTorch version:[1.6.0].
device:[cuda:0].


## Training

In [4]:
class BertEncoder(BertPreTrainedModel):
    def __init__(self, config):
        super(BertEncoder, self).__init__(config)

        self.bert = BertModel(config)
        self.init_weights()
      
    def forward(
            self,
            input_ids, 
            attention_mask=None,
            token_type_ids=None
        ): 

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        pooled_output = outputs[1]
        return pooled_output

In [5]:
dataset = load_from_disk('/opt/ml/data/train_dataset')
train_dataset = dataset['train']

In [6]:
# Anwer
class DenseRetrieval:
    def __init__(self,
        args,
        dataset,
        tokenizer,
        p_encoder,
        q_encoder
    ):
        """
        학습과 추론에 사용될 여러 셋업을 마쳐봅시다.
        """

        self.args = args
        self.dataset = dataset

        self.tokenizer = tokenizer
        self.p_encoder = p_encoder
        self.q_encoder = q_encoder

    def train(self, args=None, tokenizer = None):
        if args is None:
            args = self.args
        if tokenizer is None :
            tokenizer = self.tokenizer

        q_seqs = tokenizer(self.dataset['question'], padding="max_length", truncation=True, return_tensors='pt')
        p_seqs = tokenizer(self.dataset['context'], padding="max_length", truncation=True, return_tensors='pt')

        train_dataset = TensorDataset(p_seqs['input_ids'], p_seqs['attention_mask'], p_seqs['token_type_ids'], 
                        q_seqs['input_ids'], q_seqs['attention_mask'], q_seqs['token_type_ids'])
        train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size)

        no_decay = ["bias" ,"LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {"params": [p for n, p in self.p_encoder.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
            {"params": [p for n, p in self.p_encoder.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
            {"params": [p for n, p in self.q_encoder.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
            {"params": [p for n, p in self.q_encoder.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            # eps=args.adam_epsilon
        )

        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
        
        global_step = 0

        self.p_encoder.zero_grad()
        self.q_encoder.zero_grad()
        # torch.cuda.empty_cache()

        train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
        self.q_encoder.train()
        self.p_encoder.train()
        for epoch, _ in enumerate(train_iterator):
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            # loss_value=0 # Accumulation할 때 진행
            losses = 0
            for step, batch in enumerate(epoch_iterator):
                if torch.cuda.is_available():
                    batch = tuple(t.cuda() for t in batch)

                p_inputs = {'input_ids': batch[0],
                            'attention_mask': batch[1],
                            'token_type_ids': batch[2]
                            }
                
                q_inputs = {'input_ids': batch[3],
                            'attention_mask': batch[4],
                            'token_type_ids': batch[5]}

                p_outputs = self.p_encoder(**p_inputs)  # (batch_size, emb_dim)
                q_outputs = self.q_encoder(**q_inputs)  # (batch_size, emb_dim)

                # Calculate similarity score & loss
                sim_scores = torch.matmul(q_outputs, torch.transpose(p_outputs, 0, 1))  # (batch_size, emb_dim) x (emb_dim, batch_size) = (batch_size, batch_size)

                # target: position of positive samples = diagonal element 
                targets = torch.arange(0, args.per_device_train_batch_size).long()
                
                if torch.cuda.is_available():
                    targets = targets.to('cuda')

                sim_scores = F.log_softmax(sim_scores, dim=1)

                loss = F.nll_loss(sim_scores, targets)
                losses += loss.item()
                if step % 100 == 0 :
                    print(f'{epoch}epoch loss: {losses/(step+1)}') # Accumulation할 경우 주석처리

                
                #################ACCUMULATION###############################
                # loss_value += loss
                # if (step+1) % args.gradient_accumulation_steps == 0 :
                #     optimizer.step()
                #     scheduler.step()
                #     self.q_encoder.zero_grad()
                #     self.p_encoder.zero_grad()
                #     global_step += 1
                #     print(loss_value/args.gradient_accumulation_steps)
                #     loss_value = 0
                ############################################################
                self.q_encoder.zero_grad()
                self.p_encoder.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

                global_step += 1
                
                #torch.cuda.empty_cache()
                del p_inputs, q_inputs

        return self.p_encoder, self.q_encoder

In [12]:
# p_encoder.cpu()
# q_encoder.cpu()
# del p_encoder
# del q_encoder
# torch.cuda.empty_cache()

## Make Q_Encoder & P_Embedding

In [7]:
args = TrainingArguments(
    output_dir="dense_retireval",
    evaluation_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=1,
    num_train_epochs=10,
    weight_decay=0.01
)
model_checkpoint = "klue/bert-base"

# 혹시 위에서 사용한 encoder가 있다면 주석처리 후 진행해주세요 (CUDA ...)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
p_encoder = BertEncoder.from_pretrained(model_checkpoint).to(args.device)
q_encoder = BertEncoder.from_pretrained(model_checkpoint).to(args.device)

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.bi

In [8]:
# Retriever는 아래와 같이 사용할 수 있도록 코드를 짜봅시다.
retriever = DenseRetrieval(
    args=args,
    dataset=train_dataset,
    tokenizer=tokenizer,
    p_encoder=p_encoder,
    q_encoder=q_encoder
)
p_encoder, q_encoder = retriever.train()

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=247.0, style=ProgressStyle(description_wi…

0epoch loss: 45.39480209350586
0epoch loss: 3.00825394583073
0epoch loss: 1.9717257092266682


Epoch:  10%|█         | 1/10 [04:12<37:50, 252.23s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=247.0, style=ProgressStyle(description_wi…

1epoch loss: 1.717771053314209
1epoch loss: 0.35106341000577196
1epoch loss: 0.34898104852665474


Epoch:  20%|██        | 2/10 [08:24<33:37, 252.23s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=247.0, style=ProgressStyle(description_wi…

2epoch loss: 0.5439592003822327
2epoch loss: 0.21938291860347853
2epoch loss: 0.19432256530144423


Epoch:  30%|███       | 3/10 [12:36<29:26, 252.29s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=247.0, style=ProgressStyle(description_wi…

3epoch loss: 0.36584410071372986
3epoch loss: 0.10669050270133643
3epoch loss: 0.10537736134520218


Epoch:  40%|████      | 4/10 [16:49<25:13, 252.29s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=247.0, style=ProgressStyle(description_wi…

4epoch loss: 0.035409167408943176
4epoch loss: 0.07716580288655281
4epoch loss: 0.07500652995924013


Epoch:  50%|█████     | 5/10 [21:01<21:01, 252.29s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=247.0, style=ProgressStyle(description_wi…

5epoch loss: 0.06326474994421005
5epoch loss: 0.039368913037959474
5epoch loss: 0.03613295675147051


Epoch:  60%|██████    | 6/10 [25:13<16:48, 252.24s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=247.0, style=ProgressStyle(description_wi…

6epoch loss: 0.05519494414329529
6epoch loss: 0.03370060929098532
6epoch loss: 0.028651823582196756


Epoch:  70%|███████   | 7/10 [29:25<12:36, 252.21s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=247.0, style=ProgressStyle(description_wi…

7epoch loss: 0.030617132782936096
7epoch loss: 0.020839041321769492
7epoch loss: 0.023398075871245875


Epoch:  80%|████████  | 8/10 [33:37<08:24, 252.18s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=247.0, style=ProgressStyle(description_wi…

8epoch loss: 0.007376573514193296
8epoch loss: 0.02252028875566136
8epoch loss: 0.02108692257272775


Epoch:  90%|█████████ | 9/10 [37:50<04:12, 252.21s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=247.0, style=ProgressStyle(description_wi…

9epoch loss: 0.06962314248085022
9epoch loss: 0.02114876175928297
9epoch loss: 0.022255618022535967


Epoch: 100%|██████████| 10/10 [42:02<00:00, 252.22s/it]







In [9]:
with open('/opt/ml/data/wikipedia_documents.json', "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = list(
    dict.fromkeys([v["text"] for v in wiki.values()])
)  # set 은 매번 순서가 바뀌므로

In [10]:
# p_encoder = retriever.p_encoder
# q_encoder = retriever.q_encoder
with torch.no_grad() :
    p_encoder.eval()

    p_embs = []
    for p in tqdm(corpus) :
        p = tokenizer(p, padding='max_length', truncation=True, return_tensors='pt').to('cuda')
        p_emb = p_encoder(**p).to('cpu').numpy()
        p_embs.append(p_emb)
p_embs = torch.Tensor(p_embs).squeeze()

HBox(children=(FloatProgress(value=0.0, max=56737.0), HTML(value='')))




In [11]:
import pickle

file_path = '/opt/ml/custom/passage_embedding.bin'
with open(file_path, 'wb') as file :
    pickle.dump(p_embs, file)

In [12]:
p_encoder.cpu()
del p_encoder
torch.cuda.empty_cache()

In [13]:
torch.save(q_encoder, '/opt/ml/custom/q_encoder.pt')

## Get Relavant Documnet

In [44]:
with open('/opt/ml/data/wikipedia_documents.json', "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = list(
    dict.fromkeys([v["text"] for v in wiki.values()])
)  # set 은 매번 순서가 바뀌므로

In [6]:
model_checkpoint = "klue/bert-base"

# 혹시 위에서 사용한 encoder가 있다면 주석처리 후 진행해주세요 (CUDA ...)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [19]:
import pickle
with open('/opt/ml/custom/passage_embedding.bin', 'rb') as file :
    p_embs = pickle.load(file)
p_embs = p_embs

In [20]:
q_encoder = torch.load('/opt/ml/custom/q_encoder.pt')

### 1개 확인

In [21]:
query = dataset['validation']['question'][0]
queries = dataset['validation']['question'][:2]

In [22]:
with torch.no_grad() :
    q_encoder.eval()
    q_seqs_val = tokenizer([query], padding="max_length", truncation=True, return_tensors="pt").to(device)
    q_emb = q_encoder(**q_seqs_val).to('cpu')

In [32]:
dot_prod_scores = torch.mm(q_emb, p_embs.T)
dot_prod_scores

tensor([[159.0737, 155.1720, 158.0787,  ..., 156.8211, 155.6962, 156.6308]])

In [33]:
dot_prod_scores = torch.mm(q_emb, p_embs.T)
rank = torch.argsort(dot_prod_scores, dim=1, descending=True)

In [34]:
rank

tensor([[ 5694,  6509,  9728,  ..., 20579,  4294,  4291]])

### 여러개

In [35]:
with torch.no_grad() :
    q_encoder.eval()
    q_seqs_val = tokenizer(queries, padding="max_length", truncation=True, return_tensors="pt").to(device)
    q_emb = q_encoder(**q_seqs_val).to('cpu')

In [47]:
dot_prod_scores = torch.mm(q_emb, p_embs.T)
sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

In [49]:
scores = sort_result[0]
ranks = sort_result[1]

In [None]:
k = 5
print("[Search query]\n", query, "\n")
print("[Ground truth passage]")
print(dataset['validation']['context'][0], "\n")

for i in range(k):
  print("Top-%d passage with score %.4f" % (i+1, scores[0].squeeze()[i]))
  print(corpus[ranks[0][i]])

### 연산 시작

In [14]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i][:k])
        result_indices.append(ranks[i][:k])
    
    return result_scores, result_indices

In [16]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 20)

In [18]:
total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [22]:
correct_length = []
for i in range(len(cqas)) :
    if cqas['original_context'][i] in cqas['context'][i] :
        correct_length.append(i)

In [25]:
print(len(correct_length) / len(dataset['validation']))

0.5666666666666667


In [26]:
cqas.to_csv('/opt/ml/custom/valid_dpr.csv', index = False)