In [1]:
MODEL_NAME = "intfloat/multilingual-e5-small"

WORKING_DIR = "/home/hotchpotch/src/huggingface.co/datasets/hotchpotch/wikipedia-passages-jawiki-embeddings/"

WIKIPEDIA_DS = "singletongue/wikipedia-utils"
WIKIPEDIA_DS_NAME = "passages-c400-jawiki-20230403"
# DS_NAME = 'hotchpotch/wikipedia-ja-20231030'

INDEX_NAME = "faiss_indexes/passages-c400-jawiki-20230403/multilingual-e5-small-passage/index_m96_mbit8_nlist512.faiss"

# INDEX_NAME = 'faiss_indexes/passages-c400-jawiki-20230403/multilingual-e5-small-passage/index_m8_mbit8_nlist512.faiss'


# INDEX_NAME = 'faiss_indexes/passages-c400-jawiki-20230403/multilingual-e5-small-passage/index_flat_l2.faiss'

In [2]:
# MODEL_NAME = 'intfloat/multilingual-e5-large'

# WORKING_DIR = '/home/hotchpotch/src/huggingface.co/datasets/hotchpotch/wikipedia-passages-jawiki-embeddings/'

# WIKIPEDIA_DS = 'singletongue/wikipedia-utils'
# WIKIPEDIA_DS_NAME = 'passages-c400-jawiki-20230403'
# # DS_NAME = 'hotchpotch/wikipedia-ja-20231030'
# INDEX_NAME = 'faiss_indexes/passages-c400-jawiki-20230403/multilingual-e5-large-passage/index_m64_mbit8_nlist512.faiss'

In [3]:
from datasets.download import DownloadManager
from datasets import load_dataset

ds = load_dataset(path=WIKIPEDIA_DS, name=WIKIPEDIA_DS_NAME, split="train")
# dm = DownloadManager()
# index_pass  = dm.download(f"https://huggingface.co/datasets/{DS_NAME}/resolve/main/{INDEX_NAME}")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from dataclasses import dataclass
import json
import urllib.request

jaqket_v1_dev_urls = [
    "https://jaqket.s3.ap-northeast-1.amazonaws.com/data/aio_01/dev1_questions.json",
    "https://jaqket.s3.ap-northeast-1.amazonaws.com/data/aio_01/dev2_questions.json",
]


# jaqket v1
@dataclass
class JaqketQuestionV1:
    qid: str
    question: str
    answer_entity: str
    label: int
    answer_candidates: list[str]
    original_question: str


def load_jaqket_v1_dev(urls):
    res = []
    for url in urls:
        with urllib.request.urlopen(url) as f:
            # f は 1行ごとに処理
            data = [json.loads(line.decode("utf-8")) for line in f]
        for d in data:
            # label position
            d["label"] = d["answer_candidates"].index(d["answer_entity"])
            # if -1
            if d["label"] == -1:
                raise ValueError(
                    f"answer_entity not found in answer_candidates: {d['answer_entity']}, {d['answer_candidates']}"
                )
            res.append(JaqketQuestionV1(**d))
    return res


jaqket_v1_dev = load_jaqket_v1_dev(jaqket_v1_dev_urls)

In [5]:
from hf_hub_ctranslate2 import CT2SentenceTransformer

MODEL = CT2SentenceTransformer(
    MODEL_NAME,
    compute_type="int8_float16",
    device="cuda",
)
MODEL.max_seq_length = 512

In [6]:
import faiss

index = faiss.read_index(WORKING_DIR + INDEX_NAME)
index.ntotal

5555583

In [7]:
index.nprobe = 128

In [8]:
import numpy as np

prefix = ""

if "-e5-" in INDEX_NAME:
    prefix = "query: "


def texts_to_embs(texts, prefix=prefix) -> np.ndarray:
    texts = [prefix + text for text in texts]
    return MODEL.encode(texts, normalize_embeddings=True)  # type: ignore

In [9]:
question_embs = texts_to_embs([q.question for q in jaqket_v1_dev])

In [10]:
question_embs.shape

(1992, 384)

In [11]:
import time


def faiss_search_by_embs(embs, faiss_index=index, top_k=5):
    start_time = time.time()
    D, I = faiss_index.search(embs, top_k)
    end_time = time.time()
    print(f"search time: {end_time - start_time}")
    return D, I


scores, indexes = faiss_search_by_embs(question_embs)

search time: 7.370864152908325


In [12]:
def find_label_by_indexes(idxs, jaqket: JaqketQuestionV1, wiki_ds):
    for idx in idxs:
        data = wiki_ds[idx]
        title = data["title"]
        # まずは title が jaqket の answer_candidates に完全一致するか
        for j, candidate in enumerate(jaqket.answer_candidates):
            if candidate == title:
                return j
        # XXX: RAG のユースケースを考えると、ここで続きも計算したほうが良い?

    for idx in idxs:
        data = wiki_ds[idx]
        text = data["text"]
        # 次に text が jaqket の answer_candidates に含まれているか
        for j, candidate in enumerate(jaqket.answer_candidates):
            if candidate in text:
                return j
    return -1


def predict_by_indexes(indexes, jaqket_ds, wiki_ds):
    pred_labels = []
    for idxs, jaqket in zip(indexes, jaqket_ds):
        pred_label = find_label_by_indexes(idxs.tolist(), jaqket, wiki_ds)
        pred_labels.append(pred_label)
    return pred_labels


pred_labels = predict_by_indexes(indexes, jaqket_v1_dev, ds)

In [13]:
# pred labels に含まれる、-1 の割合
sum([1 for l in pred_labels if l == -1]) / len(pred_labels)

0.1179718875502008

In [14]:
labels = [q.label for q in jaqket_v1_dev]

In [15]:
# 正解率を表示
from sklearn.metrics import accuracy_score

accuracy_score(labels, pred_labels)

# ct2 + float16
# 0.6621485943775101
# ct2 + int8_float16
# 0.6651606425702812
# ct2 + int8
# 0.6651606425702812

# k=5
# 0.6731927710843374


# k=5, m64
# 0.6616465863453815

# k=5, m48
# 0.6616465863453815

# k=5, m32
# 0.588855421686747

# k=5, m24
# 0.588855421686747

0.6651606425702812