# dense retrieval

작성일자: 210119\
작성자: 조진욱\
목표: 
1. 2-1 에서 Sparse 임베딩 모델(TFIDF) 대신 Dense 임베딩 모델(DPR) 을 사용해보자
2. FAISS, datasets 패키지를 이용해서 retrieval 과정을 만들어보자
순서: 
2-1과 동일

비고:
1. DPR, faiss 와 datasets 를 그대로 사용한 예시라서 매우 쉽게 짜여져있음
교육을 위해서는 어느정도 구현을 할 수 있도록 만들어 둬야함. retrieval.py 에 코드 작성중

TO DO:
1. batch 과정이 이해 어려울 것 같으면 처음엔 batch 1 로 
2. 마지막에 2-1 의 retrieval 결과와 성능을 비교할 예정

In [1]:
import json
import pandas as pd
from tqdm.notebook import tqdm
import os

In [2]:
import faiss

In [3]:
from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast

In [4]:
def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
    """Compute the DPR embeddings of document passages"""
    input_ids = ctx_tokenizer(
        documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
    )["input_ids"]
    embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
    return {"embeddings": embeddings.detach().cpu().numpy()}

In [5]:
import torch
from functools import partial
device = "cuda" if torch.cuda.is_available() else "cpu"
dpr_ctx_encoder_model_name = "facebook/dpr-ctx_encoder-multiset-base"
rag_model_name = "facebook/rag-sequence-nq"
batch_size = 16

## dpr 용 csv 만들기

In [7]:
import config as cfg
with open(f'{cfg.data_dir}/dev_context.json', 'r') as reader:
    input_data = json.load(reader)['data']

row_list = []
count = 1
for entry in input_data:
    title = entry['title']
    paragraphs = entry['paragraphs']
    for _ in range(len(paragraphs[0])):
        context_text = paragraphs[0][f'context{count}'] # 데이터의 구성이 context0, 1, 2... 이렇게 구성
        count += 1
        temp = {
            'title': title,
            'text': context_text,
        }
        row_list.append(temp)
df = pd.DataFrame(row_list)
df.to_csv(f"{cfg.dense_dir}/dev_tc.csv", sep='\t', index=False)

## Step1 datasets 패키지 이용해서 데이터 로드 및 데이터셋 구축
https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files

In [8]:
from datasets import Features, Sequence, Value, load_dataset
from typing import List, Optional

In [9]:
dataset = load_dataset(
    "csv", data_files=[f"{cfg.dense_dir}/dev_tc.csv"], split="train", delimiter="\t", column_names=["title", "text"]
)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1657.0, style=ProgressStyle(description…




Using custom data configuration default


Downloading and preparing dataset csv/default-e496534a7f729b25 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/csv/default-e496534a7f729b25/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-e496534a7f729b25/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2. Subsequent calls will reuse this data.


In [10]:
def split_text(text: str, n=100, character=" ") -> List[str]:
    """Split the text every ``n``-th occurrence of ``character``"""
    text = text.split(character)
    return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]

def split_documents(documents: dict) -> dict:
    """Split documents into passages"""
    titles, texts = [], []
    for title, text in zip(documents["title"], documents["text"]):
        if text is not None:
            for passage in split_text(text):
                titles.append(title if title is not None else "")
                texts.append(passage)
    return {"title": titles, "text": texts}

dataset = dataset.map(split_documents, batched=True, num_proc=4)







dataset 객체가 context encoder가 만든 임베딩을 저장해둠

In [11]:
ctx_encoder = DPRContextEncoder.from_pretrained(dpr_ctx_encoder_model_name).to(device=device)
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(dpr_ctx_encoder_model_name)
new_features = Features(
    {"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
)  # optional, save as float32 instead of float64 to save space

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=492.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=438018293.0, style=ProgressStyle(descri…




In [12]:
dataset = dataset.map(
    partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
    batched=True,
    batch_size=batch_size,
    features=new_features,
)

HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




In [13]:
passages_path = os.path.join(cfg.dense_dir, "dpr_dataset")
dataset.save_to_disk(passages_path)

## Step2 faiss 로 데이터셋을 index, 저장해두기

In [14]:
# === index config ===
dim = 768 # The dimension of the embeddings to pass to the HNSW Faiss index.
num = 128 # The number of bi-directional links created for every new element during the HNSW index construction.

index = faiss.IndexHNSWFlat(dim, num, faiss.METRIC_INNER_PRODUCT)
dataset.add_faiss_index("embeddings", custom_index=index)

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




Dataset({
    features: ['text', 'title', 'embeddings'],
    num_rows: 1466
})

In [15]:
# save the index
index_path = os.path.join(cfg.dense_dir, "dpr_dataset_hnsw_index.faiss")
dataset.get_index("embeddings").save(index_path)
# dataset.load_faiss_index("embeddings", index_path)  # to reload the index

## Step3 저장해둔 dataset 과 index를 가지고 retrieve하기

In [24]:
from transformers import (RagRetriever,
DPRConfig,
DPRQuestionEncoderTokenizer,
DPRQuestionEncoder)

In [25]:
retriever = RagRetriever.from_pretrained(
    rag_model_name, index_name="custom", indexed_dataset=dataset
)

encoder_model_name = "facebook/dpr-question_encoder-single-nq-base"
config = DPRConfig.from_pretrained(encoder_model_name)
question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
    encoder_model_name, config=config)
encoder = DPRQuestionEncoder.from_pretrained(encoder_model_name)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=493.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=437986065.0, style=ProgressStyle(descri…




In [29]:
# question = "Where did Greece culture begin?"
# question = "What do neuroanatomists study?" # 10761 doc expected, 55th title (train)
question = "Who produced American Idol?" # 19 Entertainment
input_ids = question_encoder_tokenizer(question, return_tensors="pt")["input_ids"]

In [30]:
question_hidden_states = encoder(input_ids)[0]

In [31]:
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(question_hidden_states.cpu().detach().to(torch.float32).numpy(),
                                                              n_docs=3)

In [32]:
retrieved_doc_embeds.shape

(1, 3, 768)

In [33]:
doc_ids

array([[  6,   1, 172]])

In [40]:
doc_ids.tolist()[0]

[6, 1, 172]

In [34]:
doc_dicts[0].keys()

dict_keys(['embeddings', 'text', 'title'])

In [35]:
doc_dicts[0]['title']

['American_Idol', 'American_Idol', 'American_Idol']

In [36]:
doc_dicts[0]['text']

['American Idol was based on the British show Pop Idol created by Simon Fuller, which was in turn inspired by the New Zealand television singing competition Popstars. Television producer Nigel Lythgoe saw it in Australia and helped bring it over to Britain. Fuller was inspired by the idea from Popstars of employing a panel of judges to select singers in audition. He then added other elements, such as telephone voting by the viewing public (which at the time was already in use in shows such as the Eurovision Song Contest), the drama of backstories and real-life soap opera unfolding in',
 'American Idol is an American singing competition series created by Simon Fuller and produced by 19 Entertainment, and is distributed by FremantleMedia North America. It began airing on Fox on June 11, 2002, as an addition to the Idols format based on the British series Pop Idol and has since become one of the most successful shows in the history of American television. The concept of the series is to f

In [37]:
for attr, val in index.__dict__.items():
    print(f"{attr} = {val}")

this = <Swig Object of type 'faiss::IndexHNSWFlat *' at 0x7f3d10336750>
