# Elasticsearch와 BM25를 이용해 Sparse Ensemble 구현하기

In [1]:
from rank_bm25 import BM25Okapi, BM25Plus
from elasticsearch import Elasticsearch

from transformers import AutoTokenizer
import os
import json
import pickle
import numpy as np
import pandas as pd
import re
from tqdm.auto import tqdm

from datasets import load_from_disk

## 전처리 함수 및 데이터 로드

In [2]:
def preprocess_retrieval(corpus):
    corpus = corpus.replace("\\n", "")
    corpus = re.sub(f"[^- ㄱ-ㅎㅏ-ㅣ가-힣0-9a-zA-Zぁ-ゔァ-ヴー々〆〤一-龥]", " ", corpus)
    corpus = ' '.join(corpus.split())
    return corpus

In [3]:
with open("../data/wikipedia_documents.json", "r", encoding="utf-8") as f:
    wiki = json.load(f)
contexts = list(dict.fromkeys([v["text"] for v in wiki.values()]))

In [4]:
train_dataset = load_from_disk("../data/train_dataset")
train_context, valid_context = [], []
train_query, valid_query = [], []
for data in tqdm(train_dataset['train']):
    train_context.append(preprocess_retrieval(data['context']))
    train_query.append(preprocess_retrieval(data['question']))
for data in tqdm(train_dataset['validation']):
    valid_context.append(preprocess_retrieval(data['context']))
    valid_query.append(preprocess_retrieval(data['question']))

HBox(children=(FloatProgress(value=0.0, max=3952.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=240.0), HTML(value='')))




In [5]:
train_ids = list(range(len(train_context)))
valid_ids = list(range(len(valid_context)))

In [6]:
contexts = [preprocess_retrieval(corpus) for corpus in tqdm(contexts)]

HBox(children=(FloatProgress(value=0.0, max=56737.0), HTML(value='')))




## 1. BM25 함수

In [7]:
def make_bm25(contexts, model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenized_wiki = [tokenizer.tokenize(corpus) for corpus in tqdm(contexts)]
    bm25 = BM25Okapi(tqdm(tokenized_wiki))
    return bm25, tokenizer

In [8]:
def retrieval_bm25(model, query, contexts, tokenizer, topk=1):
    tokenized_query = tokenizer.tokenize(query)
    results = model.get_scores(tokenized_query)
    sorted_result = np.argsort(results)[::-1]
    doc_score = results[sorted_result].tolist()[:topk]
    doc_indices = sorted_result.tolist()[:topk]
    return doc_score, doc_indices

In [9]:
my_bm25, tokenizer = make_bm25(contexts=contexts, model_name="monologg/koelectra-base-v3-discriminator")

HBox(children=(FloatProgress(value=0.0, max=56737.0), HTML(value='')))

Token indices sequence length is longer than the specified maximum sequence length for this model (1033 > 512). Running this sequence through the model will result in indexing errors





HBox(children=(FloatProgress(value=0.0, max=56737.0), HTML(value='')))




In [11]:
right, wrong = 0, 0
for i in tqdm(range(len(valid_query))):
    results = my_bm25.get_top_n(tokenizer.tokenize(valid_query[i]), contexts, n=3)
    if valid_context[i] in results:
        right += 1
    else:
        wrong += 1
print(f"Total Length : {right+wrong}")
print(f"Accuracy : {100*right/(right+wrong):.2f}")

HBox(children=(FloatProgress(value=0.0, max=240.0), HTML(value='')))


Total Length : 240
Accuracy : 82.50


In [12]:
right, wrong = 0, 0
for i in tqdm(range(len(valid_query))):
    scores, indices = retrieval_bm25(my_bm25, valid_query[i], contexts, tokenizer, 3)
    predict = [contexts[idx] for idx in indices]
    if valid_context[i] in predict:
        right += 1
    else:
        wrong += 1
print(f"Total Length : {right+wrong}")
print(f"Accuracy : {100*right/(right+wrong):.2f}")

HBox(children=(FloatProgress(value=0.0, max=240.0), HTML(value='')))


Total Length : 240
Accuracy : 82.50


### 새로 발견한 중요한 사실
### get_top_n()과 get_scores()에서 동일한 결과를 얻으려면 입력들이 모두 전처리 들어간 상태여야함

## 2. Elasticsearch 함수

In [10]:
def make_elasticsearch(contexts, index_name):
    os.system("service elasticsearch start")
    INDEX_NAME = index_name

    INDEX_SETTINGS = {"settings" : {"index":{"analysis":{"analyzer":{"korean":{"type":"custom",
                                            "tokenizer":"nori_tokenizer","filter": [ "shingle" ],}}}}},
      "mappings": {"properties" : {"context" : {"type" : "text","analyzer": "korean","search_analyzer": "korean"},}}}
    
    DOCS = {}
    for i in tqdm(range(len(contexts))):
        DOCS[i] = {'context':contexts[i]}
        
    try:
        es.transport.close()
    except:
        pass
    es = Elasticsearch()
    
    if es.indices.exists(INDEX_NAME):
        es.indices.delete(index=INDEX_NAME)
    es.indices.create(index=INDEX_NAME, body=INDEX_SETTINGS)
    
    for doc_id, doc in tqdm(DOCS.items()):
        es.index(index=INDEX_NAME,  id=doc_id, body=doc)
        
    return es

In [13]:
es = make_elasticsearch(contexts, "wiki_index")

 * Starting Elasticsearch Server
 * Already running.
   ...done.


HBox(children=(FloatProgress(value=0.0, max=56737.0), HTML(value='')))




  if es.indices.exists(INDEX_NAME):
  es.indices.create(index=INDEX_NAME, body=INDEX_SETTINGS)


HBox(children=(FloatProgress(value=0.0, max=56737.0), HTML(value='')))

  es.index(index=INDEX_NAME,  id=doc_id, body=doc)





In [14]:
def retrieval_es(model, query, index_name, topk=1):
    try:
        res = model.search(index=index_name, q=query, size=topk)
    except:
        mod_q = query.replace("%", " ")
        res = model.search(index=index_name, q=mod_q, size=topk)
    
    doc_score = [float(res['hits']['hits'][idx]['_score']) for idx in range(topk)]
    doc_indices = [int(res['hits']['hits'][idx]['_id']) for idx in range(topk)]
    return doc_score, doc_indices

In [15]:
right, wrong = 0, 0
for i in tqdm(range(len(valid_query))):
    scores, indices = retrieval_es(es, valid_query[i], "wiki_index", 10)
    predict = [contexts[idx] for idx in indices]
    if valid_context[i] in predict:
        right += 1
    else:
        wrong += 1
print(f"Total Length : {right+wrong}")
print(f"Accuracy : {100*right/(right+wrong):.2f}")

HBox(children=(FloatProgress(value=0.0, max=240.0), HTML(value='')))


Total Length : 240
Accuracy : 92.08


In [16]:
scores, indices = retrieval_es(es, valid_query[0], "wiki_index", 3)
scores, type(scores[0])

([26.301352, 25.264828, 24.809431], float)

## 3. Ensemble 함수

In [18]:
# my_bm25, tokenizer = make_bm25(contexts=contexts, model_name="monologg/koelectra-base-v3-discriminator")
my_bm, tokenizer2 = make_bm25(contexts=contexts, model_name="klue/bert-base")

HBox(children=(FloatProgress(value=0.0, max=56737.0), HTML(value='')))

Token indices sequence length is longer than the specified maximum sequence length for this model (1003 > 512). Running this sequence through the model will result in indexing errors





HBox(children=(FloatProgress(value=0.0, max=56737.0), HTML(value='')))




In [68]:
from collections import Counter

def normalize_score(scores):
    total = sum(scores)
    normalized = [sc/total for sc in scores]
    return normalized

def mixed_scores(id1, sc1, id2, sc2, id3, sc3):
    # sc1, sc2, sc3 = normalize_score(sc1), normalize_score(sc2), normalize_score(sc3)
    cross = list(set(id1 + id2 + id3))
    board1, board2, board3 = list(zip(id1, sc1)), list(zip(id2, sc2)), list(zip(id3, sc3))
    total_board = [[c, 0] for c in cross]
    for i in range(len(board1)):
        for idx in range(len(total_board)):
            if total_board[idx][0] == board1[i][0]:
                total_board[idx][1] += board1[i][1]
                break
    for i in range(len(board2)):
        for idx in range(len(total_board)):
            if total_board[idx][0] == board2[i][0]:
                total_board[idx][1] += board2[i][1]
                break
    for i in range(len(board3)):
        for idx in range(len(total_board)):
            if total_board[idx][0] == board3[i][0]:
                total_board[idx][1] += board3[i][1]
                break
    indices = [idx[0] for idx in total_board]
    scores = [score[1] for score in total_board]
    return indices, scores

In [69]:
def ensemble(es, bm1, bm2, tokenizer1, tokenizer2, contexts, texts, queries, topk=1, index_name="wiki_index"):
    final_indicies, final_scores = [], []
    len_topk = []
    right, wrong = 0, 0
    for i in tqdm(range(len(queries))):
        k = 10
        scores1, indicies1 = retrieval_bm25(bm1, queries[i], contexts, tokenizer1, topk=topk)
        scores2, indicies2 = retrieval_bm25(bm2, queries[i], contexts, tokenizer2, topk=topk)
        scores3, indicies3 = retrieval_es(es, queries[i], index_name, topk=topk)
        final_id, final_score = mixed_scores(indicies1, scores1, indicies2, scores2, indicies3, scores3)
        final_indicies.append(final_id)
        final_scores.append(final_score)
        # print(final_id, final_score)
        board = list(zip(final_id, final_score))
        board = sorted(board, key = lambda x : -x[1])
        if len(final_id) < k:
            k = len(final_id)
        predict = [contexts[idx] for idx in final_id[:k]]
        len_topk.append(len(predict))
        if texts[i] in predict:
            right += 1
        else:
            wrong += 1
    return board, right/(right+wrong), len_topk

In [70]:
final_indicies, acc, len_topk10 = ensemble(es = es,
                                      bm1 = my_bm25,
                                      bm2 = my_bm,
                                      tokenizer1 = tokenizer,
                                      tokenizer2 = tokenizer2,
                                      contexts = contexts,
                                      texts = valid_context,
                                      queries = valid_query,
                                      topk = 10)

print(f"Ensemble Acc. : {100*acc:.2f}%")
print(f"Min : {min(len_topk10)}")
print(f"Max : {max(len_topk10)}")
print(f"Avg : {sum(len_topk10)/len(len_topk10):.2f}")

HBox(children=(FloatProgress(value=0.0, max=240.0), HTML(value='')))


Ensemble Acc. : 63.75%
Min : 10
Max : 10
Avg : 10.00


In [71]:
final_indicies, acc, len_topk20 = ensemble(es = es,
                                      bm1 = my_bm25,
                                      bm2 = my_bm,
                                      tokenizer1 = tokenizer,
                                      tokenizer2 = tokenizer2,
                                      contexts = contexts,
                                      texts = valid_context,
                                      queries = valid_query,
                                      topk = 20)

print(f"Ensemble Acc. : {100*acc:.2f}%")
print(f"Min : {min(len_topk20)}")
print(f"Max : {max(len_topk20)}")
print(f"Avg : {sum(len_topk20)/len(len_topk20):.2f}")

HBox(children=(FloatProgress(value=0.0, max=240.0), HTML(value='')))


Ensemble Acc. : 35.00%
Min : 10
Max : 10
Avg : 10.00


In [72]:
final_indicies, acc, len_topk30 = ensemble(es = es,
                                      bm1 = my_bm25,
                                      bm2 = my_bm,
                                      tokenizer1 = tokenizer,
                                      tokenizer2 = tokenizer2,
                                      contexts = contexts,
                                      texts = valid_context,
                                      queries = valid_query,
                                      topk = 30)

print(f"Ensemble Acc. : {100*acc:.2f}%")
print(f"Min : {min(len_topk30)}")
print(f"Max : {max(len_topk30)}")
print(f"Avg : {sum(len_topk30)/len(len_topk30):.2f}")

HBox(children=(FloatProgress(value=0.0, max=240.0), HTML(value='')))


Ensemble Acc. : 24.58%
Min : 10
Max : 10
Avg : 10.00


In [73]:
final_indicies, acc, len_topk50 = ensemble(es = es,
                                      bm1 = my_bm25,
                                      bm2 = my_bm,
                                      tokenizer1 = tokenizer,
                                      tokenizer2 = tokenizer2,
                                      contexts = contexts,
                                      texts = valid_context,
                                      queries = valid_query,
                                      topk = 50)

print(f"Ensemble Acc. : {100*acc:.2f}%")
print(f"Min : {min(len_topk50)}")
print(f"Max : {max(len_topk50)}")
print(f"Avg : {sum(len_topk50)/len(len_topk50):.2f}")

HBox(children=(FloatProgress(value=0.0, max=240.0), HTML(value='')))


Ensemble Acc. : 11.67%
Min : 10
Max : 10
Avg : 10.00


In [67]:
final_indicies, acc, len_topk5 = ensemble(es = es,
                                      bm1 = my_bm25,
                                      bm2 = my_bm,
                                      tokenizer1 = tokenizer,
                                      tokenizer2 = tokenizer2,
                                      contexts = contexts,
                                      texts = valid_context,
                                      queries = valid_query,
                                      topk = 5)

print(f"Ensemble Acc. : {100*acc:.2f}%")
print(f"Min : {min(len_topk5)}")
print(f"Max : {max(len_topk5)}")
print(f"Avg : {sum(len_topk5)/len(len_topk5):.2f}")

HBox(children=(FloatProgress(value=0.0, max=240.0), HTML(value='')))


Ensemble Acc. : 25.83%
Min : 2
Max : 2
Avg : 2.00
