In [1]:
import faiss
import random
import torch
import numpy as np
import torch.nn.functional as F

from tqdm.auto import tqdm
from pprint import pprint
from torch.utils.data import DataLoader, TensorDataset, SequentialSampler
from datasets import load_from_disk, Dataset
from transformers import (
    AutoTokenizer,
    AdamW, get_linear_schedule_with_warmup,
    TrainingArguments, RobertaModel, RobertaPreTrainedModel, BertModel, BertPreTrainedModel
)


In [21]:
# Anwer
class DenseRetrieval_with_Faiss:
    def __init__(self,
        args,
        dataset,
        tokenizer,
        p_encoder,
        q_encoder,
        num_neg=5,
        hard_neg=1,
        is_trained=False,
    ):
        """
        학습과 추론에 사용될 여러 셋업을 마쳐봅시다.
        """

        self.args = args
        self.dataset = dataset
        self.num_neg = num_neg

        self.tokenizer = tokenizer
        self.p_encoder = p_encoder
        self.q_encoder = q_encoder
        
        self.wiki_dataset = load_from_disk("/home/ubuntu/workspace/data/wiki_preprocessed_droped")
        
        if is_trained:
            pass
        else:
            self.prepare_in_batch_negative(num_neg=num_neg, hard_neg=hard_neg)

    def prepare_in_batch_negative(self,
        dataset=None,
        num_neg=5,
        hard_neg=1,
        k=100,
        tokenizer=None
    ):
        if num_neg < hard_neg:
            raise 'num_neg는 hard_neg보다 커야합니다.'
        wiki_datasets = self.wiki_dataset
        wiki_datasets.load_elasticsearch_index("text", host="localhost", port="9200", es_index_name="wikipedia_contexts")
        if dataset is None:
            dataset = self.dataset

        if tokenizer is None:
            tokenizer = self.tokenizer

        # 1. In-Batch-Negative 만들기
        # CORPUS를 np.array로 변환해줍니다.
        p_with_neg = []

        for c in tqdm(dataset):
            p_with_neg.append(c['context'])
            query = c['question']
            p_neg = []
            _, retrieved_examples = wiki_datasets.get_nearest_examples("text", query, k=k)
            for index in range(k):
                if retrieved_examples['document_id'][index] == c['document_id']:
                    continue
                p_neg.append(retrieved_examples['text'][index])
            p_with_neg.extend(p_neg[:hard_neg])
            p_with_neg.extend(p_neg[k - num_neg + hard_neg:])

        # 2. (Question, Passage) 데이터셋 만들어주기
        q_seqs = tokenizer(
            dataset["question"],
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        p_seqs = tokenizer(
            p_with_neg,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        max_len = p_seqs["input_ids"].size(-1)
        p_seqs["input_ids"] = p_seqs["input_ids"].view(-1, num_neg+1, max_len)
        p_seqs["attention_mask"] = p_seqs["attention_mask"].view(-1, num_neg+1, max_len)
        p_seqs["token_type_ids"] = p_seqs["token_type_ids"].view(-1, num_neg+1, max_len)

        train_dataset = TensorDataset(
            p_seqs["input_ids"], p_seqs["attention_mask"], p_seqs["token_type_ids"], 
            q_seqs["input_ids"], q_seqs["attention_mask"], q_seqs["token_type_ids"]
        )

        self.train_dataloader = DataLoader(
            train_dataset,
            shuffle=True,
            batch_size=self.args.per_device_train_batch_size
        )

        valid_seqs = tokenizer(
            dataset["context"],
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        passage_dataset = TensorDataset(
            valid_seqs["input_ids"],
            valid_seqs["attention_mask"],
            valid_seqs["token_type_ids"]
        )
        self.passage_dataloader = DataLoader(
            passage_dataset,
            batch_size=self.args.per_device_train_batch_size
        )

    def build_faiss(self, num_clusters=16):
  
        """
        Note:
            위에서 Faiss를 사용했던 기억을 떠올려보면,
            Indexer를 구성해서 .search() 메소드를 활용했습니다.
            여기서는 Indexer 구성을 해주도록 합시다.
        """
        eval_batch_size = 8

        self.search_corpus = list(set([example['text'] for example in self.wiki_dataset]))
        p_encoder = self.p_encoder
        
        # Construt dataloader
        valid_p_seqs = self.tokenizer(
            self.search_corpus,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        valid_dataset = TensorDataset(
            valid_p_seqs["input_ids"],
            valid_p_seqs["attention_mask"],
            valid_p_seqs["token_type_ids"]
        )
        valid_sampler = SequentialSampler(valid_dataset)
        valid_dataloader = DataLoader(
            valid_dataset,
            sampler=valid_sampler,
            batch_size=eval_batch_size
        )

        # Inference using the passage encoder to get dense embeddeings
        p_embs = []

        with torch.no_grad():

            epoch_iterator = tqdm(
                valid_dataloader,
                desc="Iteration",
                position=0,
                leave=True
            )
            p_encoder.eval()

            for _, batch in enumerate(epoch_iterator):
                batch = tuple(t.cuda() for t in batch)

                p_inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "token_type_ids": batch[2]
                }
                
                outputs = p_encoder(**p_inputs).to("cpu").numpy()
                p_embs.extend(outputs)
        p_embs = np.array(p_embs)
        emb_dim = p_embs.shape[-1]

        quantizer = faiss.IndexFlatL2(emb_dim)
        self.indexer = faiss.IndexIVFScalarQuantizer(
            quantizer,
            quantizer.d,
            num_clusters,
            faiss.METRIC_L2
        )
        self.indexer.train(p_embs)
        self.indexer.add(p_embs)

    def train(self, args=None):
        if args is None:
            args = self.args
        batch_size = args.per_device_train_batch_size

        # Optimizer
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {"params": [p for n, p in self.p_encoder.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
            {"params": [p for n, p in self.p_encoder.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
            {"params": [p for n, p in self.q_encoder.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
            {"params": [p for n, p in self.q_encoder.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            eps=args.adam_epsilon
        )
        t_total = len(self.train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=t_total
        )

        # Start training!
        global_step = 0

        self.p_encoder.zero_grad()
        self.q_encoder.zero_grad()
        torch.cuda.empty_cache()

        train_iterator = tqdm(range(int(args.num_train_epochs)), desc="Epoch")
        # for _ in range(int(args.num_train_epochs)):
        for _ in train_iterator:

            with tqdm(self.train_dataloader, unit="batch") as tepoch:
                for batch in tepoch:

                    self.p_encoder.train()
                    self.q_encoder.train()
            
                    targets = torch.zeros(batch_size).long() # positive example은 전부 첫 번째에 위치하므로
                    targets = targets.to(args.device)

                    p_inputs = {
                        "input_ids": batch[0].view(batch_size * (self.num_neg + 1), -1).to(args.device),
                        "attention_mask": batch[1].view(batch_size * (self.num_neg + 1), -1).to(args.device),
                        "token_type_ids": batch[2].view(batch_size * (self.num_neg + 1), -1).to(args.device)
                    }
            
                    q_inputs = {
                        "input_ids": batch[3].to(args.device),
                        "attention_mask": batch[4].to(args.device),
                        "token_type_ids": batch[5].to(args.device)
                    }

                    # (batch_size*(num_neg+1), emb_dim)
                    p_outputs = self.p_encoder(**p_inputs)
                    # (batch_size*, emb_dim)
                    q_outputs = self.q_encoder(**q_inputs)

                    # Calculate similarity score & loss
                    p_outputs_t = torch.transpose(p_outputs.view(batch_size, self.num_neg + 1, -1), 1 , 2)
                    q_outputs = q_outputs.view(batch_size, 1, -1)

                    sim_scores = torch.bmm(q_outputs, p_outputs_t).squeeze()  #(batch_size, num_neg + 1)
                    sim_scores = sim_scores.view(batch_size, -1)
                    sim_scores = F.log_softmax(sim_scores, dim=1)

                    loss = F.nll_loss(sim_scores, targets)
                    tepoch.set_postfix(loss=f"{str(loss.item())}")

                    loss.backward()
                    optimizer.step()
                    scheduler.step()

                    self.p_encoder.zero_grad()
                    self.q_encoder.zero_grad()

                    global_step += 1

                    torch.cuda.empty_cache()

                    del p_inputs, q_inputs

    def get_relevant_doc(self, query, k=1):
        """
        Arguments:
            query (torch.Tensor):
                Dense Representation으로 표현된 query를 받습니다.
                문자열이 아님에 주의합시다.
            k (int, default=1):
                상위 몇 개의 유사한 passage를 뽑을 것인지 결정합니다.

        Note:
            받은 query를 이 객체에 저장된 indexer를 활용해서
            유사한 문서를 찾아봅시다.
        """

        valid_q_seqs = self.tokenizer(query, padding="max_length", truncation=True, return_tensors="pt").to("cuda")

        with torch.no_grad():
            q_encoder.eval()
            q_emb = q_encoder(**valid_q_seqs).to("cpu").numpy()
        
        q_emb = q_emb.astype(np.float32)
        D, I = self.indexer.search(q_emb, k)
        distances, index = D.tolist()[0], I.tolist()[0]
        
        distance_list, doc_list = [], []
        for d, i in zip(distances, index):
            distance_list.append(d)
            doc_list.append(self.search_corpus[i])

        return distance_list, doc_list

In [3]:
class BertEncoder(BertPreTrainedModel):
    def __init__(self, config):
        super(BertEncoder, self).__init__(config)

        self.bert = BertModel(config)
        self.init_weights()
      
    def forward(
            self,
            input_ids, 
            attention_mask=None,
            token_type_ids=None
        ): 
  
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        pooled_output = outputs[1]
        return pooled_output

In [4]:
class RobertaEncoder(RobertaPreTrainedModel):
    def __init__(self, config):
        super(RobertaEncoder, self).__init__(config)

        self.roberta = RobertaModel(config)
        self.init_weights()
      
    def forward(
            self,
            input_ids, 
            attention_mask=None,
            token_type_ids=None
        ): 

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        pooled_output = outputs[1]
        return pooled_output

In [5]:
# 평가를 안하니 검증데이터와 훈련데이터를 합칩니다.
train_dataset = load_from_disk('/home/ubuntu/workspace/data/train_dataset')
train = train_dataset['train'].to_dict()
valid = train_dataset['validation'].to_dict()
for key in train.keys():
  train[key].extend(valid[key])
train_dataset = Dataset.from_dict(train)
train_dataset

In [10]:
args = TrainingArguments(
    output_dir="dense_retireval",
    evaluation_strategy="epoch",
    learning_rate=3e-4,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=2,
    weight_decay=0.01
)
model_checkpoint = "klue/bert-base"

# 혹시 위에서 사용한 encoder가 있다면 주석처리 후 진행해주세요 (CUDA ...)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
p_encoder = BertEncoder.from_pretrained(model_checkpoint).to(args.device)
q_encoder = BertEncoder.from_pretrained(model_checkpoint).to(args.device)

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.tr

In [22]:
retriever = DenseRetrieval_with_Faiss(
    args=args,
    dataset=train_dataset,
    tokenizer=tokenizer,
    p_encoder=p_encoder,
    q_encoder=q_encoder,
    is_trained=True
)

In [12]:
retriever.train()

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2096 [00:00<?, ?batch/s]

  0%|          | 0/2096 [00:00<?, ?batch/s]

In [23]:
retriever.build_faiss()

Iteration:   0%|          | 0/6996 [00:00<?, ?it/s]

In [None]:
# 테스트
query = "도플라밍고를 무찔렀으며 3억베리의 현상금을 가진 인물은?"
results = retriever.get_relevant_doc(query=query, k=5)