In [1]:
from elasticsearch import Elasticsearch

from konlpy.tag import Kkma, Mecab

from datasets import load_from_disk

import time
import json
from contextlib import contextmanager

### 참고
https://amboulouma.com/elasticsearch-python

https://gh402.tistory.com/51

https://jvvp.tistory.com/1152

In [2]:
# Load wikipedia json data
with open('../../data/wikipedia_documents.json', 'r', encoding='utf-8') as f:
    wiki = json.load(f)

wiki_contents = list(dict.fromkeys([v['text'] for v in wiki.values()]))

In [3]:
# Check wikipedia data
print('Original data length:', len(wiki.keys()))
print('W/o replicated data length:', len(wiki_contents))

Original data length: 60613
W/o replicated data length: 56737


In [4]:
es = Elasticsearch('http://localhost:30001')

In [5]:
# Test Elasticsearch Connection
!curl -XGET localhost:30001

{
  "name" : "iCfG3OW",
  "cluster_name" : "elasticsearch",
  "cluster_uuid" : "ne-ovptgRn-Gbd331afwZw",
  "version" : {
    "number" : "5.4.3",
    "build_hash" : "eed30a8",
    "build_date" : "2017-06-22T00:34:03.743Z",
    "build_snapshot" : false,
    "lucene_version" : "6.5.1"
  },
  "tagline" : "You Know, for Search"
}


In [6]:
# Check Cluster status
!curl -XGET localhost:30001/_cat/health?v

epoch      timestamp cluster       status node.total node.data shards pri relo init unassign pending_tasks max_task_wait_time active_shards_percent
1651647195 06:53:15  elasticsearch yellow          1         1     10  10    0    0       10             0                  -                 50.0%


In [6]:
# Get Index LIst
!curl -XGET localhost:30001/_cat/indices?v

health status index      uuid                   pri rep docs.count docs.deleted store.size pri.store.size
yellow open   wiki_nouns skZkXThZTbaypb3JarTADg   5   1      56737            0     95.3mb         95.3mb
yellow open   wiki       ufIaLl2KQGKrA5iL6hGx8Q   5   1      56737            0    189.3mb        189.3mb


In [9]:
# Create Index (이미 해당 index 가 존재하면 에러남)
index = 'wiki_nouns'
if not es.indices.exists(index):
    es.indices.create(index=index)

In [23]:
# Get Index LIst  ->  생성한 index 가 보여야 함!
!curl -XGET localhost:30001/_cat/indices?v

health status index      uuid                   pri rep docs.count docs.deleted store.size pri.store.size
yellow open   wiki_nouns skZkXThZTbaypb3JarTADg   5   1          0            0       650b           650b
yellow open   wiki       ufIaLl2KQGKrA5iL6hGx8Q   5   1      56737            0    189.3mb        189.3mb


In [16]:
# Index 삭제
# es.indices.delete(index='wiki')

{'acknowledged': True}

In [7]:
# Insert a data (Index should be existed!)
# Only Do Once!

kkma = Kkma()
mecab = Mecab()

# for idx, text in enumerate(wiki_contents):
#     # body = {'text': ' '.join(kkma.nouns(text))}
#     body = {'text': ' '.join(mecab.nouns(text))}
#     # body = {'text': wiki_contents[idx]}
#     es.index(index=index, doc_type="news", id=idx+1, body=body)
#     print(f'current: {idx}', end='\r')


In [10]:
# Get the data
es.get(index=index, doc_type='news', id=len(wiki_contents))

{'_index': 'wiki_nouns',
 '_type': 'news',
 '_id': '56737',
 '_version': 1,
 'found': True,
 '_source': {'text': '협약 부당 노동 행위 제도 규율 협약 조 반 노동조합 차별 행위 보호 규정 노동조합 가입 노동조합 탈퇴 것 조건 고용 황견계약 노동조합원 노동조합 활동 이유 이익 조치 것 보호 규정 조 노동자 단체 사용자 단체 사이 상호 간 간섭 보호 규정 사용 사용 단체 지배 하 둘 목적 노동자 단체 설립 지원 노동자 단체 재정 밖 방법 지원 것 간섭 행위 노동 조건 단체 협약 규율 사용 사용 단체 노동자 단체 사이 자발 교섭 기구 발전 이용 촉진 규정'}}

In [13]:
# Delete a data
# es.delete(index=index, id=1)

In [14]:
# Search the data by Query (Main !!!!)
# body = {
#     'query': {
#         'match': {
#             'text': '유엔 국제 기구'
#         }
#     }
# }

# res = es.search(index=index, body=body)

In [11]:
org_dataset = load_from_disk('../../data/train_dataset')

org_dataset

DatasetDict({
    train: Dataset({
        features: ['__index_level_0__', 'answers', 'context', 'document_id', 'id', 'question', 'title'],
        num_rows: 3952
    })
    validation: Dataset({
        features: ['__index_level_0__', 'answers', 'context', 'document_id', 'id', 'question', 'title'],
        num_rows: 240
    })
})

In [12]:
@contextmanager
def timer(name):
    t0 = time.time()
    yield
    print(f"[{name}] done in {time.time() - t0:.3f} s")

In [21]:
# Inference for Train Validation

for i in range(1, 21):
    with timer(f'TOP K: {i}'):
        TOPK = i
        doc_scores = []
        doc_indices = []
        for j in range(len(org_dataset['validation'])):
            # By match
            body = {
                'size': TOPK,
                'query': {
                    'match': {
                        # 'text': ' '.join(kkma.nouns(org_dataset['validation']['question'][j]))
                        'text': ' '.join(mecab.nouns(org_dataset['validation']['question'][j]))
                    }
                }
            }

            
            # By match_phrase  =>  Fail
            # body = {
            #     'size': TOPK,
            #     'query': {
            #         'match_phrase': {
            #             # 'text': ' '.join(kkma.nouns(org_dataset['validation']['question'][j]))
            #             'text': ' '.join(mecab.nouns(org_dataset['validation']['question'][j]))
            #         }
            #     }
            # }


            # By terms  =>  not better than match
            # body = {
            #     'size': TOPK,
            #     'query': {
            #         'bool': {
            #             'should': [],
            #             'minimum_should_match': 3
            #         }
            #     }
            # }
            # for noun in mecab.nouns(org_dataset['validation']['question'][j]):
            #     body['query']['bool']['should'].append({'term': {'text': noun}})

            res = es.search(index=index, body=body)

            a_result_scores = []
            a_result_indices = []

            for item in res['hits']['hits']:
                a_result_scores.append(item['_score'])
                a_result_indices.append(int(item['_id'])-1)

            doc_scores.append(a_result_scores)
            doc_indices.append(a_result_indices)


        # Context Accuracy

        # correct = 0
        # for idx, doc_indice in enumerate(doc_indices):
        #     for jdx, indice in enumerate(doc_indice):
        #         if org_dataset['validation']['context'][idx] == wiki_contents[indice]:
        #             correct += 1

        # Label Accuracy  =>  more reasonable ?
        correct = 0
        # answers = [] #####
        for idx, doc_indice in enumerate(doc_indices):
            # answer = [] #####
            for jdx, indice in enumerate(doc_indice):
                if org_dataset['validation']['answers'][idx]['text'][0] in wiki_contents[indice]:
                    correct += 1
                    # answer.append(indice) #####
            # answers.append(answer)
        
        print(f"Total Validation Score: {correct/len(org_dataset['validation'])*100}%")

Total Validation Score: 72.5%
[TOP K: 1] done in 1.116 s
Total Validation Score: 99.16666666666667%
[TOP K: 2] done in 2.030 s
Total Validation Score: 112.08333333333333%
[TOP K: 3] done in 2.361 s
Total Validation Score: 120.0%
[TOP K: 4] done in 2.954 s
Total Validation Score: 128.33333333333334%
[TOP K: 5] done in 3.575 s
Total Validation Score: 136.66666666666666%
[TOP K: 6] done in 4.190 s
Total Validation Score: 142.91666666666666%
[TOP K: 7] done in 4.708 s
Total Validation Score: 149.16666666666666%
[TOP K: 8] done in 5.315 s
Total Validation Score: 156.66666666666666%
[TOP K: 9] done in 5.961 s
Total Validation Score: 160.41666666666669%
[TOP K: 10] done in 6.771 s
Total Validation Score: 164.58333333333331%
[TOP K: 11] done in 7.124 s
Total Validation Score: 166.25%
[TOP K: 12] done in 7.734 s
Total Validation Score: 172.5%
[TOP K: 13] done in 8.288 s
Total Validation Score: 175.0%
[TOP K: 14] done in 8.863 s
Total Validation Score: 179.16666666666669%
[TOP K: 15] done in 9.4

In [23]:
# for query_idx, infer_idxs in enumerate(answers):
#     print('QUERY:', org_dataset['validation']['question'][query_idx])
#     print('Label:', org_dataset['validation']['answers'][query_idx]['text'][0])
#     print()
#     print('True Passage')
#     print(org_dataset['validation']['context'][query_idx])

#     for infer_idx in infer_idxs:
#         print('-'*200)
#         print('Retrieval Passage')
#         print(wiki_contents[infer_idx])
        
#     print('*'*200)
#     print()

In [14]:
# org_dataset_test = load_from_disk('../../data/test_dataset')
org_dataset_train = load_from_disk('../../data/train_dataset')

# org_dataset_test
org_dataset_train

DatasetDict({
    train: Dataset({
        features: ['__index_level_0__', 'answers', 'context', 'document_id', 'id', 'question', 'title'],
        num_rows: 3952
    })
    validation: Dataset({
        features: ['__index_level_0__', 'answers', 'context', 'document_id', 'id', 'question', 'title'],
        num_rows: 240
    })
})

In [15]:
for i in range(20, 21):
    with timer(f'TOP K: {i}'):
        TOPK = i
        doc_scores = []
        doc_indices = []
        # for j in range(len(org_dataset_test['validation'])):
        for j in range(len(org_dataset_train['validation'])):
            # By match
            body = {
                'size': TOPK,
                'query': {
                    'match': {
                        # 'text': ' '.join(mecab.nouns(org_dataset_test['validation']['question'][j]))
                        'text': ' '.join(mecab.nouns(org_dataset_train['validation']['question'][j]))
                    }
                }
            }

            res = es.search(index=index, body=body)

            a_result_scores = []
            a_result_indices = []

            for item in res['hits']['hits']:
                a_result_scores.append(item['_score'])
                a_result_indices.append(int(item['_id'])-1)

            doc_scores.append(a_result_scores)
            doc_indices.append(a_result_indices)

print(f'{i} Done!', end='\r')


[TOP K: 20] done in 0.683 s
20 Done!

In [16]:
str_doc_scores = [str(i) for i in doc_scores]
str_doc_scores

['[17.909676, 16.451714, 16.428705, 15.916275, 15.378147, 14.625661, 13.647047, 13.139832, 13.125618, 12.824602, 12.749788, 12.631439, 12.121863, 11.939099, 11.645405, 11.575977, 11.437759, 11.365505, 11.353939, 11.31154]',
 '[57.87849, 35.12496, 32.3009, 31.522078, 30.303717, 30.075878, 30.033165, 29.819874, 29.652443, 29.585495, 27.88575, 27.751038, 27.634846, 27.57086, 27.566616, 27.499495, 25.483482, 25.39822, 25.174664, 25.042997]',
 '[20.423029, 16.481031, 16.07983, 16.07983, 15.691839, 13.614856, 13.57544, 13.387859, 13.320342, 13.238969, 13.238066, 12.903227, 12.885366, 12.836143, 12.7414, 12.717378, 12.661571, 12.566406, 12.467582, 12.429216]',
 '[31.51464, 19.859695, 18.082796, 17.589697, 17.416138, 16.529747, 15.4989805, 15.454755, 15.250239, 14.019678, 13.997004, 13.730276, 12.613743, 12.169538, 10.990624, 10.8478155, 10.77043, 10.239005, 10.086279, 10.019099]',
 '[16.800262, 16.584232, 16.421316, 16.146107, 15.977791, 15.1496725, 15.112888, 14.867394, 14.299231, 14.198876,

In [17]:
str_doc_indices = [str(i) for i in doc_indices]
str_doc_indices

['[11263, 4459, 5294, 14210, 42209, 40128, 5477, 21869, 40586, 14736, 11019, 52624, 32833, 34920, 45213, 20964, 35047, 34919, 20780, 42779]',
 '[47950, 18112, 19569, 48035, 33924, 19571, 33922, 19572, 19799, 48248, 48208, 33921, 19801, 46537, 46739, 19568, 47955, 50048, 19805, 47951]',
 '[11891, 24363, 746, 4678, 15328, 41051, 17741, 51074, 50158, 51813, 19955, 51373, 4637, 33191, 49074, 532, 49730, 40200, 27412, 27278]',
 '[55663, 55662, 55660, 25395, 55654, 55652, 55655, 55664, 55661, 25794, 1968, 55665, 36348, 1807, 55653, 42293, 46029, 56651, 1808, 23645]',
 '[29185, 42802, 56385, 42793, 9588, 30960, 17725, 42686, 15393, 15138, 113, 3802, 44778, 42794, 41553, 47705, 23299, 34831, 42799, 10485]',
 '[14493, 23335, 21809, 21388, 42394, 29019, 20402, 45905, 15623, 31761, 34847, 16776, 8783, 19317, 55224, 12301, 17334, 53595, 51496, 41865]',
 '[16171, 50041, 15887, 16170, 15582, 21267, 50039, 14065, 14063, 32579, 10421, 50376, 34577, 39468, 19217, 13600, 45363, 23254, 5179, 16173]',
 '[

In [18]:
import pandas as pd

elastic_retrieval = pd.DataFrame(str_doc_indices, columns=['indices'])
elastic_retrieval['scores'] = str_doc_scores
elastic_retrieval

Unnamed: 0,indices,scores
0,"[11263, 4459, 5294, 14210, 42209, 40128, 5477,...","[17.909676, 16.451714, 16.428705, 15.916275, 1..."
1,"[47950, 18112, 19569, 48035, 33924, 19571, 339...","[57.87849, 35.12496, 32.3009, 31.522078, 30.30..."
2,"[11891, 24363, 746, 4678, 15328, 41051, 17741,...","[20.423029, 16.481031, 16.07983, 16.07983, 15...."
3,"[55663, 55662, 55660, 25395, 55654, 55652, 556...","[31.51464, 19.859695, 18.082796, 17.589697, 17..."
4,"[29185, 42802, 56385, 42793, 9588, 30960, 1772...","[16.800262, 16.584232, 16.421316, 16.146107, 1..."
...,...,...
235,"[49447, 49449, 49451, 49454, 49452, 48213, 137...","[36.84846, 31.860357, 27.574036, 27.06894, 26...."
236,"[15026, 8387, 43925, 15031, 15210, 46521, 1073...","[15.405445, 14.848118, 14.640722, 14.384999, 1..."
237,"[4693, 772, 30863, 4461, 2388, 4993, 2866, 537...","[30.163343, 29.678434, 22.06234, 15.006094, 14..."
238,"[44764, 51615, 51613, 52504, 44765, 51614, 227...","[37.915394, 37.86438, 34.95318, 32.83405, 30.9..."
