In [None]:
import json
from datasets import load_from_disk
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch

In [None]:
with open('../data/preprocess_wiki.json', 'r', encoding='utf-8') as f:
    wiki = json.load(f)
wiki['0']

In [None]:
corpus = [wiki[str(i)]['text'] for i in range(len(wiki))]

In [None]:
len(corpus)

In [None]:
cross_encoder = CrossEncoder('klue/bert-base')

In [None]:
bi_encoder = SentenceTransformer('Huffon/sentence-klue-roberta-base')
bi_encoder.max_seq_length = 384

corpus_embeddings = bi_encoder.encode(corpus, convert_to_tensor=True, show_progress_bar=True)

In [None]:
from tqdm.autonotebook import tqdm
import numpy as np
import pickle

with open('/opt/ml/elastic_valid_500.bin','rb') as f:
    elastic_valid = pickle.load(f)

In [None]:
def search(query, k):
    bi_encoder_retrieval, cross_encoder_retrieval = [], []

    ##### Sematic Search #####
    # Encode the query using the bi-encoder and find potentially relevant passages
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    question_embedding = question_embedding.cuda()
    hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=100)
    hits = hits[0]  # Get the hits for the first query

    ##### Re-Ranking #####
    # Now, score all retrieved passages with the cross_encoder
    cross_inp = [[query, corpus[hit['corpus_id']]] for hit in hits]
    cross_scores = cross_encoder.predict(cross_inp)

    # Sort results by the cross-encoder scores
    for idx in range(len(cross_scores)):
        hits[idx]['cross-score'] = cross_scores[idx]

    # Output of top-5 hits from bi-encoder
    hits = sorted(hits, key=lambda x: x['score'], reverse=True)
    for hit in hits[0:k]:
        bi_encoder_retrieval.append(hit['corpus_id'])

    hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
    for hit in hits[0:k]:
        cross_encoder_retrieval.append(hit['corpus_id'])

    return bi_encoder_retrieval, cross_encoder_retrieval

In [None]:
td = load_from_disk('../data/new_train_dataset')['train']

In [None]:
a, b = search(td[0]['question'], 5)
len(a)


In [None]:
origin_vd = load_from_disk('/opt/ml/data/train_dataset')['validation']
len(origin_vd['document_id'])

In [None]:
question = origin_vd['question']
question_embedding = bi_encoder.encode(question, convert_to_tensor=True)
question_embedding = question_embedding.cuda()
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=500)

st_v = {}
for i, q in enumerate(question):
    st_v[q] = hits[i]

In [None]:
with open('sentence_transformer_valid.pickle', 'wb') as f:
    pickle.dump(st_v, f, pickle.HIGHEST_PROTOCOL)

In [None]:
test_dataset = load_from_disk('../data/test_dataset')['validation']

question = test_dataset['question']
question_embedding = bi_encoder.encode(question, convert_to_tensor=True)
question_embedding = question_embedding.cuda()
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=500)

st_t = {}
for i, q in enumerate(question):
    st_t[q] = hits[i]

In [None]:
with open('sentence_transformer_test.pickle', 'wb') as f:
    pickle.dump(st_t, f, pickle.HIGHEST_PROTOCOL)

In [None]:
len(st_v[list(st_v.keys())[0]])

In [None]:
len(st_t[list(st_t.keys())[0]])

In [None]:
valid_dataset = load_from_disk("/opt/ml/data/new_train_dataset/validation")
query = valid_dataset['question']
context = valid_dataset['context']

top_k_list = [20]

for top_k in top_k_list:
    es_acc = 0
    bi_encoder_acc = 0
    cross_encoder_acc = 0

    for i in tqdm(range(len(query))):
        q = query[i]
        ground_truth = origin_vd[i]['document_id']

        bi_encoder_top_k, cross_encoder_top_k = search(q, top_k)

        es_top_k = []
        for j in range(top_k):
            es_top_k.append(elastic_valid[q][j])

        if ground_truth in es_top_k:
            es_acc += 1
        if ground_truth in bi_encoder_top_k:
            bi_encoder_acc += 1
        if ground_truth in cross_encoder_top_k:
            cross_encoder_acc += 1

    print('score_top_k : ', top_k)
    print('es ACC : ', es_acc / len(query))
    print('bi-encoder ACC :',bi_encoder_acc / len(query))
    print('cross-encoder ACC :',cross_encoder_acc / len(query))

    print()