In [None]:
import os
import json
import pandas as pd
import pickle
from transformers import BertTokenizer, BertModel
import numpy as np
from tqdm import trange, tqdm
import random

###### Download TriviaQA dataset from https://nlp.cs.washington.edu/triviaqa/data/triviaqa-rc.tar.gz

In [None]:
evidence_dir = os.path.join("evidence/wikipedia")
qa_train_path = os.path.join("qa/wikipedia-train.json")
qa_dev_path = os.path.join("qa/wikipedia-dev.json")
qa_test_path = os.path.join("qa/wikipedia-test-without-answers.json")

def txt2title(text):
    return os.path.splitext(text)[0].replace('_', ' ')

full_doc = {}
global_index = 0
for file_name in os.listdir(evidence_dir):
    if file_name.endswith(".txt"):
        title = txt2title(file_name)
        assert title not in full_doc, f"dup title for {file_name}"
        with open(os.path.join(evidence_dir, file_name)) as f:
            body = f.read().replace('\n', ' ').replace('\t', ' ').replace('\r', ' ')[:10240]
        full_doc[file_name] = (str(global_index), title, body)
        global_index += 1
lines = ['\t'.join(v) + '\n' for v in full_doc.values()]
with open("trivia_qa_fulldoc.csv", "w") as f:
    f.writelines(lines)


for in_file, out_file in zip((qa_train_path, qa_dev_path, qa_test_path), ("trivia_qa_train.csv", "trivia_qa_dev.csv", "trivia_qa_test.csv")):
    with open(in_file) as f:
        raw = json.load(f)

    lines = []
    for item in raw["Data"]:
        query = item["Question"]
        doc_ids = [full_doc[doc["Filename"]][0] for doc in item["EntityPages"]]
        lines.append(query + "\t" + ','.join(doc_ids) + "\n")

    with open(out_file, "w") as f:
        f.writelines(lines)

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def lower(x):
    try:
        text = tokenizer.tokenize(x)[:512]
        id_ = tokenizer.convert_tokens_to_ids(text)
        return tokenizer.decode(id_)
    except:
        return x

In [None]:
df_train = pd.read_csv('trivia_qa_train.csv',
                        names=["query", "docid"],
                        encoding='utf-8', header=None, sep='\t')
df_train

In [None]:
df_val = pd.read_csv('trivia_qa_dev.csv',
                        names=["query", "docid"],
                        encoding='utf-8', header=None, sep='\t').loc[:, ['query', 'docid']]
df_val

In [None]:
df_test = pd.read_csv('trivia_qa_test.csv',
                        names=["query", "docid"],
                        encoding='utf-8', header=None, sep='\t').loc[:, ['query', 'docid']]
df_test

In [None]:
df_full = pd.read_csv('trivia_qa_fulldoc.csv',
                         names=["docid", "title", "content"],
                        encoding='utf-8', header=None, sep='\t')
df_full

In [None]:
df_full['title'] = df_full['title'].map(lower)
df_drop_title = df_full.drop_duplicates('title').reset_index(drop=True) 

In [None]:
df_drop_title

In [None]:
title_doc_id = {}
for i in trange(len(df_drop_title)):
    title_doc_id[df_drop_title['title'][i]] = i

origin_new_id = {}
for i in trange(len(df_full)):
    origin_new_id[df_full['docid'][i]] = title_doc_id[df_full['title'][i]]

In [None]:
## doc pool

In [None]:
df_drop_title

In [None]:
file_pool = open("Trivia_doc_content.tsv", 'w') 

for i in trange(len(df_drop_title)):
    file_pool.write('\t'.join([str(df_drop_title['docid'][i]), str(origin_new_id[df_drop_title['docid'][i]]), str(df_drop_title['title'][i]), str(df_drop_title['content'][i]), str(df_drop_title['title'][i]) + str(df_drop_title['content'][i])]) + '\n')
    file_pool.flush()

## Generate BERT embeddings for each document

In [None]:
## Execute the following command to get bert embedding pkl file
## Use 4 GPU
!./bert/Trivia_bert.sh 4

In [None]:
output_bert_base_tensor = []
output_bert_base_id_tensor = []
for num in trange(4):
    with open(f'bert/pkl/Trivia_output_tensor_512_content_{num}.pkl', 'rb') as f:
        data = pickle.load(f)
    f.close()
    output_bert_base_tensor.extend(data)

    with open(f'bert/pkl/Trivia_output_tensor_512_content_{num}_id.pkl', 'rb') as f:
        data = pickle.load(f)
    f.close()
    output_bert_base_id_tensor.extend(data)


train_file = open(f"bert/Trivia_doc_content_embedding_bert_512.tsv", 'w') 

for idx, doc_tensor in enumerate(output_bert_base_tensor):
    embedding = '|'.join([str(elem) for elem in doc_tensor])
    train_file.write('\t'.join([str(output_bert_base_id_tensor[idx]), '', '', '', '', '', 'en', embedding]) + '\n')
    train_file.flush()

## Apply Hierarchical K-Means on it to generate semantic IDs

In [None]:
## Execute the following command to get kmeans id of the documents
!./kmeans/kmeans_Trivia.sh

In [None]:
with open('kmeans/IDMapping_Trivia_bert_512_k30_c30_seed_7.pkl', 'rb') as f:
    kmeans_trivia_doc_dict = pickle.load(f)
## random id : newid
new_kmeans_trivia_doc_dict_512 = {}
for old_docid in kmeans_trivia_doc_dict.keys():
    new_kmeans_trivia_doc_dict_512[str(old_docid)] = '-'.join(str(elem) for elem in kmeans_trivia_doc_dict[old_docid])

new_kmeans_trivia_doc_dict_512_int_key = {}
for key in new_kmeans_trivia_doc_dict_512:
    new_kmeans_trivia_doc_dict_512_int_key[int(key)] = new_kmeans_trivia_doc_dict_512[key]

## Query Generation

In [None]:
## Execute the following command to generate queries for the documents
## Use 4 GPU
!./qg/Trivia_qg.sh 4

In [None]:
## merge parallel results
output_bert_base_tensor_qg = []
output_bert_base_id_tensor_qg = []
for num in trange(4):
    with open(f'qg/pkl/Trivia_output_tensor_512_content_64_15_{num}.pkl', 'rb') as f:
        data = pickle.load(f)
    f.close()
    output_bert_base_tensor_qg.extend(data)

    with open(f'qg/pkl/Trivia_output_tensor_512_content_64_15_{num}_id.pkl', 'rb') as f:
        data = pickle.load(f)
    f.close()
    output_bert_base_id_tensor_qg.extend(data)

In [None]:
qg_dict = {}
for i in trange(len(output_bert_base_tensor_qg)):
    if(output_bert_base_id_tensor_qg[i] not in qg_dict):
        qg_dict[output_bert_base_id_tensor_qg[i]] = [output_bert_base_tensor_qg[i]]
    else:
        qg_dict[output_bert_base_id_tensor_qg[i]].append(output_bert_base_tensor_qg[i])

## Genarate training data

In [None]:
train_query_docid = {}
for i in trange(len(df_train)):
    if(len(df_train['query'][i].split('\n')) == 1):
        train_query_docid[df_train['query'][i]] = [int(elem) for elem in df_train['docid'][i].split(',')]

file_train = open("train.tsv", 'w')

count = 0
for query in tqdm(train_query_docid.keys()):
    for i in range(len(train_query_docid[query])):
        id_ori = train_query_docid[query][i]
        new_id = origin_new_id[id_ori]
        file_train.write('\t'.join([query, str(id_ori), str(new_id), new_kmeans_trivia_doc_dict_512_int_key[int(new_id)]]) + '\n')
        file_train.flush()

In [None]:
val_query_docid = {}
for i in trange(len(df_val)):
    if(len(df_val['query'][i].split('\n')) == 1):
        val_query_docid[df_val['query'][i]] = [int(elem) for elem in df_val['docid'][i].split(',')]

file_val = open("dev.tsv", 'w')

count = 0
for query in tqdm(val_query_docid.keys()):
    id_ori_ = []
    new_id_ = []
    kmeans_ = []
    for i in range(len(val_query_docid[query])):
        id_ori = str(val_query_docid[query][i])
        new_id = str(origin_new_id[int(id_ori)])
        id_ori_.append(id_ori)
        new_id_.append(new_id)
        kmeans_.append(new_kmeans_trivia_doc_dict_512_int_key[int(new_id)])
    
    id_ori_ = ','.join(id_ori_)
    new_id_ = ','.join(new_id_)
    kmeans_ = ','.join(kmeans_)
    
    file_val.write('\t'.join([query, str(id_ori_), str(new_id_), kmeans_]) + '\n')
    file_val.flush()

In [None]:
test_query_docid = {}
for i in trange(len(df_test)):
    if(len(df_test['query'][i].split('\n')) == 1):
        test_query_docid[df_val['query'][i]] = [int(elem) for elem in df_test['docid'][i].split(',')]

file_test = open("test.tsv", 'w')

count = 0
for query in tqdm(test_query_docid.keys()):
    id_ori_ = []
    new_id_ = []
    kmeans_ = []
    for i in range(len(test_query_docid[query])):
        id_ori = str(test_query_docid[query][i])
        new_id = str(origin_new_id[int(id_ori)])
        id_ori_.append(id_ori)
        new_id_.append(new_id)
        kmeans_.append(new_kmeans_trivia_doc_dict_512_int_key[int(new_id)])
    
    id_ori_ = ','.join(id_ori_)
    new_id_ = ','.join(new_id_)
    kmeans_ = ','.join(kmeans_)
    
    file_test.write('\t'.join([query, str(id_ori_), str(new_id_), kmeans_]) + '\n')
    file_test.flush()

In [None]:
QG_NUM = 15

In [None]:
qg_file = open("trivia_512_qg.tsv", 'w') 

for queryid in tqdm(qg_dict):
    for query in qg_dict[queryid][:QG_NUM]:
        qg_file.write('\t'.join([query, queryid, new_kmeans_trivia_doc_dict_512_int_key[int(queryid)]]) + '\n')
        qg_file.flush()

In [None]:
df_drop_title['new_id'] = df_drop_title['docid'].map(origin_new_id)

df_drop_title['kmeas_id'] = df_drop_title['new_id'].map(new_kmeans_trivia_doc_dict_512_int_key)


df_drop_title['tc'] = df_drop_title['title'] + ' ' + df_drop_title['content']

df_drop_title_ = df_drop_title.loc[:, ['tc', 'docid', 'new_id', 'kmeas_id']]  

df_drop_title_.to_csv('trivia_title_cont.tsv', sep='\t', header=None, index=False, encoding='utf-8')

In [None]:
df_drop_title

In [None]:
queryid_oldid_dict = {}
bertid_oldid_dict = {}
map_file = "trivia_title_cont.tsv"
with open(map_file, 'r') as f:
    for line in f.readlines():
        query, queryid, oldid, bert_k30_c30 = line.split("\t")
        queryid_oldid_dict[oldid] = queryid
        bertid_oldid_dict[oldid] = bert_k30_c30

train_file = "Trivia_doc_content.tsv"
doc_aug_file = open("trivia_doc_aug.tsv", 'w') 
with open(train_file, 'r') as f:
    for line in f.readlines():
        _, docid, _, _, content = line.split("\t")
        content = content.split(' ')
        add_num = max(0, len(content)-3000) / 3000
        for i in range(10+int(add_num)):
            begin = random.randrange(0, len(content))
            # if begin >= (len(content)-64):
            #     begin = max(0, len(content)-64)
            end = begin + 64 if len(content) > begin + 64 else len(content)
            doc_aug = content[begin:end]
            doc_aug = ' '.join(doc_aug).replace('\n', ' ')
            queryid = queryid_oldid_dict[docid]
            bert_k30_c30 = bertid_oldid_dict[docid]
            # doc_aug_file.write('\t'.join([doc_aug, str(queryid), str(docid), str(bert_k30_c30)]) + '\n')
            doc_aug_file.write('\t'.join([doc_aug, str(queryid), str(docid), str(bert_k30_c30)]))
            doc_aug_file.flush()