In [1]:
import string
from collections import Counter
from typing import Dict, List, Tuple, Union, Callable
from tqdm.auto import tqdm
import json

import nltk

import numpy as np
import math
import pandas as pd
import torch
import torch.nn.functional as F

from langdetect import detect
import faiss
import os

In [2]:
import sys
sys.path.append('../')

from lib.ranking import Model

In [3]:
import warnings
warnings.filterwarnings('ignore')

In [4]:
#nltk.download()

In [5]:
PATH = '../../../ranking and matching/final_project/data'

glue_qqp_dir = PATH + '/QQP'
glove_path = PATH + '/glove.6B.50d.txt'

In [32]:
model = Model(glove_path,
              glue_qqp_dir,
              min_token_occurancies = 1,
              knrm_kernel_num = 21,
              knrm_out_mlp = [10, 5],
              dataloader_bs = 1024,
              freeze_knrm_embeddings = False,
              train_lr = 0.01,
              change_train_loader_ep = 10,
              n_training = 12000,
              )

model.initialize_training()

In [33]:
%%time
# model training

model.train(30)

state_mlp = model.model.mlp.state_dict()
torch.save(state_mlp, open('../artifacts/knrm/knrm_mlp.bin', 'wb'))

state_emb = model.model.embeddings.state_dict()
torch.save(state_emb, open('../artifacts/knrm/knrm_emb.bin', 'wb'))

with open('../artifacts/knrm/vocab.json', 'w') as f:
    json.dump(model.vocab, f)

0 0.5653442646777981
5 0.8986491445062015
10 0.9291317588355423
15 0.939217092741588
20 0.9452734667065634
25 0.9491157158382607
29 0.9507637693795155
CPU times: user 6min 51s, sys: 1min 35s, total: 8min 26s
Wall time: 5min 3s


### Check

In [34]:
from lib.index import Selection

In [35]:
df = model.glue_train_df
df = df.loc[:, ['text_left', 'text_right']].stack()
df.drop_duplicates(inplace=True)
all_documents = {str(i): t for i, t in enumerate(df.values)}
print(f'n of documents: {len(all_documents)}')

with open('../artifacts/all_documents.json', 'w') as f:
    json.dump(all_documents, f)

n of documents: 493874


In [36]:
os.environ["EMB_PATH_KNRM"] = '../artifacts/knrm/knrm_emb.bin'
os.environ["MLP_PATH"] = '../artifacts/knrm/knrm_mlp.bin'
os.environ["EMB_PATH_GLOVE"] = glove_path
os.environ["VOCAB_PATH"] = '../artifacts/knrm/vocab.json'

EMB_PATH_KNRM = os.environ["EMB_PATH_KNRM"]
MLP_PATH = os.environ["MLP_PATH"]
EMB_PATH_GLOVE = os.environ["EMB_PATH_GLOVE"]
VOCAB_PATH = os.environ["VOCAB_PATH"]

model = Model(EMB_PATH_GLOVE,
              EMB_PATH_GLOVE,
              min_token_occurancies = 1,
              knrm_kernel_num = 21,
              knrm_out_mlp = [10, 5],
              train_lr = 0.01,
              )

state_dict = torch.load(EMB_PATH_KNRM) 
emb_matrix = state_dict['weight']
mlp_state_dict = torch.load(MLP_PATH)
model.build_model_with_pretrained_weights(mlp_state_dict, emb_matrix)

In [37]:
%%time
selection = Selection(glove_path, agg_method='mean', metric='l2')
selection.init_index(all_documents)

CPU times: user 29.6 s, sys: 793 ms, total: 30.4 s
Wall time: 30.9 s


In [38]:
def get_simmilar_questions(query):
    ids, texts = selection.search(query)
    model_input = model.prediction_data(query, texts, vocab)
    res = model.knrm.predict(model_input)
    _, I = torch.sort(res.flatten(), descending=True)
    I = I[:10] 
    texts_out = [texts[i] for i in I]
    ids_out = [ids[i] for i in I]
    return texts_out, ids_out

query = 'What are the most lovely self help book I should read?'
texts_out, ids_out = get_simmilar_questions(query)

In [163]:
for q in texts_out:
    print(q)

What are the top self help books I should read?
What are the best self-help books you've ever read?
What are some books I should read this summer?
What book have you re-read the most, and why?
What are the books to be read for self improvement?
What are the best books one should read?
What are the best books that one should must read?
What are some good book that are a must read?
What is the most important book you have ever read?
What are some of the best novels everyone should read?
