In [1]:
from __future__ import annotations

import streamlit as st
import os

import faiss
from sentence_transformers import SentenceTransformer
import torch
import pandas as pd
import os
from time import time
from datasets.download import DownloadManager
from datasets import load_dataset  # type: ignore
import time
from dataclasses import dataclass
import json
import urllib.request

JAQKET_V1_DEV_URLS = [
    "https://jaqket.s3.ap-northeast-1.amazonaws.com/data/aio_01/dev1_questions.json",
    "https://jaqket.s3.ap-northeast-1.amazonaws.com/data/aio_01/dev2_questions.json",
]


WIKIPEDIA_JA_DS = "singletongue/wikipedia-utils"
WIKIPEDIA_JS_DS_NAME = "passages-c400-jawiki-20230403"
WIKIPEDIA_JA_EMB_DS = "hotchpotch/wikipedia-passages-jawiki-embeddings"

EMB_MODEL_PQ = {
    "intfloat/multilingual-e5-small": 96,
    "intfloat/multilingual-e5-base": 192,
    "intfloat/multilingual-e5-large": 256,
    "cl-nagoya/sup-simcse-ja-base": 192,
    "pkshatech/GLuCoSE-base-ja": 192,
}

EMB_MODEL_NAMES = list(EMB_MODEL_PQ.keys())

E5_QUERY_TYPES = [
    "passage",
    "query",
]

# for tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def get_model(name: str, max_seq_length=512):
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    model = SentenceTransformer(name, device=device)
    model.max_seq_length = max_seq_length
    return model


def get_wikija_ds(name: str = WIKIPEDIA_JS_DS_NAME):
    ds = load_dataset(path=WIKIPEDIA_JA_DS, name=name, split="train")
    return ds


def get_faiss_index(
    index_name: str, ja_emb_ds: str = WIKIPEDIA_JA_EMB_DS, name=WIKIPEDIA_JS_DS_NAME
):
    target_path = f"faiss_indexes/{name}/{index_name}"
    dm = DownloadManager()
    index_local_path = dm.download(
        f"https://huggingface.co/datasets/{ja_emb_ds}/resolve/main/{target_path}"
    )
    index = faiss.read_index(index_local_path)
    index.nprobe = 256
    return index


def texts_to_embs(model, texts: list[str], prefix: str):
    texts = [prefix + text for text in texts]
    embs = model.encode(texts, normalize_embeddings=True, show_progress_bar=True)
    return embs


def faiss_search_by_embs(faiss_index, embs, top_k=5):
    start_time = time.time()
    D, I = faiss_index.search(embs, top_k)
    end_time = time.time()
    print(f"faiss search time: {end_time - start_time}")
    return D, I

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# jaqket v1
@dataclass
class JaqketQuestionV1:
    qid: str
    question: str
    answer_entity: str
    label: int
    answer_candidates: list[str]
    original_question: str


def load_jaqket_v1_dev(urls):
    res = []
    for url in urls:
        with urllib.request.urlopen(url) as f:
            # f は 1行ごとに処理
            data = [json.loads(line.decode("utf-8")) for line in f]
        for d in data:
            # label position
            d["label"] = d["answer_candidates"].index(d["answer_entity"])
            # if -1
            if d["label"] == -1:
                raise ValueError(
                    f"answer_entity not found in answer_candidates: {d['answer_entity']}, {d['answer_candidates']}"
                )
            res.append(JaqketQuestionV1(**d))
    return res

In [3]:
emb_model_name = "intfloat/multilingual-e5-large"
e5_query_or_passage = "passage"
top_k = 5

In [4]:
ds = get_wikija_ds()

In [5]:
jaqket_v1_dev = load_jaqket_v1_dev(JAQKET_V1_DEV_URLS)
# jaqket_v1_dev = jaqket_v1_dev[0:100]

In [6]:
model = get_model(emb_model_name)
model.max_seq_length = 512

In [7]:
if "-e5-" in emb_model_name:
    index_emb_model_name = f"{emb_model_name.split('/')[-1]}-{e5_query_or_passage}"
    search_text_prefix = f"query: "  # 　検索するための prefix は元データが passage でも "query: " を指定する
else:
    index_emb_model_name = emb_model_name.split("/")[-1]
    search_text_prefix = ""

emb_model_pq = EMB_MODEL_PQ[emb_model_name]
index_name = f"{index_emb_model_name}/index_IVF2048_PQ{emb_model_pq}.faiss"
faiss_index = get_faiss_index(index_name)

In [8]:
question_embs = texts_to_embs(
    model, texts=[q.question for q in jaqket_v1_dev], prefix=search_text_prefix
)
question_embs.shape  # type: ignore

Batches: 100%|██████████| 63/63 [00:02<00:00, 22.69it/s]


(1992, 1024)

In [9]:
scores, indexes = faiss_search_by_embs(faiss_index, question_embs, top_k=top_k)

faiss search time: 9.45790147781372


In [10]:
def find_label_by_indexes(idxs, jaqket: JaqketQuestionV1, wiki_ds):
    for idx in idxs:
        data = wiki_ds[idx]
        title = data["title"]
        # まずは title が jaqket の answer_candidates に完全一致するか
        for j, candidate in enumerate(jaqket.answer_candidates):
            if candidate == title:
                return j
        # XXX: RAG のユースケースを考えると、ここで続きも計算したほうが良い?

    for idx in idxs:
        data = wiki_ds[idx]
        text = data["text"]
        # 次に text が jaqket の answer_candidates に含まれているか
        for j, candidate in enumerate(jaqket.answer_candidates):
            if candidate in text:
                return j
    return -1


def predict_by_indexes(indexes, jaqket_ds, wiki_ds):
    pred_labels = []
    for idxs, jaqket in zip(indexes, jaqket_ds):
        pred_label = find_label_by_indexes(idxs.tolist(), jaqket, wiki_ds)
        pred_labels.append(pred_label)
    return pred_labels

In [11]:
pred_labels = predict_by_indexes(indexes, jaqket_v1_dev, ds)

In [12]:
# pred labels に含まれる、-1 の割合
not_found_index_count = sum([1 for l in pred_labels if l == -1]) / len(pred_labels)
print(not_found_index_count)

0.05973895582329317


In [13]:
# 正解率を表示
from sklearn.metrics import accuracy_score

labels = [q.label for q in jaqket_v1_dev]
print(accuracy_score(labels, pred_labels))

0.7484939759036144
