Skip to content

Commit

Permalink
Merge pull request #301 from mdekstrand/tweak/shm-test
Browse files Browse the repository at this point in the history
Disable SharedMemory on Windows and improve testing
  • Loading branch information
mdekstrand committed Feb 11, 2022
2 parents ea900a1 + 14b28c9 commit 1b3504c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
8 changes: 6 additions & 2 deletions lenskit/sharing/shm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import sys
import logging
import pickle
from . import sharing_mode, PersistedModel

try:
import multiprocessing.shared_memory as shm
SHM_AVAILABLE = True
SHM_AVAILABLE = sys.platform != 'win32'
except ImportError:
SHM_AVAILABLE = False

Expand Down Expand Up @@ -89,7 +90,7 @@ def close(self, unlink=True):
self.buffers = None
if self.memory is not None:
self.memory.close()
if self.is_owner:
if unlink and self.is_owner and self.is_owner != 'transfer':
self.memory.unlink()
self.is_owner = False
self.memory = None
Expand All @@ -112,3 +113,6 @@ def __setstate__(self, state):
if self.is_owner:
_log.debug('opening shared buffers after ownership transfer')
self._open()

def __del__(self):
self.close(False)
15 changes: 10 additions & 5 deletions tests/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import logging
import multiprocessing as mp
import numpy as np
import pytest

from lenskit.util.parallel import invoker, proc_count, run_sp, is_worker, is_mp_worker
from lenskit.util.test import set_env_var
from lenskit.util.random import get_root_seed
from lenskit.sharing import persist_binpickle
from lenskit.sharing import persist, SHM_AVAILABLE

from pytest import mark, raises, approx

Expand Down Expand Up @@ -77,9 +78,9 @@ def _sp_matmul(a1, a2, *, fail=False):
return a1 @ a2


def _sp_matmul_p(a1, a2, *, fail=False):
def _sp_matmul_p(a1, a2, *, method=None, fail=False):
_log.info('in worker process')
return persist_binpickle(a1 @ a2).transfer()
return persist(a1 @ a2, method=method).transfer()


def test_run_sp():
Expand All @@ -98,11 +99,15 @@ def test_run_sp_fail():
run_sp(_sp_matmul, a1, a2, fail=True)


def test_run_sp_persist():
@pytest.mark.parametrize('method', [None, 'binpickle', 'shm'])
def test_run_sp_persist(method):
if method == 'shm' and not SHM_AVAILABLE:
pytest.skip('SHM backend not available')

a1 = np.random.randn(100, 100)
a2 = np.random.randn(100, 100)

res = run_sp(_sp_matmul_p, a1, a2)
res = run_sp(_sp_matmul_p, a1, a2, method=method)
try:
assert res.is_owner
assert np.all(res.get() == a1 @ a2)
Expand Down

0 comments on commit 1b3504c

Please sign in to comment.