Skip to content

Commit

Permalink
Revert "Allow passing sparse matrices to _parse_n_wd() (#835)" (#836)
Browse files Browse the repository at this point in the history
This reverts commit abee373.
  • Loading branch information
MelLain committed Aug 11, 2017
1 parent abee373 commit 721c242
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 124 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ before_install:
# we need latest pip to work with only-binary option
- pip install -U pip
- pip install -U pytest pep8
- pip install -U numpy scipy pandas tqdm --only-binary numpy scipy pandas
- pip install -U numpy pandas tqdm --only-binary numpy pandas
- pip install protobuf==3.0.0
# configure ccache
# code from https://github.com/urho3d/Urho3D/blob/master/.travis.yml
Expand Down
2 changes: 1 addition & 1 deletion appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ install:
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
- conda info -a
- conda install numpy scipy pandas pytest
- conda install numpy pandas pytest
- conda install -c conda-forge tqdm

# scripts to run before build
Expand Down
54 changes: 18 additions & 36 deletions python/artm/batches_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,50 +223,32 @@ def __reset_batch():
batch.id = str(uuid.uuid4())
return batch, {}

try:
from scipy.sparse.base import spmatrix
except ImportError:
spmatrix = tuple()

os.mkdir(self._target_folder)
global_vocab, global_n = {}, 0.0
batch, batch_vocab = __reset_batch()
try:
n_wd_T = n_wd.T
except AttributeError:
raise TypeError("Expected a transposable matrix, got {}".format(type(n_wd)))
for item_id, column in enumerate(n_wd_T):
for item_id, column in enumerate(n_wd.T):
item = batch.item.add()
item.id = item_id
for key in global_vocab.keys():
global_vocab[key][2] = False # all tokens haven't appeared in this item yet

if isinstance(column, np.matrix):
enum = enumerate(np.squeeze(np.asarray(column), axis=0))
elif isinstance(column, np.ndarray):
enum = enumerate(column)
elif isinstance(column, spmatrix):
nnz = column.nonzero()[1]
enum = zip(nnz, np.squeeze(column[0, nnz].toarray(), axis=0))
else:
raise TypeError("Unsupported column type: %s" % type(column))
for token_id, value in enum:
if value <= GLOB_EPS:
continue
token = vocab[token_id]
if token not in global_vocab:
global_vocab[token] = [0, 0, False] # token_tf, token_df, appeared in this item

global_vocab[token][0] += value
global_vocab[token][1] += 0 if global_vocab[token][2] else 1
global_n += value

if token not in batch_vocab:
batch_vocab[token] = len(batch.token)
batch.token.append(token)

item.token_id.append(batch_vocab[token])
item.token_weight.append(float(value))
col = column if isinstance(column, type(np.zeros([0]))) else column.tolist()[0]
for token_id, value in enumerate(col):
if value > GLOB_EPS:
token = vocab[token_id]
if token not in global_vocab:
global_vocab[token] = [0, 0, False] # token_tf, token_df, appeared in this item

global_vocab[token][0] += value
global_vocab[token][1] += 0 if global_vocab[token][2] else 1
global_n += value

if token not in batch_vocab:
batch_vocab[token] = len(batch.token)
batch.token.append(token)

item.token_id.append(batch_vocab[token])
item.token_weight.append(float(value))

if ((item_id + 1) % self._batch_size == 0 and item_id != 0) or ((item_id + 1) == n_wd.shape[1]):
filename = os.path.join(self._target_folder, '{}.batch'.format(batch.id))
Expand Down
108 changes: 23 additions & 85 deletions python/tests/artm/test_batches_utils.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,40 @@
# Copyright 2017, Additive Regularization of Topic Models.

from contextlib import contextmanager
import shutil
import glob
import tempfile
import os
import numpy
from scipy.sparse import csr_matrix
import pytest

from six.moves import range

import artm


def test_func():
data_path = os.environ.get('BIGARTM_UNITTEST_DATA')
# constatnts
num_uci_batches = 4
n_wd = numpy.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7], [4, 5, 6, 7, 8]])
n_wd_sparse = csr_matrix(numpy.array([[1, 2, 3, 0, 0], [2, 0, 0, 0, 6], [0, 0, 5, 6, 7], [4, 5, 0, 0, 8]]))
vocab = {0: 'test', 1: 'artm', 2: 'python', 3: 'batch'}
num_n_wd_batches = 3
n_wd_num_tokens = n_wd.shape[0]
dictionary_name = 'dict.txt'
n_wd_tokens_list = ['test', 'python', 'artm', 'batch']
n_wd_token_tf_list = ['15.0', '25.0', '20.0', '30.0']
n_wd_sparse_token_tf_list = ['18.0', '17.0', '6.0', '8.0']
n_wd_token_df_list = [str(float(n_wd.shape[1])) + '\n'] * n_wd.shape[0]
n_wd_sparse_token_df_list = ['2.0\n', '3.0\n'] # doc freq

# test_bow_uci
batches_directory = tempfile.mkdtemp()
data_path = os.environ.get('BIGARTM_UNITTEST_DATA')
batches_folder = tempfile.mkdtemp()
try:
uci_batch_vectorizer = artm.BatchVectorizer(data_path=data_path,
data_format='bow_uci',
collection_name='kos',
target_folder=batches_directory)
target_folder=batches_folder)

assert len(glob.glob(os.path.join(batches_directory, '*.batch'))) == num_uci_batches
assert len(glob.glob(os.path.join(batches_folder, '*.batch'))) == num_uci_batches
assert len(uci_batch_vectorizer.batches_list) == num_uci_batches


dictionary = uci_batch_vectorizer.dictionary
model = artm.ARTM(num_topics=10, dictionary=dictionary)
model.scores.add(artm.PerplexityScore(name='perplexity', dictionary=dictionary))
Expand All @@ -61,28 +56,26 @@ def test_func():

del in_memory_batch_vectorizer

batch_batch_vectorizer = artm.BatchVectorizer(data_path=batches_directory, data_format='batches')

batch_batch_vectorizer = artm.BatchVectorizer(data_path=batches_folder, data_format='batches')
assert len(batch_batch_vectorizer.batches_list) == num_uci_batches
finally:
shutil.rmtree(batches_directory)

# test_bow_uci():
uci_batch_vectorizer = artm.BatchVectorizer(data_path=data_path,
data_format='bow_uci',
collection_name='kos')

temp_target_folder = uci_batch_vectorizer._target_folder
assert os.path.isdir(temp_target_folder)
assert len(glob.glob(os.path.join(temp_target_folder, '*.batch'))) == num_uci_batches
uci_batch_vectorizer = artm.BatchVectorizer(data_path=data_path,
data_format='bow_uci',
collection_name='kos')

temp_target_folder = uci_batch_vectorizer._target_folder
assert os.path.isdir(temp_target_folder)
assert len(glob.glob(os.path.join(temp_target_folder, '*.batch'))) == num_uci_batches

uci_batch_vectorizer.__del__()
assert not os.path.isdir(temp_target_folder)

uci_batch_vectorizer.__del__()
assert not os.path.isdir(temp_target_folder)

# test_n_dw():
for matrix in (n_wd, numpy.matrix(n_wd), csr_matrix(n_wd)):
n_wd_batch_vectorizer = artm.BatchVectorizer(data_path=data_path,
data_format='bow_n_wd',
n_wd=matrix,
n_wd=n_wd,
vocabulary=vocab,
batch_size=2)

Expand All @@ -97,7 +90,7 @@ def test_func():
batch.ParseFromString(fin.read())
assert len(batch.item) == 2 or len(batch.item) == 1
assert len(batch.token) == n_wd_num_tokens

n_wd_batch_vectorizer.dictionary.save_text(os.path.join(temp_target_folder, dictionary_name))
assert os.path.isfile(os.path.join(temp_target_folder, dictionary_name))
with open(os.path.join(temp_target_folder, dictionary_name), 'r') as fin:
Expand All @@ -110,7 +103,7 @@ def test_func():
tokens.append(temp[0])
token_tf.append(temp[3])
token_df.append(temp[4])

assert counter == n_wd_num_tokens + 2

# ToDo: we're not able to compare lists directly in Python 3 because of
Expand All @@ -121,60 +114,5 @@ def test_func():

n_wd_batch_vectorizer.__del__()
assert not os.path.isdir(temp_target_folder)

# test_sparse_n_wd():
n_wd_batch_vectorizer = artm.BatchVectorizer(data_path=data_path,
data_format='bow_n_wd',
n_wd=n_wd_sparse,
vocabulary=vocab,
batch_size=2)

temp_target_folder = n_wd_batch_vectorizer._target_folder
assert os.path.isdir(temp_target_folder)
assert len(n_wd_batch_vectorizer.batches_list) == num_n_wd_batches
assert len(glob.glob(os.path.join(temp_target_folder, '*.batch'))) == num_n_wd_batches

for i in range(num_n_wd_batches):
with open(n_wd_batch_vectorizer.batches_ids[i], 'rb') as fin:
batch = artm.messages.Batch()
batch.ParseFromString(fin.read())
assert len(batch.item) == 2 or len(batch.item) == 1
assert 2 <= len(batch.token) <= n_wd_num_tokens

n_wd_batch_vectorizer.dictionary.save_text(os.path.join(temp_target_folder, dictionary_name))
assert os.path.isfile(os.path.join(temp_target_folder, dictionary_name))
with open(os.path.join(temp_target_folder, dictionary_name), 'r') as fin:
counter = 0
tokens, token_tf, token_df = [], [], []
for line in fin:
counter += 1
if counter > 2:
temp = line.split(', ')
tokens.append(temp[0])
token_tf.append(temp[3])
token_df.append(temp[4])

assert counter == n_wd_num_tokens + 2

# ToDo: we're not able to compare lists directly in Python 3 because of
# unknown reasons. This should be fixed
assert set(tokens) == set(n_wd_tokens_list)
assert set(token_tf) == set(n_wd_sparse_token_tf_list)
assert set(token_df) == set(n_wd_sparse_token_df_list)

n_wd_batch_vectorizer.__del__()
assert not os.path.isdir(temp_target_folder)

# test_errors_n_wd():
with pytest.raises(TypeError):
n_wd_batch_vectorizer = artm.BatchVectorizer(data_path=data_path,
data_format='bow_n_wd',
n_wd="a mess",
vocabulary=vocab,
batch_size=2)
with pytest.raises(TypeError):
n_wd_batch_vectorizer = artm.BatchVectorizer(data_path=data_path,
data_format='bow_n_wd',
n_wd=numpy.array([["1", "2"], ["3", "4"]]),
vocabulary=vocab,
batch_size=2)
finally:
shutil.rmtree(batches_folder)
2 changes: 1 addition & 1 deletion python/tests/artm/test_dump_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _assert_score_values_equality(model_1, model_2):
for name in model_1.scores.data.keys():
if name == 'perp' or name == 'sp_theta' or name == 'sp_nwt':
assert sum([abs(x - y) for x, y in zip(model_1.score_tracker[name].value,
model_2.score_tracker[name].value)]) < 0.005
model_2.score_tracker[name].value)]) < 0.001
elif name == 'top_tok':
assert set(model_1.score_tracker[name].last_tokens) == set(model_2.score_tracker[name].last_tokens)
elif name == 'kernel':
Expand Down

0 comments on commit 721c242

Please sign in to comment.