Skip to content

Commit

Permalink
added max_token_line_num to parse_lda_topics to limit num read lines
Browse files Browse the repository at this point in the history
LDAResults now only reads up to the max token hash number
  • Loading branch information
dkrasner committed May 30, 2015
1 parent b0bd1da commit 94bea48
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions rosetta/text/vw_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def parse_varinfo(varinfo_file):
return varinfo


def parse_lda_topics(topics_file, num_topics, normalize=True):
def parse_lda_topics(topics_file, num_topics, max_token_line_num=1e+100,
normalize=True):
"""
Returns a DataFrame representation of the topics output of an lda VW run.
Expand All @@ -71,6 +72,9 @@ def parse_lda_topics(topics_file, num_topics, normalize=True):
The --readable_model output of a VW lda run
num_topics : Integer
The number of topics in every valid row
max_token_line_num : Integer
Reading of token probabilities from the topics_file will stop after this
line number. Useful, when you know the max hash value of your tokens.
normalize : Boolean
Normalize the rows so that they represent probabilities of topic
given hash_val
Expand All @@ -89,7 +93,9 @@ def parse_lda_topics(topics_file, num_topics, normalize=True):
# Once we detect that we're in the valid rows, there better not be
# any exceptions!
in_valid_rows = False
for line in open_file:
for i, line in enumerate(open_file):
if i > max_token_line_num:
break
try:
# If this row raises an exception, then it isn't a valid row
# Sometimes trailing space...that's the reason for split()
Expand Down Expand Up @@ -212,7 +218,7 @@ def __init__(
predictions_file : filepath or buffer
The -p output of a VW lda run
num_topics : Integer or None
The number of topics in every valid row; if None will infer num
The number of topics in every valid row; if None will infer num
topics from predictions_file
sfile_filter : filepath, buffer, or loaded text_processors.SFileFilter
Contains the token2id and id2token mappings
Expand All @@ -234,7 +240,9 @@ def __init__(
self.sfile_frame = sfile_filter.to_frame()

# Load the topics file
topics = parse_lda_topics(topics_file, num_topics, normalize=False)
topics = parse_lda_topics(topics_file, num_topics,
max(sfile_filter.id2token.keys()),
normalize=False)
topics = topics.reindex(index=sfile_filter.id2token.keys())
topics = topics.rename(index=sfile_filter.id2token)

Expand All @@ -247,7 +255,7 @@ def __init__(
self.num_docs = len(predictions)
self.num_tokens = len(topics)
self.topics = topics.columns.tolist()
self.tokens = topics.index.tolist()
self.tokens = topics.index.tolist()
self.docs = predictions.index.tolist()

# Check that the topics/docs/token names are unique with no overlap
Expand Down

0 comments on commit 94bea48

Please sign in to comment.