In [1]:
import sys
import os
os.environ['HF_HOME']='/hdd/'
from datasets import load_dataset, save_to_disk

/hdd/


In [2]:
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional

import torch
from datasets import Features, Sequence, Value, load_dataset

import faiss
from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizerFast,
    HfArgumentParser,
    RagRetriever,
    RagSequenceForGeneration,
    RagTokenizer,
)


logger = logging.getLogger(__name__)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:

# compute the embeddings
ctx_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base').to(device=device)
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
new_features = Features(
    {"text": Value("string"), "title": Value("string"), "wikipedia_id": Value("string"), "embeddings": Sequence(Value("float32"))}
)  # optional, save as float32 instead of float64 to save space

In [4]:
torch.cuda.is_available()

True

In [5]:
kilt_wiki = load_dataset("kilt_wikipedia", data_dir='/hdd/kilt/', split='full')

Using custom data configuration default
Reusing dataset kilt_wikipedia (/hdd/datasets/kilt_wikipedia/default/0.0.0/a48aa8d021c82ff4e2210a596893076073305e17e3125949291227be54e42b9b)


In [6]:
kilt_wiki

Dataset({
    features: ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories', 'wikidata_info', 'history'],
    num_rows: 5903530
})

In [7]:
def split_text(text: str, n=100, character=" ") -> List[str]:
    """Split the text every ``n``-th occurrence of ``character``"""
    text = text.split(character)
    return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]


def split_documents(documents: dict) -> dict:
    """Split documents into passages"""
    wikipedia_ids, titles, texts = [], [], []
    for title, text, _id in zip(documents["wikipedia_title"], documents['text'], documents['wikipedia_id']):
        if text is not None:
            # convert from list to str
            paragraphs= text['paragraph']
            text = " ".join(paragraphs)
            for passage in split_text(text):
                wikipedia_ids.append(_id)
                texts.append(passage)
                titles.append(title)
    return {"title": titles, "text": texts, 'wikipedia_id': wikipedia_ids}


def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
    """Compute the DPR embeddings of document passages"""
    input_ids = ctx_tokenizer(
        documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
    )["input_ids"]
    embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
    return {"embeddings": embeddings.detach().cpu().numpy()}

In [2]:
kilt_wiki.column_names

NameError: name 'kilt_wiki' is not defined

In [9]:
# process the dataset
# split into passages
dataset = kilt_wiki
dataset = dataset.map(split_documents, batched=True, remove_columns=kilt_wiki.column_names, num_proc=8)
# # embed the docs
dataset = dataset.map(
    partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
    batched=True,*/
    batch_size=100,
    features=new_features,
)











HBox(children=(FloatProgress(value=0.0, max=291673.0), HTML(value='')))




In [10]:
dataset

Dataset({
    features: ['text', 'title', 'wikipedia_id', 'embeddings'],
    num_rows: 29167229
})

In [11]:
#dataset.save_to_disk("/hdd/kilt")

In [None]:
dataset.add_faiss_index(column='embeddings')

HBox(children=(FloatProgress(value=0.0, max=29168.0), HTML(value='')))

In [1]:
dataset

NameError: name 'dataset' is not defined