In [1]:
# Copyright (c) 2016-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.


import numpy as np
import time
import inspect

from scipy.sparse import csr_matrix
from sklearn.datasets import fetch_20newsgroups
from sklearn.neighbors import LSHForest
from sklearn.feature_extraction import DictVectorizer

In [2]:
import pysparnn
import pysparnn_utils

In [3]:
!wget https://www.cs.cmu.edu/~./enron/enron_mail_20150507.tgz
_ = !tar -xzvf enron_mail_20150507.tgz

--2016-04-09 18:22:44--  https://www.cs.cmu.edu/~./enron/enron_mail_20150507.tgz
Resolving www.cs.cmu.edu (www.cs.cmu.edu)... 128.2.217.13
Connecting to www.cs.cmu.edu (www.cs.cmu.edu)|128.2.217.13|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 443254787 (423M) [application/x-tar]
Saving to: ‘enron_mail_20150507.tgz’


2016-04-09 18:32:02 (777 KB/s) - ‘enron_mail_20150507.tgz’ saved [443254787/443254787]



In [3]:
import os
import sys

docs = []
for folder, subs, files in os.walk('maildir'):
    for filename in files:
        with open(os.path.join(folder, filename), 'r') as src:
            txt = ' '.join(src.readlines())
            if len(txt) > 0:
                docs.append(txt.split())

In [4]:
print 'Num docs: {}'.format(len(docs))
print 'Avg doc length: {}'.format(np.mean([len(x) for x in docs]))
words = set()
for doc in docs:
    words.update(doc)
print 'Num unique words: {}'.format(len(words))

Num docs: 517401
Avg doc length: 329.550878332
Num unique words: 2584811


In [5]:
print inspect.getsource(pysparnn_utils.PySparNNTextSearch)

class PySparNNTextSearch:
    def __init__(self, docs, datas):
        
        self.dv = DictVectorizer()
        dicts = []
        for d in docs:
            dicts.append(dict([(w, 1) for w in d]))
        self.dv.fit(dicts)
        features = csr_matrix(self.dv.transform(dicts), dtype=int)
        self.cp = pysparnn.ClusterIndex(features, datas, pysparnn.matrix_distance.UnitCosineDistance)
        
    def search(self, docs, k=1, min_distance=None, max_distance=None, k_clusters=1, return_metric=False):
        dicts = []
        for d in docs:
            dicts.append(dict([(w, 1) for w in d]))
        features = csr_matrix(self.dv.transform(dicts), dtype=int)
        return self.cp.search(features, k=k, min_distance=min_distance, max_distance=max_distance, 
                              k_clusters=k_clusters, return_metric=return_metric)



In [6]:
t0 = time.time()
text_search = pysparnn_utils.PySparNNTextSearch(docs, range(len(docs)))
print(time.time() - t0)

445.963020802


In [7]:
snn_time, snn_accuracy = pysparnn_utils.identity_benchmark(text_search, docs)
print('PySparNN median time per query: {0}'.format(snn_time)) 
print('PySparNN median accuracy: {0}'.format(snn_accuracy)) 

PySparNN median time per query: 0.0131492972374
PySparNN median accuracy: 1.0


In [8]:
print inspect.getsource(pysparnn_utils.LSHForestSearch)

class LSHForestSearch:
    def __init__(self, docs):
        self.lshf = LSHForest(n_estimators=1, n_candidates=1,
                     n_neighbors=1)
        self.dv = DictVectorizer()
        dicts = []
        for d in docs:
            dicts.append(dict([(w, 1) for w in d]))
        self.dv.fit(dicts)
        features = self.dv.transform(dicts)
        # floats are faster
        # features = csr_matrix(features, dtype=int)
        self.lshf.fit(features)
        
    def search(self, docs):
        dicts = []
        for d in docs:
            dicts.append(dict([(w, 1) for w in d]))
        features = self.dv.transform(dicts)
        # floats are faster
        # features = csr_matrix(features, dtype=int)
        return self.lshf.kneighbors(features, return_distance=False)    



In [9]:
t0 = time.time()
lsh_search = pysparnn_utils.LSHForestSearch(docs)
print(time.time() - t0)

226.671472788


In [10]:
lsh_time, lsh_accuracy = pysparnn_utils.identity_benchmark(lsh_search, docs)
print('LSH median time per query: {0}'.format(lsh_time)) 
print('LSH median accuracy: {0}'.format(lsh_accuracy)) 

LSH median time per query: 0.0226684308052
LSH median accuracy: 1.0


In [11]:
lsh_time / snn_time

1.7239271723767795