In [1]:
import spacy
from spacy.lang.en.stop_words import STOP_WORDS
from sklearn.neighbors import NearestNeighbors
import re
import numpy as np
import json
import random

In [2]:
DATASET_PATH = "data/News_Category_Dataset_v2.json"

In [3]:
def sample_news_dataset(k=5000, search_for="headline"):
    news = []
    for line in open(DATASET_PATH, 'r'):
        news.append(json.loads(line)[search_for])
    return random.sample(news, k)

In [4]:
documents = sample_news_dataset()

In [5]:
class DocumentSearch:
    def __init__(self, embedding_dim=300, neighbors=2, algorithm="ball_tree"):
        self.nlp = spacy.load("en_core_web_md")
        self.document_embeddings = {}
        self.stopwords = list(STOP_WORDS)
        self.embedding_dim = embedding_dim
        self.neighbors = neighbors
        self.neigh = NearestNeighbors(n_neighbors=neighbors, algorithm=algorithm)
    
    def process(self, doc, remove_stopwords=True):
        doc = doc.lower()
        doc = re.findall(r"\w+", doc)
        if remove_stopwords:
            doc = [word for word in doc if word not in self.stopwords]
        return " ".join(doc)
    
    def tokenize(self, doc):
        doc = self.process(doc)
        tokens = self.nlp(doc)
        return tokens
    
    def create_document_embedding(self, doc):
        tokens = self.tokenize(doc)
        document_embedding = np.zeros(self.embedding_dim)
        for token in tokens:
            document_embedding += token.vector
        return document_embedding
    
    def fit_nearest_neighbors(self):
        X = np.array(list(self.document_embeddings.values()))
        self.neigh.fit(X)
    
    def fit(self, documents):
        self.documents = documents
        print(f"{len(self.documents)} documents attached.")
        for document in self.documents:
            document_embedding = self.create_document_embedding(document)
            self.document_embeddings[len(self.document_embeddings)] = document_embedding 
        self.fit_nearest_neighbors()  
        print("Document embeddings created.")
        
    def retrive_document(self, index):
        retrived_document = self.documents[index]
        return retrived_document
    
    def print_function(self, distances, indicies):
        print("Similar Documents")
        print("-------------------------------------")
        print("\n")
        for i in range(self.neighbors):
            print(f"Document {i+1}")
            print(f"Distance from the document: {distances[i]}")
            print("\n")
            print(self.retrive_document(indicies[i]))
            print("-------------------------------------")
            print("\n")
              
    def search(self, new_document):
        new_document_embedding = self.create_document_embedding(new_document)
        distances, indices = self.neigh.kneighbors([new_document_embedding])
        self.print_function(distances[0], indices[0])

In [6]:
doc_search = DocumentSearch()

In [7]:
doc_search.fit(documents)

5000 documents attached.
Document embeddings created.


In [8]:
new_document = "Donald Trump Tweet Rescues 'Roseanne' And Family In Season Finale"

In [9]:
doc_search.search(new_document)

Similar Documents
-------------------------------------


Document 1
Distance from the document: 18.90370558596096


Just A Reminder, Donald Trump Actually Could Win The Election
-------------------------------------


Document 2
Distance from the document: 19.43576575830994


"Evangelicals Love Donald Trump!" Wait, What?
-------------------------------------


