In [13]:
#obvious packages
import numpy as np
import pandas as pd
#text-processing packages
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neighbors import KDTree
#sparsematrix-packages
from scipy.sparse import csc_matrix #will give sparse
from scipy.sparse.linalg import norm  #efficient norm comps, fro is standard
from scipy.linalg.interpolative import estimate_rank #estimates rank of sparse matrix
from scipy.sparse.linalg import svds #fast svd for sparse matrices.
#other
from scipy import spatial #contains a good method for tree-search KDtree
#plotting stuff
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#%matplotlib notebook

In [14]:
#does some reading-in and some stuff.
data = pd.read_csv('preprocess.csv',  encoding ='ISO-8859-1')
data = data.dropna()
data = data.reset_index()

vectorizer = CountVectorizer()

X = vectorizer.fit_transform(data["text_pp"]) #calculates matrix, this case just count, could be Tf-idf
X = X.astype("double")
words = vectorizer.get_feature_names()

In [15]:
def do_svd(k):

    u, s, v = svds(X.T, k = k)

    sigma = np.diag(s)
    sigmainv = np.diag(1.0/s)

    coordinates = v.T
    coordinates = np.array([row/np.linalg.norm(row) for row in coordinates])
    
    return coordinates, u, sigmainv

def transform_query(query, sigmainv, u):
    
    if(not all([(w in words) for w in query.split(' ')])):
       raise Exception('One of the words is not in dictionary')
       
    vect = vectorizer.transform([query]).astype("double")[0].todense()
    
    transform = (sigmainv @ u.T @ vect.T).T
    
    return transform/np.linalg.norm(transform)


Compute svd, and the word coordinate matrix + kdtree for the coordinates

In [16]:
coordinates, u, sigmainv = do_svd(100)

n_words = len(words)

word_matrix = np.eye(n_words)

word_coords = (sigmainv @ u.T @ word_matrix).T
word_coords = [row/np.linalg.norm(row) for row in word_coords]

word_tree = KDTree(word_coords, metric = 'euclidean')

Find the k closest words to a given word and print them

In [18]:
word = 'hillary'
tword = transform_query(word, sigmainv, u)

nearest_dist, nearest_ind = word_tree.query(tword, k = 11)

for index in nearest_ind[0]:
    print(words[index])

hillary
sleeping
wheres
sigh
slogan
cake
rosie
ate
donnell
lgbt
threaten
