In [258]:
import numpy as np
import scipy as sp
import pickle
import json
from nltk.stem import PorterStemmer
from nltk.tokenize import word_tokenize
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import normalize

In [253]:
def cosine_similarity(q, A):
    print(q.shape, A.shape)
    dot_product = q.T @ A
    q_norm = np.sqrt((q ** 2).sum())
    A_norms = np.sqrt((A ** 2).sum(axis=0))
    similarities = dot_product / (q_norm * A_norms)
    return np.nan_to_num(similarities)

In [320]:
def cosine_similarity_normalized(q, A):
    dot_product = q.T @ A
    return dot_product

In [163]:
def tf_idf(matrix):
    mat = matrix.copy()
    cnt_mat = mat.count_nonzero(axis=1)
    for i in range(len(cnt_mat)):
        if cnt_mat[i] != 0:
            mat[i] *= np.log(mat.shape[1] / cnt_mat[i])
    return mat

In [324]:
class SearchEngine:
    def __init__(self): #normalized=True
        #self.normalized = normalized
        terms_by_doc = sp.sparse.load_npz("terms_by_doc.npz")
        with open("documents.json", "r") as f:
            self.documents = json.load(f)
        with open("terms.json", "r") as f:
            self.terms = json.load(f)
        with open("stop_words.txt", 'r') as sw_file:
            self.stop_words = set(sw_file.read().splitlines())
        self.stemmer = PorterStemmer()
        # perfoming TF-IDF on the matrix
        tfidf_transformer = TfidfTransformer()
        self.tfidf_matrix = tfidf_transformer.fit_transform(terms_by_doc.T).T # if self.normalized else tf_idf(terms_by_doc)

    def _reduce_svd_matrix(self, k):
        # TODO: check if matrix for that k already exists
        svd = TruncatedSVD(n_components=k)
        reduced_matrix = svd.fit_transform(self.tfidf_matrix.T)
        # TODO: add saving to file new calculated matrix and model
        return reduced_matrix, svd
    
    def search(self, search_terms: str, k: int = 0):
        query = sp.sparse.lil_matrix((len(self.terms), 1))
        words = word_tokenize(search_terms.lower())
        stemmed_words = [self.stemmer.stem(word) for word in words if word not in self.stop_words]
        filtered_terms = [word for word in stemmed_words if word not in self.stop_words and word.isalpha() and word in self.terms]
        if len(filtered_terms) == 0:
            print("Invalid query")
        for term in filtered_terms:
            idx = self.terms[term]
            query[idx, 0] += 1
            query *= 1/query.sum()      
        if k==0:
            fit = cosine_similarity_normalized(query, self.tfidf_matrix).toarray()[0]
        else:
            reduced_matrix, svd = self._reduce_svd_matrix(k)
            reduced_matrix = normalize(reduced_matrix, norm='l2', axis=1).T
            reduced_query = svd.transform(query.T)
            reduced_query = normalize(reduced_query, norm='l2', axis=1).T
            fit = cosine_similarity_normalized(reduced_query, reduced_matrix)[0]
        
        result_indices = np.argsort(-fit)
        results = [(self.documents[idx], fit[idx]) for idx in result_indices[: 20]]
        return results

In [335]:
se = SearchEngine()
res = se.search("computer science major", k = 50)
print(res)

[('https://simple.wikipedia.org/wiki/Theoretical_computer_science', np.float64(0.9708594104158843)), ('https://simple.wikipedia.org/wiki/Computer_science', np.float64(0.9308094391745394)), ('https://simple.wikipedia.org/wiki/Computing', np.float64(0.9308094391745394)), ('https://simple.wikipedia.org/wiki/Debugging', np.float64(0.9218291368252772)), ('https://simple.wikipedia.org/wiki/Computability_theory', np.float64(0.9211531323586261)), ('https://simple.wikipedia.org/wiki/Theory_of_computation', np.float64(0.9188093142757258)), ('https://simple.wikipedia.org/wiki/Computation', np.float64(0.908246275554424)), ('https://simple.wikipedia.org/wiki/Framework', np.float64(0.9037145838094329)), ('https://simple.wikipedia.org/wiki/Distributed_computing', np.float64(0.9017982083727221)), ('https://simple.wikipedia.org/wiki/Computer_vision', np.float64(0.8994008017079984)), ('https://simple.wikipedia.org/wiki/Computer_security', np.float64(0.8968233826706269)), ('https://simple.wikipedia.org/w