Skip to content

Commit

Permalink
specify project id instead of arbitrary index
Browse files Browse the repository at this point in the history
  • Loading branch information
csiu committed Apr 24, 2017
1 parent a0f28e8 commit afb745e
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/python/sim_doc.py
Expand Up @@ -25,8 +25,8 @@ def get_args():
parser.add_argument('-n', '--num_results', default=None, type=int,
help="Number of similar documents to print in the results")

parser.add_argument('-i', '--index_document0', default=0, type=int,
help="Index of query document")
parser.add_argument('-i', '--document0_id', default=None, type=int,
help="Kickstarter ID of query document")

parser.add_argument('-c', '--cache_dir', default=".",
help="Specify cache dir")
Expand Down Expand Up @@ -92,11 +92,16 @@ def doc_to_string(doc):

return(df['doc_processed'])

def compute_distance(U, i=0, sort=False, top_n=None, metric='euclidean'):
def compute_distance(U, i=None, sort=False, top_n=None, metric='euclidean'):
"""
Compute distance of document U[i] with all documents in U
"""
document0 = np.asmatrix(U[i])
if i != None:
index_document0 = df[df["id"] == i].index.tolist()
else:
index_document0 = 0

document0 = np.asmatrix(U[index_document0])

dist = pairwise_distances(document0, U, metric=metric)
df_dist = pd.DataFrame(np.transpose(dist), columns=["dist"])
Expand All @@ -114,7 +119,7 @@ def compute_distance(U, i=0, sort=False, top_n=None, metric='euclidean'):
if __name__ == '__main__':
args = get_args()
num_singular_values = args.num_singular_values
index_document0 = args.index_document0
document0_id = args.document0_id
num_results = args.num_results
cache_dir = args.cache_dir
verbose = args.verbose
Expand Down Expand Up @@ -145,7 +150,7 @@ def compute_distance(U, i=0, sort=False, top_n=None, metric='euclidean'):
n_iter=5, random_state=5)

if verbose: print("# Computing distances...")
top_n = compute_distance(U, i=index_document0,
top_n = compute_distance(U, i=document0_id,
sort=True, top_n=num_results)

if verbose: print("# Printing results...")
Expand Down

0 comments on commit afb745e

Please sign in to comment.