In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
article_date = '2018-12-21'

In [None]:
import ast
import pandas as pd
from datetime import datetime
from sklearn.neighbors import NearestNeighbors
from sklearn.feature_extraction.text import TfidfVectorizer
from faculty import datasets

from models.article import Article

In [None]:
def load_article_from_datasets(foo):
    articles = []
    with datasets.open(f'/input/article_content/{foo}.csv') as f:
        df = pd.read_csv(f, sep='\t', encoding='utf-8')
    for row in df.iterrows():
        try:
            articles.append(Article(
                row[1]['article_url'], 
                row[1]['article_title'],
                row[1]['article_description'],
                row[1]['source_id'],
                row[1]['published_at'],
                row[1]['article_uuid'],
                ast.literal_eval(row[1]['named_entities']),
                None,
                row[1]['raw_content']
            ))
        except:
            pass
    return articles

In [None]:
articles = load_article_from_datasets(article_date)

In [None]:
# Create tf_idf matrix from articles
test_url = 'https://www.bbc.co.uk/news/world-us-canada-46657393'
test_article = Article(test_url, '', '', '', datetime.now())

# List of named entities
named_entities_list = list(map(lambda x: ' '.join(x.named_entities), articles))
named_entities_list.append(' '.join(test_article.named_entities))

# TF-IDF matrix
tfidf_vectorizer = TfidfVectorizer()
tfidf_matrix = tfidf_vectorizer.fit_transform(named_entities_list)

# Fit KNN
nbrs = NearestNeighbors(n_neighbors=10) 
nbrs.fit(tfidf_matrix)

# Predict
test_row = tfidf_matrix.getrow(len(named_entities_list) - 1)
distances, indices = nbrs.kneighbors(test_row)

# Format predictions
similar_articles = []
for idx in indices.flatten()[1:]:
    similar_articles.append(articles[idx])
    
df = pd.DataFrame({
    'distance': distances.flatten()[1:],
    'titles': list(map(lambda x: x.title, similar_articles)),
    'named_entities': list(map(lambda x: x.named_entities, similar_articles)),
    'url': list(map(lambda x: x.url, similar_articles)),
})
pd.set_option('display.max_colwidth', -1)
print(df)