Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support creating in-memory shared list from worker process #120

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions mobile_cv/torch/tests/utils_pytorch/test_shareables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import time
import unittest

import mobile_cv.torch.utils_pytorch.comm as comm
import mobile_cv.torch.utils_pytorch.distributed_helper as dh
import numpy as np
from mobile_cv.common.misc.py import PicklableWrapper
from mobile_cv.torch.utils_pytorch.shareables import (
share_numpy_array_locally,
SharedList,
)


class TestInMemorySharedNumpyArray(unittest.TestCase):
@dh.launch_deco(num_processes=2)
def test_shared_numpy_array(self):
# create numpy array on master process
if comm.get_rank() == 0:
data = np.array([1, 2, 3], dtype=np.dtype("float32"))
else:
data = None

# share the numpy array for all processes
data, shm_ref = share_numpy_array_locally(data)
self.assertEqual(data.shape, (3,))
self.assertEqual(data[0], 1.0)
self.assertEqual(data[1], 2.0)
self.assertEqual(data[2], 3.0)
comm.synchronize()


def _check_and_modify(self, shared_lst):
# the shared list should be available on all ranks
self.assertEqual(len(shared_lst), 3)
self.assertEqual(shared_lst[0], "old")
self.assertEqual(shared_lst[1], 2)
self.assertEqual(shared_lst[2], (3,))

# make sure other ranks have finished above checks before rank N-1 modifies the list
comm.synchronize()
# modify the list from rank N-1
if comm.get_rank() == comm.get_world_size() - 1:
shared_lst[0] = "new"
self.assertEqual(shared_lst[0], "new")
# setting different sized object is illegal
with self.assertRaises(ValueError):
shared_lst[0] = "long enough string"
self.assertEqual(shared_lst[0], "new")
# make sure rank N-1 has modified the list before rank0 does the next check
comm.synchronize()

# now the list should be updated on rank 0 as well since they're shared
self.assertEqual(shared_lst[0], "new")


class TestInMemoryShareables(unittest.TestCase):
def test_shared_list_single_process(self):
lst = ["old", 2, (3,)]
shared_lst = SharedList(lst, _allow_inplace_update=True)
del lst
_check_and_modify(self, shared_lst)

@dh.launch_deco(num_processes=2)
def test_shared_list_shared_among_peers(self):
"""
This test mimics that one GPU worker creates a large dataset and wants to shared
it with other GPU workers without copying the memory.
"""
# only create the list from rank 0
if comm.get_rank() == 0:
lst = ["old", 2, (3,)]
else:
lst = "whatever, this won't be used"

# create the shared list
shared_lst = SharedList(lst, _allow_inplace_update=True)
del lst

# now the list should be available on all ranks
_check_and_modify(self, shared_lst)

def test_shared_list_pass_to_child_processes(self):
"""
This test mimics that a large dataset is created on GPU worker, the dataset is
then passed to the data loader worker without copying the memory.
"""
lst = ["old", 2, (3,)]
parent_proc_shared_lst = SharedList(lst, _allow_inplace_update=True)
del lst

# launch child processes and pass the shared object
dh.launch(
_check_and_modify,
num_processes_per_machine=2,
backend="GLOO",
args=(PicklableWrapper(self), parent_proc_shared_lst),
)

# since child process share the object with parent, the object should be modified
self.assertEqual(parent_proc_shared_lst[0], "new")

@dh.launch_deco(num_processes=2)
def test_unevenly_close(self):
"""
Test the SharedList can still be accessed from non-master process after master
process has finished using it.
"""
lst = SharedList([1, 2, 3])
# creation of SharedList should handle synchronization, don't need to call it here.
rank = comm.get_rank()
time.sleep(rank) # mimic a workload that master process finishes first
# rank 1 can still access the list even after the list has been deleted from rank 0.
print(f"rank {rank} got: {(x := lst[0])}")
self.assertEqual(x, 1)
del lst
176 changes: 176 additions & 0 deletions mobile_cv/torch/utils_pytorch/shareables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging
import os
import pickle
from multiprocessing import shared_memory
from typing import Any, List, Tuple, Union

import mobile_cv.torch.utils_pytorch.comm as comm
import numpy as np

logger = logging.getLogger(__name__)


class _SharedMemoryRef(object):
"""Deal with the clean up of shared memory"""

def __init__(self, shm: shared_memory.SharedMemory, owner_pid: int):
self.shm = shm
self.owner_pid = owner_pid

def __del__(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python does not guarantee to always call __del__ upon exit - would that cause memory leak then?

There is a SharedMemoryManager in python that uses ResourceTracker - which starts a separate process whose only job is to reliably free the shm.. I haven't used it but may be relevant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ppwwyyxx I found that the SharedMemoryManager is a bit slow to start (it starts a new process) from GPU worker process whereas SharedMemory is instant, so I ended up not using it. Wondering what makes SharedMemoryManager more reliable for freeing shm, does it use refcount such that when all reference of shm goes out of context, the memory will be freed? I'm under the impression that SharedMemory can do the same, if calling x, _ = share_numpy_array_locally(x) with current implementation plus removing the __del__, when trying to access x, it'll raise segmentation fault because the shm has gone out-of-context and refcount is 0.

Copy link

@ppwwyyxx ppwwyyxx Nov 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main problem with refcount is that it relies on executing Python code. When Python exits ungracefully, refcount doesn't work and __del__ won't be called. ResourceTracker starts a separate Python process so it won't be affected by crashes in other Python process.

Took another look at SharedMemory. It seems when you create a SharedMemory, it already uses ResourceTracker: https://github.com/python/cpython/blob/c0859743d9ad3bbd4c021200f4162cfeadc0c17a/Lib/multiprocessing/shared_memory.py#L120 , which will start a new process. So there is probably no leak.

It seems SharedMemoryManager starts more processes than ResourceTracker - maybe that's why it's slow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I ran some more benchmarks (multiple evaluations during training which creates test loader), there seems no memory leak.

self.shm.close() # all instances should call close()
if os.getpid() == self.owner_pid:
self.shm.unlink() # destroy the underlying shared memory block
Comment on lines +25 to +26
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like if master's refcount goes to 0 first, but workers are still using the data, bad memory access could happen.

Not sure if this could happen. Maybe it can happen if master throws an exception while other workers are still running.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read about this comment saying it is safe to call shm.unlink() before all handles are closed. I tried the following and it doesn't crash:

lst = SharedList([1, 2, 3])
time.sleep(comm.get_rank())  # mimic the workload that master process finishes first
print(lst[0])
del lst



def share_numpy_array_locally(
data: Union[np.ndarray, None],
) -> Tuple[np.ndarray, _SharedMemoryRef]:
"""
Helper function to create memory-shared numpy array.

Args:
data: the original data, the data provided by non-local master process will
be discarded.
Returns:
new_data: a shared numpy array equal to the original one provided from master
process. Note that the memory of numpy array might still be copied if
passing it to child process, in those case it's better to user `shm`.
shm: the underlying shared memory, the caller needs to hold this object in
order to prevent it from GC-ed.
"""

if not isinstance(data, (np.ndarray, type(None))):
raise TypeError(f"Unexpected data type: {type(data)}")

new_data = None
shm = None
master_rank_pid = None

if comm.get_local_rank() == 0:
if not isinstance(data, np.ndarray):
raise ValueError(
f"Data must be provided from local master rank (rank: {comm.get_rank()}"
)
# create a new shared memory using the original data
shm = shared_memory.SharedMemory(create=True, size=data.nbytes)
master_rank_pid = os.getpid()
logger.info(f"Moving data to shared memory ({shm}) ...")
new_data = np.ndarray(data.shape, dtype=data.dtype, buffer=shm.buf)
new_data[:] = data[:]
shared_data_info = (data.shape, data.dtype, shm.name, master_rank_pid)
# maybe release the memory held by the original data?
else:
if data is not None:
raise ValueError(
f"Data must be None for non local master rank (rank: {comm.get_rank()}"
)
shared_data_info = None

# broadcast the shared memory name
shared_data_info_list = comm.all_gather(shared_data_info)
local_master_rank = (
comm.get_rank() // comm.get_local_size()
) * comm.get_local_size()
shared_data_info = shared_data_info_list[local_master_rank]
assert shared_data_info is not None

# create new data from shared memory
if not comm.get_local_rank() == 0:
shape, dtype, name, master_rank_pid = shared_data_info
shm = shared_memory.SharedMemory(name=name)
logger.info(f"Attaching to the existing shared memory ({shm}) ...")
new_data = np.ndarray(shape, dtype=dtype, buffer=shm.buf)

# synchronize before returning to make sure data are usable
comm.synchronize()
assert isinstance(new_data, np.ndarray)
assert isinstance(shm, shared_memory.SharedMemory)
assert isinstance(master_rank_pid, int)
return new_data, _SharedMemoryRef(shm, master_rank_pid)


class SharedList(object):
"""
List-like read-only object shared between all (local) processes, backed by
multiprocessing.shared_memory (requires Python 3.8+).
"""

def __init__(
self, lst: Union[List[Any], Any], *, _allow_inplace_update: bool = False
):
"""
Args:
lst (list or None): a list of serializable objects.
"""

self._allow_inplace_update = _allow_inplace_update

def _serialize(data):
buffer = pickle.dumps(data, protocol=-1)
return np.frombuffer(buffer, dtype=np.uint8)

logger.info(
"Serializing {} elements to byte tensors and concatenating them all ...".format(
len(lst)
)
)
if comm.get_local_rank() == 0:
lst = [_serialize(x) for x in lst]
addr = np.asarray([len(x) for x in lst], dtype=np.int64)
addr = np.cumsum(addr)
lst = np.concatenate(lst)
logger.info(
"Serialized dataset takes {:.2f} MiB".format(len(lst) / 1024**2)
)
else:
addr = None
lst = None

logger.info("Moving serialized dataset to shared memory ...")
# keep the returned shared memory to prevent it from GC-ed
_, self._lst_shm_ref = share_numpy_array_locally(lst)
self._addr, self._addr_shm_ref = share_numpy_array_locally(addr)
logger.info("Finished moving to shared memory")

def _calculate_addr_range(self, idx: int) -> Tuple[int, int]:
start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
end_addr = self._addr[idx].item()
return start_addr, end_addr

def __len__(self):
return len(self._addr)

def __getitem__(self, idx):
start_addr, end_addr = self._calculate_addr_range(idx)
# @lint-ignore PYTHONPICKLEISBAD
return pickle.loads(self._lst_shm_ref.shm.buf[start_addr:end_addr])

def __setitem__(self, idx, value):
# Normally user shouldn't update the stored data since this class is designed to
# be read-only, in rare cases where user knows that the size of data would be
# the same, it might be helpful to update the stored data.
if not self._allow_inplace_update:
raise RuntimeError("Update item from SharedList is not allowed!")
# NOTE: Currently user should be responsible for dealing with race-condition.
start_addr, end_addr = self._calculate_addr_range(idx)
nbytes = end_addr - start_addr
new_bytes = pickle.dumps(value, protocol=-1)
if len(new_bytes) != nbytes:
raise ValueError(
f"Can't replace the original object ({nbytes} bytes) with one that has"
f" different size ({len(new_bytes)} bytes)!"
)
self._lst_shm_ref.shm.buf[start_addr:end_addr] = new_bytes


class SharedDict(object):
"""
Dict-like read-only object shared between all (local) processes, backed by
multiprocessing.shared_memory (requires Python 3.8+).
"""

# TODO: we can support dict in a similar way if needed