In [1]:
# Install ElasticSearch
!pip install Elasticsearch
!service elasticsearch start

 * Starting Elasticsearch Server
 * Already running.
   ...done.


In [2]:
# import library

from elasticsearch import Elasticsearch
import json
import tqdm
import pandas as pd

In [3]:
class Elastic_retriever():

    def __init__(self, wiki_json_path, hundred = False, index_name = None):
        self.wiki_list = [] # saves text and document id only
        
        # load wikipedia data and drop needless informations 
        with open(wiki_json_path, "r", encoding = "utf-8") as f:
            wiki = json.load(f)

            if hundred:
                len_wiki = 100
            else:
                len_wiki = len(wiki)

            for ind in range(len_wiki):
                temp_wiki = wiki[str(ind)]
                self.wiki_list.append({"text": temp_wiki["text"], "document_id" : temp_wiki["document_id"]})

        del wiki # for memory usage
        self.es = Elasticsearch("localhost:9200")

        if index_name is not None:
            self.index_name = index_name
        else:
            self.index_name = 'klue_mrc_wikipedia_index'

    def _create_indice(self, index_config = None):

        if index_config is None:
            index_config = {
                "settings": {
                    "analysis": {
                        "analyzer": {
                            "standard_analyzer": {
                                "type": "standard"
                            }
                        }
                    }
                },
                "mappings": {
                    "dynamic": "strict", 
                    "properties": {
                        "document_id": {"type": "long",},
                        "text": {"type": "text", "analyzer": "standard_analyzer"}
                        }
                    }
                }

        if self.es.indices.exists(index=self.index_name):
            self.es.indices.delete(index=self.index_name)

        self.es.indices.create(index=self.index_name, body=index_config, ignore=400)

    def _populate_index(self):
        
        for i in tqdm.tqdm(range(len(self.wiki_list))):
            self.es.index(index = self.index_name, id = i, body = self.wiki_list[i])
        
    def config_and_index(self, index_name = None, index_config = None):
        self._create_indice(index_name, index_config)
        self._populate_index(index_name)

    def search(self, query, num_return, index_name = None):
        if index_name is None:
            index_name = 'klue_mrc_wikipedia_index'
        answer = self.es.search(index=index_name, q = query, size = num_return)
        return answer

In [6]:
wiki_json_path = "/opt/ml/code/preprocessed_json_v3.json"

retriever = Elastic_retriever(wiki_json_path)
retriever._create_indice(index_config = index_config)
retriever._populate_index()

  self.es.indices.create(index=self.index_name, body=index_config, ignore=400)
  self.es.index(index = self.index_name, id = i, body = self.wiki_list[i])
100%|██████████| 60613/60613 [04:36<00:00, 219.18it/s]


In [5]:
index_config = {
        "settings": {
            "analysis": {
                "filter":{
                    "my_stop_filter": { 
                        "type" : "stop",
                        "stopwords_path" : "stop_words.txt" # /etc/elastic안에 txt파일이 존재해야 댑니다
                    }
                },
                "analyzer": {
                    "nori_analyzer": {
                        "type": "custom",
                        "tokenizer": "nori_tokenizer", # 노리 형태소 깔아야대는데 에러나면 맨위에 참고해서 깔기
                        "decompound_mode": "discard",
                        "filter" : ["my_stop_filter"]# 위에서 정의한 stopword
                    }
                }
            }
        },
        "mappings": {
            "dynamic": "strict", # 먼지 잘모르겟
            "properties": {
                "document_id": {"type": "long",},
                "text": {"type": "text", "analyzer": "nori_analyzer"}
                }
            }
        }

In [8]:
def convert(train_path, valid_path):
    # make a list of dictionary
    total_data = [] # {context, question, document_id}

    train_df = pd.read_csv(train_path, index_col = 0)
    valid_df = pd.read_csv(valid_path, index_col = 0)

    total_df = pd.concat([train_df, valid_df])

    for i in range(len(total_df)):
        temp_data = total_df.iloc[i]
        total_data.append({"text":temp_data.context, "question":temp_data.question, "document_id": temp_data.document_id})

    return total_data

train_data = "/opt/ml/code/train_dataset_no_tilde.csv"
valid_data = "/opt/ml/code/valid_dataset_no_tilde.csv"

total_data = convert(train_data, valid_data)

In [66]:
import tqdm
import re

def show_the_result(retriever, total_data):

    results = [0]*21

    total_data_len = len(total_data)
    
    for ind in tqdm.tqdm(range(total_data_len)):

        temp_data = total_data[ind]
        query = re.sub("~","-", temp_data["question"])
        query = re.sub("/","", query)
        document_id = temp_data["document_id"]

        hit_ones = retriever.search(query, 20)["hits"]["hits"]

        if hit_ones: #만약 검출이 되었다면
            result = [hit_one["_source"]["document_id"] for hit_one in hit_ones]
            
            if document_id in result:
                found_index = result.index(document_id)
                results[found_index] +=1 
            else:
                results[-1] += 1

    return results

In [76]:
results = show_the_result(retriever, total_data)

100%|██████████| 4192/4192 [00:47<00:00, 88.25it/s]


In [68]:
def pretty_result(result):
    total = sum(result)
    top_1 = result[0]
    top_5 = sum(result[:2])
    top_10 = sum(result[:3])
    top_20 = sum(result[:4])

    print(f"===Retrieval Result===\n")
    print(f"top 1 : {top_1*100/total}%")
    print(f"top 5 : {top_5*100/total}%")
    print(f"top 10 : {top_10*100/total}%")
    print(f"top 20 : {top_20*100/total}%")
    print(f"failed to predict : {result[-1]*100/total}%")

In [77]:
pretty_result(results)

===Retrieval Result===

top 1 : 54.87950369840134%
top 5 : 85.65974707706991%
top 10 : 88.69005010737294%
top 20 : 91.86351706036746%
failed to predict : 8.136482939632545%
