Skip to content

Commit

Permalink
partial fix for issue #833
Browse files Browse the repository at this point in the history
  • Loading branch information
ofrei committed Apr 27, 2018
1 parent 6f25dc5 commit 32f02ba
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions python/artm/artm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pandas import DataFrame
from six import iteritems, string_types
from six.moves import range, zip
from multiprocessing.pool import ThreadPool
from multiprocessing.pool import ThreadPool, ApplyResult
from copy import deepcopy
import tqdm

Expand Down Expand Up @@ -92,11 +92,11 @@ def _topic_selection_regularizer_func(self, regularizers):


class ArtmThreadPool(object):
def __init__(self):
self._pool = ThreadPool(processes=1)
def __init__(self, async=True):
self._pool = ThreadPool(processes=1) if async else None

def apply_async(self, func, args):
return self._pool.apply_async(func, args)
return self._pool.apply_async(func, args) if self._pool else func(*args)

def __deepcopy__(self, memo):
return self
Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(self, num_topics=None, topic_names=None, num_processors=None, class
self._theta_columns_naming = 'id'
self._seed = -1
self._show_progress_bars = show_progress_bars
self._pool = ArtmThreadPool()
self._pool = ArtmThreadPool(async=show_progress_bars)

if topic_names is not None:
self._topic_names = topic_names
Expand Down Expand Up @@ -523,15 +523,11 @@ def seed(self, seed):
else:
self._seed = seed

@show_progress_bars.setter
def show_progress_bars(self, show_progress_bars):
if not isinstance(show_progress_bars, bool):
raise IOError('show_progress_bars should be bool')
else:
self._show_progress_bars = show_progress_bars

# ========== PRIVATE METHODS ==========
def _wait_for_batches_processed(self, async_result, num_batches):
if not(isinstance(async_result, ApplyResult)):
return async_result

import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
Expand Down

0 comments on commit 32f02ba

Please sign in to comment.