Skip to content

Commit

Permalink
Dictionary fix and improvment (#826)
Browse files Browse the repository at this point in the history
* add ability to re-morm values in dictionary while filtering

* fix bug with df_rate (num_items would not be lost during import and filter any more)
  • Loading branch information
MelLain committed Jul 29, 2017
1 parent e881ed2 commit 709da7d
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 19 deletions.
14 changes: 10 additions & 4 deletions python/artm/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def save_text(self, dictionary_path, encoding='utf-8'):
"""
dictionary_data = self._master.get_dictionary(self._name)
with codecs.open(dictionary_path, 'w', encoding) as fout:
fout.write(u'name: {}\n'.format(dictionary_data.name))
fout.write(u'name: {} num_items: {}\n'.format(dictionary_data.name,
dictionary_data.num_items_in_collection))
fout.write(u'token, class_id, token_value, token_tf, token_df\n')

for i in range(len(dictionary_data.token)):
Expand All @@ -106,7 +107,9 @@ def load_text(self, dictionary_path, encoding='utf-8'):
self._reset()
dictionary_data = messages.DictionaryData()
with codecs.open(dictionary_path, 'r', encoding) as fin:
dictionary_data.name = fin.readline().split(' ')[1][0: -1]
first_str = fin.readline()[: -1].split(' ')
dictionary_data.name = first_str[1]
dictionary_data.num_items_in_collection = int(first_str[3])
fin.readline() # skip comment line

for line in fin:
Expand Down Expand Up @@ -154,7 +157,7 @@ def gather(self, data_path, cooc_file_path=None, vocab_file_path=None, symmetric
symmetric_cooc_values=symmetric_cooc_values)

def filter(self, class_id=None, min_df=None, max_df=None, min_df_rate=None, max_df_rate=None,
min_tf=None, max_tf=None, max_dictionary_size=None):
min_tf=None, max_tf=None, max_dictionary_size=None, recalculate_value=False):
"""
:Description: filters the BigARTM dictionary of the collection, which\
was already loaded into the lib
Expand All @@ -170,6 +173,8 @@ def filter(self, class_id=None, min_df=None, max_df=None, min_df_rate=None, max_
:param float max_tf: max tf value to pass the filter
:param float max_dictionary_size: give an easy option to limit dictionary size;
rare tokens will be excluded until dictionary reaches given size.
:param bool recalculate_value: recalculate or not value field in dictionary after filtration\
according to new sun of tf values
:Note: the current dictionary will be replaced with filtered
"""
Expand All @@ -182,7 +187,8 @@ def filter(self, class_id=None, min_df=None, max_df=None, min_df_rate=None, max_
max_df_rate=max_df_rate,
min_tf=min_tf,
max_tf=max_tf,
max_dictionary_size=max_dictionary_size)
max_dictionary_size=max_dictionary_size,
recalculate_value=recalculate_value)

def __deepcopy__(self, memo):
return self
Expand Down
5 changes: 5 additions & 0 deletions python/artm/master_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def filter_dictionary(self, dictionary_name=None, dictionary_target_name=None, c
min_df_rate=None, max_df_rate=None,
min_tf=None, max_tf=None,
max_dictionary_size=None,
recalculate_value=None,
args=None):

"""
Expand All @@ -353,6 +354,8 @@ def filter_dictionary(self, dictionary_name=None, dictionary_target_name=None, c
:param float max_tf: max tf value to pass the filter
:param float max_dictionary_size: give an easy option to limit dictionary size;
rare tokens will be excluded until dictionary reaches given size.
:param bool recalculate_value: recalculate or not value field in dictionary after filtration\
according to new sun of tf values
:param args: an instance of FilterDictionaryArgs
"""
filter_args = messages.FilterDictionaryArgs()
Expand All @@ -378,6 +381,8 @@ def filter_dictionary(self, dictionary_name=None, dictionary_target_name=None, c
filter_args.max_tf = max_tf
if max_dictionary_size is not None:
filter_args.max_dictionary_size = max_dictionary_size
if recalculate_value is not None:
filter_args.recalculate_value = recalculate_value

self._lib.ArtmFilterDictionary(self.master_id, filter_args)

Expand Down
50 changes: 41 additions & 9 deletions python/tests/artm/test_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,34 @@ def test_func():

num_tokens = 6906
num_filtered_tokens = 2852
num_rate_filtered_tokens = 122
eps = 1e-5

def _check_num_tokens_in_saved_text_dictionary(file_name, filtered=False):
def _check_num_tokens_in_saved_text_dictionary(file_name, case_type=0):
with open(file_name, 'r') as fin:
fin.readline()
fin.readline()
counter = 0
value = 0.0

for line in fin:
counter += 1
assert counter == (num_tokens if not filtered else num_filtered_tokens)
splitted = line.split(' ')
if len(splitted) == 5:
counter += 1
value += float(splitted[2][: -1])

if case_type == 0:
assert counter == num_tokens
assert abs(value - 1.0) < eps
elif case_type == 1:
assert counter == num_tokens
assert abs(value - 0.0) < eps
elif case_type == 2:
assert counter == num_filtered_tokens
assert abs(value - 1.0) < eps
elif case_type == 3:
assert counter == num_rate_filtered_tokens
assert abs(value - 1.0) < eps

try:
batch_vectorizer = artm.BatchVectorizer(data_path=data_path,
Expand All @@ -34,7 +53,6 @@ def _check_num_tokens_in_saved_text_dictionary(file_name, filtered=False):

dictionary_1 = artm.Dictionary()
dictionary_1.gather(data_path=batches_folder)

dictionary_1.save_text(dictionary_path=os.path.join(batches_folder, 'saved_text_dict_1.txt'))
_check_num_tokens_in_saved_text_dictionary(os.path.join(batches_folder, 'saved_text_dict_1.txt'))

Expand All @@ -55,17 +73,31 @@ def _check_num_tokens_in_saved_text_dictionary(file_name, filtered=False):
dictionary_data.class_id.append('@default_class')
dictionary_data.token_value.append(0.0)
dictionary_data.token_df.append(0.0)
dictionary_data.token_tf.append(0.0)
dictionary_data.token_tf.append(1.0)
f = os.path.join(batches_folder, 'saved_text_dict_3.txt')
dictionary_data.num_items_in_collection = int(open(f).readline()[: -1].split(' ')[3])

dictionary_4 = artm.Dictionary()
dictionary_4.create(dictionary_data=dictionary_data)
dictionary_4.filter()
dictionary_4.save_text(dictionary_path=os.path.join(batches_folder, 'saved_text_dict_4.txt'))
_check_num_tokens_in_saved_text_dictionary(os.path.join(batches_folder, 'saved_text_dict_4.txt'))
_check_num_tokens_in_saved_text_dictionary(os.path.join(batches_folder, 'saved_text_dict_4.txt'), case_type=1)

dictionary_5 = artm.Dictionary()
dictionary_5.load(dictionary_path=os.path.join(batches_folder, 'saved_dict_1.dict'))
dictionary_5.filter(min_df=2, max_df=100, min_tf=1, max_tf=20)
dictionary_5.create(dictionary_data=dictionary_data)
dictionary_5.filter(recalculate_value=True)
dictionary_5.save_text(dictionary_path=os.path.join(batches_folder, 'saved_text_dict_5.txt'))
_check_num_tokens_in_saved_text_dictionary(os.path.join(batches_folder, 'saved_text_dict_5.txt'), filtered=True)
_check_num_tokens_in_saved_text_dictionary(os.path.join(batches_folder, 'saved_text_dict_5.txt'))

dictionary_6 = artm.Dictionary()
dictionary_6.load(dictionary_path=os.path.join(batches_folder, 'saved_dict_1.dict'))
dictionary_6.filter(min_df=2, max_df=100, min_tf=1, max_tf=20, recalculate_value=True)
dictionary_6.save_text(dictionary_path=os.path.join(batches_folder, 'saved_text_dict_6.txt'))
_check_num_tokens_in_saved_text_dictionary(os.path.join(batches_folder, 'saved_text_dict_6.txt'), case_type=2)

dictionary_6.filter(min_df_rate=0.001, max_df_rate=0.002, recalculate_value=True)
dictionary_6.save_text(dictionary_path=os.path.join(batches_folder, 'saved_text_dict_6.txt'))
_check_num_tokens_in_saved_text_dictionary(os.path.join(batches_folder, 'saved_text_dict_6.txt'), case_type=3)
finally:
shutil.rmtree(batches_folder)

4 changes: 2 additions & 2 deletions src/artm/core/dictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class Dictionary {
const DictionaryEntry* entry(const Token& token) const;
const DictionaryEntry* entry(int index) const;

size_t size() const { return entries_.size(); }
size_t num_items() const { return num_items_in_collection_; }
int size() const { return entries_.size(); }
int num_items() const { return num_items_in_collection_; }
const std::string& name() const { return name_; }
bool has_valid_cooc_state() const;
int64_t ByteSize() const;
Expand Down
19 changes: 15 additions & 4 deletions src/artm/core/dictionary_operations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ std::shared_ptr<Dictionary> DictionaryOperations::Import(const ImportDictionaryA

// part with main dictionary
if (dict_data.token_size() > 0) {
dictionary->SetNumItems(dict_data.num_items_in_collection());
for (int token_id = 0; token_id < dict_data.token_size(); ++token_id) {
dictionary->AddEntry(DictionaryEntry(Token(dict_data.class_id(token_id), dict_data.token(token_id)),
dict_data.token_value(token_id), dict_data.token_tf(token_id), dict_data.token_df(token_id)));
Expand Down Expand Up @@ -445,6 +446,7 @@ std::shared_ptr<Dictionary> DictionaryOperations::Gather(const GatherDictionaryA

std::shared_ptr<Dictionary> DictionaryOperations::Filter(const FilterDictionaryArgs& args, const Dictionary& dict) {
auto dictionary = std::make_shared<Dictionary>(Dictionary(args.dictionary_target_name()));
dictionary->SetNumItems(dict.num_items());

auto& src_entries = dict.entries();
auto& dictionary_token_index = dict.token_index();
Expand All @@ -453,7 +455,7 @@ std::shared_ptr<Dictionary> DictionaryOperations::Filter(const FilterDictionaryA
float size = static_cast<float>(dict.num_items());
std::vector<bool> entries_mask(src_entries.size(), false);
std::vector<float> df_values;
int accepted_tokens_count = 0;
double new_tf_normalizer = 0.0;

for (int entry_index = 0; entry_index < (int64_t) src_entries.size(); entry_index++) {
auto& entry = src_entries[entry_index];
Expand Down Expand Up @@ -483,8 +485,9 @@ std::shared_ptr<Dictionary> DictionaryOperations::Filter(const FilterDictionaryA
}
}

entries_mask[entry_index] = true; // pass all filters
entries_mask[entry_index] = true; // have passed all filters
df_values.push_back(entry.token_df());
new_tf_normalizer += entry.token_tf();
}

// Handle max_dictionary_size
Expand All @@ -493,24 +496,32 @@ std::shared_ptr<Dictionary> DictionaryOperations::Filter(const FilterDictionaryA
std::sort(df_values.begin(), df_values.end(), std::greater<float>());
float min_df_due_to_size = df_values[args.max_dictionary_size()];


for (int entry_index = 0; entry_index < (int64_t) src_entries.size();
entry_index++) {
auto& entry = src_entries[entry_index];
if (entry.token_df() <= min_df_due_to_size) {
entries_mask[entry_index] = false;
new_tf_normalizer -= entry.token_tf();
}
}
}

int accepted_tokens_count = 0;
for (int entry_index = 0; entry_index < (int64_t) src_entries.size(); entry_index++) {
if (!entries_mask[entry_index]) {
continue;
}

// all filters were passed, add token to the new dictionary
auto& entry = src_entries[entry_index];
accepted_tokens_count += 1;
dictionary->AddEntry(entry);
++accepted_tokens_count;
if (args.recalculate_value()) {
float value = static_cast<float>(new_tf_normalizer > 0.0 ? entry.token_tf() / new_tf_normalizer : 0.0);
dictionary->AddEntry({ entry.token(), value, entry.token_tf(), entry.token_df() });
} else {
dictionary->AddEntry(entry);
}

old_index_new_index.insert(std::pair<int, int>(dictionary_token_index.find(entry.token())->second,
accepted_tokens_count - 1));
Expand Down
1 change: 1 addition & 0 deletions src/artm/messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ message FilterDictionaryArgs {
optional float max_tf = 9;

optional int64 max_dictionary_size = 10;
optional bool recalculate_value = 11 [default = false];
}

message GatherDictionaryArgs {
Expand Down

0 comments on commit 709da7d

Please sign in to comment.