In [227]:
import numpy as np
from scipy import sparse
from typing import List, Dict
from tqdm import tqdm
from tabulate import tabulate
import pandas
from typing import Tuple

In [11]:
from analysis import compute_tf_idf

In [100]:
from scipy.sparse.linalg import svds

In [12]:
from analysis import NameSpace

In [13]:
txt_path = './data/rmrb.txt'

In [14]:
from loader import read_articles_from_txt

In [15]:
names, contents = read_articles_from_txt(txt_path)

opening ./data/rmrb.txt
Found: 中共中央印发中国共产党地方组织选举工作条例 , length: 510
Found: 中共中央致电祝贺朝鲜劳动党八大召开 , length: 477
Found: 农业科技进步贡献率超60%农业农村现代化迈上新台阶 , length: 1647
Found: 提高新时代地方党组织选举质量的制度保证 , length: 1335
Found: 全国宣传部长会议在京召开王沪宁出席并讲话 , length: 843
Found: 通海蔬菜远销海外 , length: 1172
Found: 复兴号高寒动车组亮相 , length: 189
Found: 胡春华强调立足新发展阶段推动农民工工作取得更大成就 , length: 513
Found: 国办印发《意见》进一步优化地方政务服务便民热线 , length: 1080
Found: 民生欢歌，旋律更高昂 , length: 2196
Found: 推动住房和城乡建设事业高质量发展 , length: 2406
Found 11 articles.


In [16]:
from cut import make_articles_from_contents

In [17]:
articles = make_articles_from_contents(names, contents)

100%|██████████| 11/11 [00:00<00:00, 12.03it/s]


In [18]:
articles

[Article(name=中共中央印发中国共产党地方组织选举工作条例, n_terms=311),
 Article(name=中共中央致电祝贺朝鲜劳动党八大召开, n_terms=281),
 Article(name=农业科技进步贡献率超60%农业农村现代化迈上新台阶, n_terms=903),
 Article(name=提高新时代地方党组织选举质量的制度保证, n_terms=772),
 Article(name=全国宣传部长会议在京召开王沪宁出席并讲话, n_terms=468),
 Article(name=通海蔬菜远销海外, n_terms=664),
 Article(name=复兴号高寒动车组亮相, n_terms=103),
 Article(name=胡春华强调立足新发展阶段推动农民工工作取得更大成就, n_terms=278),
 Article(name=国办印发《意见》进一步优化地方政务服务便民热线, n_terms=606),
 Article(name=民生欢歌，旋律更高昂, n_terms=1271),
 Article(name=推动住房和城乡建设事业高质量发展, n_terms=1389)]

In [19]:
ns = NameSpace()

In [20]:
ns.add_articles(articles)

In [21]:
doc_matrix = ns.get_term_document_matrix()

In [23]:
tf_idf = compute_tf_idf(doc_matrix)

In [101]:
tf_idf = sparse.csr_matrix(tf_idf)

In [110]:
u, s, vh = svds(tf_idf, k=min(tf_idf.shape)-1)

In [232]:

class Searcher:

    def __init__(
        self,
        svd_u: np.ndarray,
        svd_s: np.ndarray,
        svd_vh: np.ndarray
    ) -> None:
        self.u = svd_u.copy()
        self.s = np.diag(svd_s)
        self.vh = svd_vh.copy()

        self.doc_coords = self.u @ self.s
    
    def pairwise_cosine_similarities(self, xs: np.ndarray, ys: np.ndarray) -> np.ndarray:
        n_samples_x = xs.shape[0]
        n_samples_y = ys.shape[0]

        xs = xs / (np.linalg.norm(xs, axis=1).reshape(n_samples_x, 1))
        ys = ys / (np.linalg.norm(ys, axis=1).reshape(n_samples_y, 1))
        cosines = xs @ (ys.T)

        return cosines
    
    def sort_index_by_cosine_similarity(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        x = np.atleast_2d(x)
        if x.shape[0] > x.shape[1]:
            x = x.T

        cosines = self.pairwise_cosine_similarities(x, self.doc_coords).flatten()
        article_indexes = np.arange(0, self.doc_coords.shape[0])
        sort_indexes = np.argsort(-cosines)

        return (
            article_indexes[sort_indexes],
            cosines[sort_indexes],
        )
    
    def make_query(self, terms: List[str], term_index: Dict[str, int]) -> np.ndarray:
        query_term_indexes = [term_index.get(term, 0) for term in terms]
        query_row = np.zeros(shape=(1, self.vh.shape[1],), dtype=np.float32)
        query_row[0, query_term_indexes] = 1
        query_row = query_row @ self.vh.T

        return query_row

In [233]:
u, s, vh = svds(tf_idf, k=min(tf_idf.shape)-1)

In [234]:
searcher = Searcher(u, s, vh)

In [237]:
q = searcher.make_query(['农机', '小麦', '收成'], ns.term_index)

In [238]:
searcher.sort_index_by_cosine_similarity(q)

(array([ 2,  3,  8,  7,  6,  1,  5, 10,  4,  9,  0]),
 array([ 9.9106801e-01,  7.4726786e-03,  8.2469080e-05,  1.4682300e-05,
         5.4128468e-06, -7.7346340e-06, -2.0742737e-05, -3.1686504e-05,
        -2.1699129e-04, -5.8697467e-04, -5.4438477e-03], dtype=float32))