In [22]:
%%writefile solution_final.py
import os
import json
import nltk
import torch
import faiss
import string
import langdetect
import numpy as np
from flask import (Flask,
                   request,)
from typing import (Dict,
                    List,
                    Tuple,)


app = Flask(__name__)


class Solution(object):
    def __init__(self):
        self.documents = None
        self.index = None
        torch.set_grad_enabled(False)
        
    def load_resources(self):
        self.vocab = json.load(open(os.environ['VOCAB_PATH'],
                               mode='r',
                               encoding='utf-8'))
        
        state_dict = torch.load(os.environ['EMB_PATH_KNRM'])
        self.emb_knrm_shape = state_dict['weight'].shape
        self.emb_knrm = torch.nn.Embedding.from_pretrained(state_dict['weight'],
                                                           freeze=True,
                                                           padding_idx=0)

#         self.emb_glove = list()
#         with open(os.environ['EMB_PATH_GLOVE'], mode='r') as file:
#             for line in file:
#                 self.emb_glove.append(line.split()[0])

        self.mlp_knrm = torch.load(os.environ['MLP_PATH'])
        
        global is_ready
        is_ready = True
    
    def _preprocess(self, input_str: str) -> str:
        table = str.maketrans(string.punctuation,
                              ' '*len(string.punctuation))
        return (input_str
                .translate(table)
                .lower())
    
    def _filter_glove_tokens(self, tokens: List[str]) -> List[str]:
        return [t for t in tokens if t in self.emb_glove]
    
    def _get_tokens(self, input_str: str) -> List[str]:
        return nltk.word_tokenize(self._preprocess(input_str))
    
    def _get_tokens_ids(self,
                        input_str: str,
                        filter_glove: bool = False) -> List[int]:
        tokens = self._get_tokens(input_str)
        if filter_glove:
            tokens = self._filter_glove_tokens(tokens)
        return [self.vocab.get(t, self.vocab['OOV']) 
                for t in tokens]
        
    def update_index(self, documents: Dict[str, str]) -> int:
        self.documents = documents
        
        tokens_ids = list()
        for d in self.documents:
            ids = self._get_tokens_ids(self.documents[d],
                                       filter_glove=False)
            tokens_ids.append(ids)
            
        vectors = list()
        for ids in tokens_ids:
            embs = self.emb_knrm(torch.LongTensor(ids))
            vectors.append(embs
                           .mean(axis=0)
                           .numpy())
        vectors = np.array(vectors)
            
        self.index = faiss.IndexFlatL2(vectors.shape[1])
        self.index = faiss.IndexIDMap(self.index)
        self.index.add_with_ids(vectors,
                                np.array([int(i) for i in self.documents]))
        
        return self.index.ntotal
    
    def search(self, query: str, k: int = 10) -> List[Tuple[str, str]]:
        query_ids = self._get_tokens_ids(query, filter_glove=False)
        
        query_emb = self.emb_knrm(torch.LongTensor(query_ids)).mean(axis=0)
        query_emb = (query_emb
                     .numpy()
                     .reshape(-1, self.emb_knrm_shape[1]))
        
        _, document_ids = self.index.search(query_emb, k)
        
        return [(str(i), self.documents[str(i)])
                for i in document_ids.reshape(-1)]
        

solution = Solution()
is_ready = False


@app.route('/ping')
def ping():
    if not is_ready:
        return {'status': 'not ready'}
    return {'status': 'ok'}


@app.route('/query', methods=['POST'])
def query():
    if solution.index is None:
        return {'status': 'FAISS is not initialized!'}
    
    content = json.loads(request.json)
    queries = content['queries']
    
    results = list()
    for q in queries:
        if langdetect.detect(q) == 'en':
            candidates = solution.search(q)
            
            results.append(candidates)
        else:
            results.append(None)
    
    return {'lang_check': [True if r is not None else False 
                           for r in results],
            'suggestions': results,}


@app.route('/update_index', methods=['POST'])
def update_index():
    content = json.loads(request.json)
    documents = content['documents']
    
    return {'status': 'ok',
            'index_size': solution.update_index(documents)}


solution.load_resources()



Overwriting solution_final.py


In [23]:
!curl 'http://127.0.0.1:11000/ping'

{
  "status": "ok"
}


In [3]:
import json
import requests
import pandas as pd
from IPython.core.display import HTML

In [4]:
dev_df = pd.read_csv('../5/data/QQP/dev.tsv',
                     sep='\t',
                     error_bad_lines=False,
                     dtype=object)
dev_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 40430 entries, 0 to 40429
Data columns (total 6 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   id            40430 non-null  object
 1   qid1          40430 non-null  object
 2   qid2          40430 non-null  object
 3   question1     40430 non-null  object
 4   question2     40430 non-null  object
 5   is_duplicate  40430 non-null  object
dtypes: object(6)
memory usage: 1.9+ MB


In [24]:
body = json.dumps({'documents': {i[0]: i[1] 
                                 for i in dev_df[['qid1', 'question1']].values.tolist()}})
res = requests.post('http://127.0.0.1:11000/update_index',
                    json=body)
HTML(res.content.decode('utf-8'))

In [25]:
body = json.dumps({'queries': ['Проверка на язык',
                               'Why are African-Americans so beautiful?']},
#                   ensure_ascii=False,
                 )
res = requests.post('http://127.0.0.1:11000/query',
                    json=body)
HTML(res.content.decode('utf-8'))

In [4]:
import torch
state_dict = torch.load('/home/jupyter-v.pashentsev-2/2.matching/5/embedings.pickle')
emb = torch.nn.Embedding.from_pretrained(state_dict['weight'])
# emb.load_state_dict(state_dict)
emb

Embedding(87164, 50)

In [5]:
emb(torch.LongTensor([3, 1])).shape

torch.Size([2, 50])

In [7]:
res = emb(torch.LongTensor([3, 1]))
res.numpy()

array([[ 6.1850e-01,  6.4254e-01, -4.6552e-01,  3.7570e-01,  7.4838e-01,
         5.3739e-01,  2.2239e-03, -6.0577e-01,  2.6408e-01,  1.1703e-01,
         4.3722e-01,  2.0092e-01, -5.7859e-02, -3.4589e-01,  2.1664e-01,
         5.8573e-01,  5.3919e-01,  6.9490e-01, -1.5618e-01,  5.5830e-02,
        -6.0515e-01, -2.8997e-01, -2.5594e-02,  5.5593e-01,  2.5356e-01,
        -1.9612e+00, -5.1381e-01,  6.9096e-01,  6.6246e-02, -5.4224e-02,
         3.7871e+00, -7.7403e-01, -1.2689e-01, -5.1465e-01,  6.6705e-02,
        -3.2933e-01,  1.3483e-01,  1.9049e-01,  1.3812e-01, -2.1503e-01,
        -1.6573e-02,  3.1200e-01, -3.3189e-01, -2.6001e-02, -3.8203e-01,
         1.9403e-01, -1.2466e-01, -2.7557e-01,  3.0899e-01,  4.8497e-01],
       [ 1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,
         1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,
         1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,
         1.0000e+00,  1.0000e+00,  1.0000e+00,  1.

In [97]:
import faiss
index = faiss.IndexFlatL2(50)
index.add(res.detach().numpy())
index.ntotal

2

In [98]:
index.search(x=np.full((1, 50), 1, dtype=np.float32), k=2)

(array([[ 0.      , 65.641846]], dtype=float32), array([[1, 0]]))

In [44]:
import faiss
import numpy as np

matrix = np.random.randn(100, 10, 50)
index = faiss.IndexFlatL2()
index.add(matrix[0])

AssertionError: 

In [51]:
import torch
torch.LongTensor([0, 1, 2] + [3] * 1)

tensor([0, 1, 2, 3])

In [184]:
{'a': 1, 'b': 2}.items()[0]

TypeError: 'dict_items' object is not subscriptable