From c864c91065abcd71b4610a9803fc68e1be8b11d2 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Thu, 18 Jan 2024 01:51:18 -0500 Subject: [PATCH 1/5] [GraphBolt][CUDA] Inplace pin memory for Graph and TorchFeatureStore (#6962) --- .../impl/fused_csc_sampling_graph.py | 37 ++++++++++++++++++- .../impl/torch_based_feature_store.py | 31 +++++++++++++++- .../impl/test_fused_csc_sampling_graph.py | 4 ++ .../impl/test_torch_based_feature_store.py | 3 ++ 4 files changed, 72 insertions(+), 3 deletions(-) diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index 8f026b3c5095..09df1e9e5799 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -34,6 +34,17 @@ def __init__( ): super().__init__() self._c_csc_graph = c_csc_graph + self._is_inplace_pinned = set() + + def __del__(self): + # torch.Tensor.pin_memory() is not an inplace operation. To make it + # truly in-place, we need to use cudaHostRegister. Then, we need to use + # cudaHostUnregister to unpin the tensor in the destructor. + # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 + for tensor in self._is_inplace_pinned: + assert ( + torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0 + ) @property def total_num_nodes(self) -> int: @@ -974,9 +985,33 @@ def _pin(x): def pin_memory_(self): """Copy `FusedCSCSamplingGraph` to the pinned memory in-place.""" + # torch.Tensor.pin_memory() is not an inplace operation. To make it + # truly in-place, we need to use cudaHostRegister. Then, we need to use + # cudaHostUnregister to unpin the tensor in the destructor. + # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 + cudart = torch.cuda.cudart() def _pin(x): - return x.pin_memory() if hasattr(x, "pin_memory") else x + if hasattr(x, "pin_memory_"): + x.pin_memory_() + elif ( + isinstance(x, torch.Tensor) + and not x.is_pinned() + and x.device.type == "cpu" + ): + assert ( + x.is_contiguous() + ), "Tensor pinning is only supported for contiguous tensors." + assert ( + cudart.cudaHostRegister( + x.data_ptr(), x.numel() * x.element_size(), 0 + ) + == 0 + ) + + self._is_inplace_pinned.add(x) + + return x self._apply_to_members(_pin) diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py index 3952eb0a84b4..af77912ec9d5 100644 --- a/python/dgl/graphbolt/impl/torch_based_feature_store.py +++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py @@ -83,6 +83,17 @@ def __init__(self, torch_feature: torch.Tensor, metadata: Dict = None): # Make sure the tensor is contiguous. self._tensor = torch_feature.contiguous() self._metadata = metadata + self._is_inplace_pinned = set() + + def __del__(self): + # torch.Tensor.pin_memory() is not an inplace operation. To make it + # truly in-place, we need to use cudaHostRegister. Then, we need to use + # cudaHostUnregister to unpin the tensor in the destructor. + # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 + for tensor in self._is_inplace_pinned: + assert ( + torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0 + ) def read(self, ids: torch.Tensor = None): """Read the feature by index. @@ -169,14 +180,30 @@ def metadata(self): def pin_memory_(self): """In-place operation to copy the feature to pinned memory.""" - self._tensor = self._tensor.pin_memory() + # torch.Tensor.pin_memory() is not an inplace operation. To make it + # truly in-place, we need to use cudaHostRegister. Then, we need to use + # cudaHostUnregister to unpin the tensor in the destructor. + # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 + x = self._tensor + if not x.is_pinned() and x.device.type == "cpu": + assert ( + x.is_contiguous() + ), "Tensor pinning is only supported for contiguous tensors." + assert ( + torch.cuda.cudart().cudaHostRegister( + x.data_ptr(), x.numel() * x.element_size(), 0 + ) + == 0 + ) + + self._is_inplace_pinned.add(x) def to(self, device): # pylint: disable=invalid-name """Copy `TorchBasedFeature` to the specified device.""" # copy.copy is a shallow copy so it does not copy tensor memory. self2 = copy.copy(self) if device == "pinned": - self2.pin_memory_() + self2._tensor = self2._tensor.pin_memory() else: self2._tensor = self2._tensor.to(device) return self2 diff --git a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py index b2f240e6279b..cb4035c62ba7 100644 --- a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py +++ b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py @@ -1601,10 +1601,14 @@ def test_csc_sampling_graph_to_device(device): def test_csc_sampling_graph_to_pinned_memory(): # Construct FusedCSCSamplingGraph. graph = create_fused_csc_sampling_graph() + ptr = graph.csc_indptr.data_ptr() # Copy to pinned_memory in-place. graph.pin_memory_() + # Check if pinning is truly in-place. + assert graph.csc_indptr.data_ptr() == ptr + is_graph_on_device_type(graph, "cpu") is_graph_pinned(graph) diff --git a/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py b/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py index be4b43b79461..ff7aa8f912e6 100644 --- a/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py +++ b/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py @@ -221,6 +221,9 @@ def test_torch_based_pinned_feature(dtype, idtype, shape): feature = gb.TorchBasedFeature(tensor) feature.pin_memory_() + # Check if pinning is truly in-place. + assert feature._tensor.data_ptr() == tensor.data_ptr() + # Test read entire pinned feature, the result should be on cuda. assert torch.equal(feature.read(), test_tensor_cuda) assert feature.read().is_cuda From 173257b37aa893c46ba4fd1543acec347b965eaa Mon Sep 17 00:00:00 2001 From: yxy235 <77922129+yxy235@users.noreply.github.com> Date: Thu, 18 Jan 2024 15:58:44 +0800 Subject: [PATCH 2/5] [GraphBolt] Add `seeds` to MiniBatch. (#6968) Co-authored-by: Ubuntu --- python/dgl/graphbolt/minibatch.py | 26 +++ .../pytorch/graphbolt/impl/test_minibatch.py | 11 +- .../pytorch/graphbolt/test_integration.py | 18 +- .../pytorch/graphbolt/test_item_sampler.py | 186 ++++++++++++++++++ 4 files changed, 231 insertions(+), 10 deletions(-) diff --git a/python/dgl/graphbolt/minibatch.py b/python/dgl/graphbolt/minibatch.py index ec7ead0c36a8..145ef55550d2 100644 --- a/python/dgl/graphbolt/minibatch.py +++ b/python/dgl/graphbolt/minibatch.py @@ -54,6 +54,32 @@ class MiniBatch: value should be corresponding labels to given 'seed_nodes' or 'node_pairs'. """ + seeds: Union[ + torch.Tensor, + Dict[str, torch.Tensor], + ] = None + """ + Representation of seed items utilized in node classification tasks, link + prediction tasks and hyperlinks tasks. + - If `seeds` is a tensor: it indicates that the seeds originate from a + homogeneous graph. It can be either a 1-dimensional or 2-dimensional + tensor: + - 1-dimensional tensor: Each element directly represents a seed node + within the graph. + - 2-dimensional tensor: Each row designates a seed item, which can + encompass various entities such as edges, hyperlinks, or other graph + components depending on the specific context. + - If `seeds` is a dictionary: it indicates that the seeds originate from a + heterogeneous graph. The keys should be edge or node type, and the value + should be a tensor, which can be either a 1-dimensional or 2-dimensional + tensor: + - 1-dimensional tensor: Each element directly represents a seed node + of the given type within the graph. + - 2-dimensional tensor: Each row designates a seed item of the given + type, which can encompass various entities such as edges, hyperlinks, + or other graph components depending on the specific context. + """ + negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None """ Representation of negative samples for the head nodes in the link diff --git a/tests/python/pytorch/graphbolt/impl/test_minibatch.py b/tests/python/pytorch/graphbolt/impl/test_minibatch.py index 79e75df6bb56..fed1ea2d0105 100644 --- a/tests/python/pytorch/graphbolt/impl/test_minibatch.py +++ b/tests/python/pytorch/graphbolt/impl/test_minibatch.py @@ -58,7 +58,8 @@ def test_minibatch_representation_homo(): # Test minibatch without data. minibatch = gb.MiniBatch() expect_result = str( - """MiniBatch(seed_nodes=None, + """MiniBatch(seeds=None, + seed_nodes=None, sampled_subgraphs=None, positive_node_pairs=None, node_pairs_with_labels=None, @@ -77,7 +78,7 @@ def test_minibatch_representation_homo(): )""" ) result = str(minibatch) - assert result == expect_result, print(len(expect_result), len(result)) + assert result == expect_result, print(expect_result, result) # Test minibatch with all attributes. minibatch = gb.MiniBatch( node_pairs=csc_formats, @@ -93,7 +94,8 @@ def test_minibatch_representation_homo(): compacted_negative_dsts=compacted_negative_dsts, ) expect_result = str( - """MiniBatch(seed_nodes=None, + """MiniBatch(seeds=None, + seed_nodes=None, sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]), indices=tensor([0, 1, 2, 2, 1, 2]), ), @@ -242,7 +244,8 @@ def test_minibatch_representation_hetero(): compacted_negative_dsts=compacted_negative_dsts, ) expect_result = str( - """MiniBatch(seed_nodes={'B': tensor([10, 15])}, + """MiniBatch(seeds=None, + seed_nodes={'B': tensor([10, 15])}, sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]), indices=tensor([0, 1, 1]), ), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]), diff --git a/tests/python/pytorch/graphbolt/test_integration.py b/tests/python/pytorch/graphbolt/test_integration.py index bea5f234869c..20c0083f7e3b 100644 --- a/tests/python/pytorch/graphbolt/test_integration.py +++ b/tests/python/pytorch/graphbolt/test_integration.py @@ -60,7 +60,8 @@ def test_integration_link_prediction(): ) expected = [ str( - """MiniBatch(seed_nodes=None, + """MiniBatch(seeds=None, + seed_nodes=None, sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]), indices=tensor([0, 4]), ), @@ -116,7 +117,8 @@ def test_integration_link_prediction(): )""" ), str( - """MiniBatch(seed_nodes=None, + """MiniBatch(seeds=None, + seed_nodes=None, sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3]), indices=tensor([4, 1, 0]), ), @@ -172,7 +174,8 @@ def test_integration_link_prediction(): )""" ), str( - """MiniBatch(seed_nodes=None, + """MiniBatch(seeds=None, + seed_nodes=None, sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2]), indices=tensor([1, 0]), ), @@ -276,7 +279,8 @@ def test_integration_node_classification(): ) expected = [ str( - """MiniBatch(seed_nodes=None, + """MiniBatch(seeds=None, + seed_nodes=None, sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]), indices=tensor([4, 1, 0, 1]), ), @@ -317,7 +321,8 @@ def test_integration_node_classification(): )""" ), str( - """MiniBatch(seed_nodes=None, + """MiniBatch(seeds=None, + seed_nodes=None, sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]), indices=tensor([0, 2]), ), @@ -356,7 +361,8 @@ def test_integration_node_classification(): )""" ), str( - """MiniBatch(seed_nodes=None, + """MiniBatch(seeds=None, + seed_nodes=None, sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]), indices=tensor([0, 2]), ), diff --git a/tests/python/pytorch/graphbolt/test_item_sampler.py b/tests/python/pytorch/graphbolt/test_item_sampler.py index 3308053fea2c..6f4d072c0157 100644 --- a/tests/python/pytorch/graphbolt/test_item_sampler.py +++ b/tests/python/pytorch/graphbolt/test_item_sampler.py @@ -376,6 +376,86 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last): assert torch.all(negs_ids[:-1, 1] <= negs_ids[1:, 1]) is not shuffle +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize("drop_last", [True, False]) +def test_ItemSet_seeds(batch_size, shuffle, drop_last): + # Node pairs. + num_ids = 103 + seeds = torch.arange(0, 3 * num_ids).reshape(-1, 3) + item_set = gb.ItemSet(seeds, names="seeds") + item_sampler = gb.ItemSampler( + item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last + ) + seeds_ids = [] + for i, minibatch in enumerate(item_sampler): + assert minibatch.seeds is not None + assert isinstance(minibatch.seeds, torch.Tensor) + assert minibatch.labels is None + is_last = (i + 1) * batch_size >= num_ids + if not is_last or num_ids % batch_size == 0: + expected_batch_size = batch_size + else: + if not drop_last: + expected_batch_size = num_ids % batch_size + else: + assert False + assert minibatch.seeds.shape == (expected_batch_size, 3) + # Verify seeds match. + assert torch.equal(minibatch.seeds[:, 0] + 1, minibatch.seeds[:, 1]) + assert torch.equal(minibatch.seeds[:, 1] + 1, minibatch.seeds[:, 2]) + # Archive batch. + seeds_ids.append(minibatch.seeds) + seeds_ids = torch.cat(seeds_ids) + assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle + assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle + assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize("drop_last", [True, False]) +def test_ItemSet_seeds_labels(batch_size, shuffle, drop_last): + # Node pairs and labels + num_ids = 103 + seeds = torch.arange(0, 3 * num_ids).reshape(-1, 3) + labels = seeds[:, 0] + item_set = gb.ItemSet((seeds, labels), names=("seeds", "labels")) + item_sampler = gb.ItemSampler( + item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last + ) + seeds_ids = [] + labels = [] + for i, minibatch in enumerate(item_sampler): + assert minibatch.seeds is not None + assert isinstance(minibatch.seeds, torch.Tensor) + assert minibatch.labels is not None + label = minibatch.labels + assert len(minibatch.seeds) == len(label) + is_last = (i + 1) * batch_size >= num_ids + if not is_last or num_ids % batch_size == 0: + expected_batch_size = batch_size + else: + if not drop_last: + expected_batch_size = num_ids % batch_size + else: + assert False + assert minibatch.seeds.shape == (expected_batch_size, 3) + assert len(label) == expected_batch_size + # Verify seeds and labels match. + assert torch.equal(minibatch.seeds[:, 0] + 1, minibatch.seeds[:, 1]) + assert torch.equal(minibatch.seeds[:, 1] + 1, minibatch.seeds[:, 2]) + # Archive batch. + seeds_ids.append(minibatch.seeds) + labels.append(label) + seeds_ids = torch.cat(seeds_ids) + labels = torch.cat(labels) + assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle + assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle + assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle + assert torch.all(labels[:-1] <= labels[1:]) is not shuffle + + def test_append_with_other_datapipes(): num_ids = 100 batch_size = 4 @@ -723,6 +803,112 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last): assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize("drop_last", [True, False]) +def test_ItemSetDict_seeds(batch_size, shuffle, drop_last): + # Node pairs. + num_ids = 103 + total_pairs = 2 * num_ids + seeds_like = torch.arange(0, num_ids * 3).reshape(-1, 3) + seeds_follow = torch.arange(num_ids * 3, num_ids * 6).reshape(-1, 3) + seeds_dict = { + "user:like:item": gb.ItemSet(seeds_like, names="seeds"), + "user:follow:user": gb.ItemSet(seeds_follow, names="seeds"), + } + item_set = gb.ItemSetDict(seeds_dict) + item_sampler = gb.ItemSampler( + item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last + ) + seeds_ids = [] + for i, minibatch in enumerate(item_sampler): + assert isinstance(minibatch, gb.MiniBatch) + assert minibatch.seeds is not None + assert minibatch.labels is None + is_last = (i + 1) * batch_size >= total_pairs + if not is_last or total_pairs % batch_size == 0: + expected_batch_size = batch_size + else: + if not drop_last: + expected_batch_size = total_pairs % batch_size + else: + assert False + seeds_lst = [] + for _, (seeds) in minibatch.seeds.items(): + assert isinstance(seeds, torch.Tensor) + seeds_lst.append(seeds) + seeds_lst = torch.cat(seeds_lst) + assert seeds_lst.shape == (expected_batch_size, 3) + seeds_ids.append(seeds_lst) + assert torch.equal(seeds_lst[:, 0] + 1, seeds_lst[:, 1]) + assert torch.equal(seeds_lst[:, 1] + 1, seeds_lst[:, 2]) + seeds_ids = torch.cat(seeds_ids) + assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle + assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle + assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize("drop_last", [True, False]) +def test_ItemSetDict_seeds_labels(batch_size, shuffle, drop_last): + # Node pairs and labels + num_ids = 103 + total_ids = 2 * num_ids + seeds_like = torch.arange(0, num_ids * 3).reshape(-1, 3) + seeds_follow = torch.arange(num_ids * 3, num_ids * 6).reshape(-1, 3) + seeds_dict = { + "user:like:item": gb.ItemSet( + (seeds_like, seeds_like[:, 0]), + names=("seeds", "labels"), + ), + "user:follow:user": gb.ItemSet( + (seeds_follow, seeds_follow[:, 0]), + names=("seeds", "labels"), + ), + } + item_set = gb.ItemSetDict(seeds_dict) + item_sampler = gb.ItemSampler( + item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last + ) + seeds_ids = [] + labels = [] + for i, minibatch in enumerate(item_sampler): + assert isinstance(minibatch, gb.MiniBatch) + assert minibatch.seeds is not None + assert minibatch.labels is not None + is_last = (i + 1) * batch_size >= total_ids + if not is_last or total_ids % batch_size == 0: + expected_batch_size = batch_size + else: + if not drop_last: + expected_batch_size = total_ids % batch_size + else: + assert False + seeds_lst = [] + label = [] + for _, seeds in minibatch.seeds.items(): + assert isinstance(seeds, torch.Tensor) + seeds_lst.append(seeds) + for _, v_label in minibatch.labels.items(): + label.append(v_label) + seeds_lst = torch.cat(seeds_lst) + label = torch.cat(label) + assert seeds_lst.shape == (expected_batch_size, 3) + assert len(label) == expected_batch_size + seeds_ids.append(seeds_lst) + labels.append(label) + assert torch.equal(seeds_lst[:, 0] + 1, seeds_lst[:, 1]) + assert torch.equal(seeds_lst[:, 1] + 1, seeds_lst[:, 2]) + assert torch.equal(seeds_lst[:, 0], label) + seeds_ids = torch.cat(seeds_ids) + labels = torch.cat(labels) + assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle + assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle + assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle + assert torch.all(labels[:-1] <= labels[1:]) is not shuffle + + def distributed_item_sampler_subprocess( proc_id, nprocs, From 528b041c51aae91afb7b40c031010f24cfcd3cf8 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Thu, 18 Jan 2024 02:58:54 -0500 Subject: [PATCH 3/5] [GraphBolt][CUDA] Overlap feature fetcher (#6954) --- graphbolt/CMakeLists.txt | 1 + graphbolt/src/cuda/index_select_impl.cu | 13 +- graphbolt/src/cuda/max_uva_threads.cc | 15 +++ graphbolt/src/cuda/max_uva_threads.h | 24 ++++ graphbolt/src/python_binding.cc | 6 + python/dgl/graphbolt/dataloader.py | 117 +++++++++++++++++- python/dgl/graphbolt/feature_fetcher.py | 71 ++++++++--- .../pytorch/graphbolt/test_dataloader.py | 45 ++++++- 8 files changed, 266 insertions(+), 26 deletions(-) create mode 100644 graphbolt/src/cuda/max_uva_threads.cc create mode 100644 graphbolt/src/cuda/max_uva_threads.h diff --git a/graphbolt/CMakeLists.txt b/graphbolt/CMakeLists.txt index 8ffbc2a5cf82..8f377292e368 100644 --- a/graphbolt/CMakeLists.txt +++ b/graphbolt/CMakeLists.txt @@ -52,6 +52,7 @@ file(GLOB BOLT_SRC ${BOLT_DIR}/*.cc) if(USE_CUDA) file(GLOB BOLT_CUDA_SRC ${BOLT_DIR}/cuda/*.cu + ${BOLT_DIR}/cuda/*.cc ) list(APPEND BOLT_SRC ${BOLT_CUDA_SRC}) if(DEFINED ENV{CUDAARCHS}) diff --git a/graphbolt/src/cuda/index_select_impl.cu b/graphbolt/src/cuda/index_select_impl.cu index af2c9fe96a24..389d2430f227 100644 --- a/graphbolt/src/cuda/index_select_impl.cu +++ b/graphbolt/src/cuda/index_select_impl.cu @@ -10,6 +10,7 @@ #include #include "./common.h" +#include "./max_uva_threads.h" #include "./utils.h" namespace graphbolt { @@ -122,17 +123,23 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) { if (aligned_feature_size == 1) { // Use a single thread to process each output row to avoid wasting threads. const int num_threads = cuda::FindNumThreads(return_len); - const int num_blocks = (return_len + num_threads - 1) / num_threads; + const int num_blocks = + (std::min(return_len, cuda::max_uva_threads.value_or(1 << 20)) + + num_threads - 1) / + num_threads; CUDA_KERNEL_CALL( IndexSelectSingleKernel, num_blocks, num_threads, 0, input_ptr, input_len, index_sorted_ptr, return_len, ret_ptr, permutation_ptr); } else { - dim3 block(512, 1); + constexpr int BLOCK_SIZE = 512; + dim3 block(BLOCK_SIZE, 1); while (static_cast(block.x) >= 2 * aligned_feature_size) { block.x >>= 1; block.y <<= 1; } - const dim3 grid((return_len + block.y - 1) / block.y); + const dim3 grid(std::min( + (return_len + block.y - 1) / block.y, + cuda::max_uva_threads.value_or(1 << 20) / BLOCK_SIZE)); if (aligned_feature_size * sizeof(DType) <= GPU_CACHE_LINE_SIZE) { // When feature size is smaller than GPU cache line size, use unaligned // version for less SM usage, which is more resource efficient. diff --git a/graphbolt/src/cuda/max_uva_threads.cc b/graphbolt/src/cuda/max_uva_threads.cc new file mode 100644 index 000000000000..de8a1fffc023 --- /dev/null +++ b/graphbolt/src/cuda/max_uva_threads.cc @@ -0,0 +1,15 @@ +/** + * Copyright (c) 2023 by Contributors + * Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) + * @file cuda/max_uva_threads.cc + * @brief Max uva threads variable setter function. + */ +#include "./max_uva_threads.h" + +namespace graphbolt { +namespace cuda { + +void set_max_uva_threads(int64_t count) { max_uva_threads = count; } + +} // namespace cuda +} // namespace graphbolt diff --git a/graphbolt/src/cuda/max_uva_threads.h b/graphbolt/src/cuda/max_uva_threads.h new file mode 100644 index 000000000000..b33718a7bcb5 --- /dev/null +++ b/graphbolt/src/cuda/max_uva_threads.h @@ -0,0 +1,24 @@ +/** + * Copyright (c) 2023 by Contributors + * Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) + * @file cuda/max_uva_threads.h + * @brief Max uva threads variable declaration. + */ +#ifndef GRAPHBOLT_MAX_UVA_THREADS_H_ +#define GRAPHBOLT_MAX_UVA_THREADS_H_ + +#include +#include + +namespace graphbolt { +namespace cuda { + +/** @brief Set a limit on the number of CUDA threads for UVA accesses. */ +inline std::optional max_uva_threads; + +void set_max_uva_threads(int64_t count); + +} // namespace cuda +} // namespace graphbolt + +#endif // GRAPHBOLT_MAX_UVA_THREADS_H_ diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index 44b6306d890d..d44bccf66cc9 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -9,6 +9,9 @@ #include #include +#ifdef GRAPHBOLT_USE_CUDA +#include "./cuda/max_uva_threads.h" +#endif #include "./index_select.h" #include "./random.h" @@ -75,6 +78,9 @@ TORCH_LIBRARY(graphbolt, m) { m.def("index_select", &ops::IndexSelect); m.def("index_select_csc", &ops::IndexSelectCSC); m.def("set_seed", &RandomEngine::SetManualSeed); +#ifdef GRAPHBOLT_USE_CUDA + m.def("set_max_uva_threads", &cuda::set_max_uva_threads); +#endif } } // namespace sampling diff --git a/python/dgl/graphbolt/dataloader.py b/python/dgl/graphbolt/dataloader.py index a708f867dc0e..d19778494847 100644 --- a/python/dgl/graphbolt/dataloader.py +++ b/python/dgl/graphbolt/dataloader.py @@ -1,5 +1,8 @@ """Graph Bolt DataLoaders""" +from queue import Queue + +import torch import torch.utils.data import torchdata.dataloader2.graph as dp_utils import torchdata.datapipes as dp @@ -35,6 +38,62 @@ def _find_and_wrap_parent( ) +class EndMarker(dp.iter.IterDataPipe): + """Used to mark the end of a datapipe and is a no-op.""" + + def __init__(self, datapipe): + self.datapipe = datapipe + + def __iter__(self): + for data in self.datapipe: + yield data + + +class Bufferer(dp.iter.IterDataPipe): + """Buffers items before yielding them. + + Parameters + ---------- + datapipe : DataPipe + The data pipeline. + buffer_size : int, optional + The size of the buffer which stores the fetched samples. If data coming + from datapipe has latency spikes, consider increasing passing a high + value. Default is 2. + """ + + def __init__(self, datapipe, buffer_size=2): + self.datapipe = datapipe + if buffer_size <= 0: + raise ValueError( + "'buffer_size' is required to be a positive integer." + ) + self.buffer = Queue(buffer_size) + + def __iter__(self): + for data in self.datapipe: + if not self.buffer.full(): + self.buffer.put(data) + else: + return_data = self.buffer.get() + self.buffer.put(data) + yield return_data + while not self.buffer.empty(): + yield self.buffer.get() + + +class Awaiter(dp.iter.IterDataPipe): + """Calls the wait function of all items.""" + + def __init__(self, datapipe): + self.datapipe = datapipe + + def __iter__(self): + for data in self.datapipe: + data.wait() + yield data + + class MultiprocessingWrapper(dp.iter.IterDataPipe): """Wraps a datapipe with multiprocessing. @@ -64,6 +123,14 @@ def __iter__(self): yield from self.dataloader +# There needs to be a single instance of the uva_stream, if it is created +# multiple times, it leads to multiple CUDA memory pools and memory leaks. +def _get_uva_stream(): + if not hasattr(_get_uva_stream, "stream"): + _get_uva_stream.stream = torch.cuda.Stream(priority=-1) + return _get_uva_stream.stream + + class DataLoader(torch.utils.data.DataLoader): """Multiprocessing DataLoader. @@ -84,9 +151,26 @@ class DataLoader(torch.utils.data.DataLoader): If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers instances alive. + overlap_feature_fetch : bool, optional + If True, the data loader will overlap the UVA feature fetcher operations + with the rest of operations by using an alternative CUDA stream. Default + is True. + max_uva_threads : int, optional + Limits the number of CUDA threads used for UVA copies so that the rest + of the computations can run simultaneously with it. Setting it to a too + high value will limit the amount of overlap while setting it too low may + cause the PCI-e bandwidth to not get fully utilized. Manually tuned + default is 6144, meaning around 3-4 Streaming Multiprocessors. """ - def __init__(self, datapipe, num_workers=0, persistent_workers=True): + def __init__( + self, + datapipe, + num_workers=0, + persistent_workers=True, + overlap_feature_fetch=True, + max_uva_threads=6144, + ): # Multiprocessing requires two modifications to the datapipe: # # 1. Insert a stage after ItemSampler to distribute the @@ -94,6 +178,7 @@ def __init__(self, datapipe, num_workers=0, persistent_workers=True): # 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe # of the FeatureFetcher with a multiprocessing PyTorch DataLoader. + datapipe = EndMarker(datapipe) datapipe_graph = dp_utils.traverse_dps(datapipe) datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph) @@ -122,7 +207,35 @@ def __init__(self, datapipe, num_workers=0, persistent_workers=True): persistent_workers=persistent_workers, ) - # (3) Cut datapipe at CopyTo and wrap with prefetcher. This enables the + # (3) Overlap UVA feature fetching by buffering and using an alternative + # stream. + if ( + overlap_feature_fetch + and num_workers == 0 + and torch.cuda.is_available() + ): + torch.ops.graphbolt.set_max_uva_threads(max_uva_threads) + feature_fetchers = dp_utils.find_dps( + datapipe_graph, + FeatureFetcher, + ) + for feature_fetcher in feature_fetchers: + feature_fetcher.stream = _get_uva_stream() + _find_and_wrap_parent( + datapipe_graph, + datapipe_adjlist, + EndMarker, + Bufferer, + buffer_size=2, + ) + _find_and_wrap_parent( + datapipe_graph, + datapipe_adjlist, + EndMarker, + Awaiter, + ) + + # (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the # data pipeline up to the CopyTo operation to run in a separate thread. _find_and_wrap_parent( datapipe_graph, diff --git a/python/dgl/graphbolt/feature_fetcher.py b/python/dgl/graphbolt/feature_fetcher.py index 9f1796362991..7b9ed6817afb 100644 --- a/python/dgl/graphbolt/feature_fetcher.py +++ b/python/dgl/graphbolt/feature_fetcher.py @@ -2,6 +2,8 @@ from typing import Dict +import torch + from torch.utils.data import functional_datapipe from .base import etype_tuple_to_str @@ -52,8 +54,9 @@ def __init__( self.feature_store = feature_store self.node_feature_keys = node_feature_keys self.edge_feature_keys = edge_feature_keys + self.stream = None - def _read(self, data): + def _read_data(self, data, stream): """ Fill in the node/edge features field in data. @@ -77,6 +80,12 @@ def _read(self, data): ) or isinstance(self.edge_feature_keys, Dict) # Read Node features. input_nodes = data.node_ids() + + def record_stream(tensor): + if stream is not None and tensor.is_cuda: + tensor.record_stream(stream) + return tensor + if self.node_feature_keys and input_nodes is not None: if is_heterogeneous: for type_name, feature_names in self.node_feature_keys.items(): @@ -86,19 +95,23 @@ def _read(self, data): for feature_name in feature_names: node_features[ (type_name, feature_name) - ] = self.feature_store.read( - "node", - type_name, - feature_name, - nodes, + ] = record_stream( + self.feature_store.read( + "node", + type_name, + feature_name, + nodes, + ) ) else: for feature_name in self.node_feature_keys: - node_features[feature_name] = self.feature_store.read( - "node", - None, - feature_name, - input_nodes, + node_features[feature_name] = record_stream( + self.feature_store.read( + "node", + None, + feature_name, + input_nodes, + ) ) # Read Edge features. if self.edge_feature_keys and num_layers > 0: @@ -124,19 +137,37 @@ def _read(self, data): for feature_name in feature_names: edge_features[i][ (type_name, feature_name) - ] = self.feature_store.read( - "edge", type_name, feature_name, edges + ] = record_stream( + self.feature_store.read( + "edge", type_name, feature_name, edges + ) ) else: for feature_name in self.edge_feature_keys: - edge_features[i][ - feature_name - ] = self.feature_store.read( - "edge", - None, - feature_name, - original_edge_ids, + edge_features[i][feature_name] = record_stream( + self.feature_store.read( + "edge", + None, + feature_name, + original_edge_ids, + ) ) data.set_node_features(node_features) data.set_edge_features(edge_features) return data + + def _read(self, data): + current_stream = None + if self.stream is not None: + current_stream = torch.cuda.current_stream() + self.stream.wait_stream(current_stream) + with torch.cuda.stream(self.stream): + data = self._read_data(data, current_stream) + if self.stream is not None: + event = torch.cuda.current_stream().record_event() + + def _wait(): + event.wait() + + data.wait = _wait + return data diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index b59260cf654e..9485c78b92b1 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -1,9 +1,11 @@ +import unittest + import backend as F import dgl import dgl.graphbolt +import pytest import torch -import torch.multiprocessing as mp from . import gb_test_utils @@ -37,3 +39,44 @@ def test_DataLoader(): num_workers=4, ) assert len(list(dataloader)) == N // B + + +@unittest.skipIf( + F._default_context_str != "gpu", + reason="This test requires the GPU.", +) +@pytest.mark.parametrize("overlap_feature_fetch", [True, False]) +def test_gpu_sampling_DataLoader(overlap_feature_fetch): + N = 40 + B = 4 + itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes") + graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True).to( + F.ctx() + ) + features = {} + keys = [("node", None, "a"), ("node", None, "b")] + features[keys[0]] = dgl.graphbolt.TorchBasedFeature( + torch.randn(200, 4, pin_memory=True) + ) + features[keys[1]] = dgl.graphbolt.TorchBasedFeature( + torch.randn(200, 4, pin_memory=True) + ) + feature_store = dgl.graphbolt.BasicFeatureStore(features) + + datapipe = dgl.graphbolt.ItemSampler(itemset, batch_size=B) + datapipe = datapipe.copy_to(F.ctx(), extra_attrs=["seed_nodes"]) + datapipe = dgl.graphbolt.NeighborSampler( + datapipe, + graph, + fanouts=[torch.LongTensor([2]) for _ in range(2)], + ) + datapipe = dgl.graphbolt.FeatureFetcher( + datapipe, + feature_store, + ["a", "b"], + ) + + dataloader = dgl.graphbolt.DataLoader( + datapipe, overlap_feature_fetch=overlap_feature_fetch + ) + assert len(list(dataloader)) == N // B From f6db850d751355858e007fb17859f1f4ec2bd48e Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Thu, 18 Jan 2024 02:59:35 -0500 Subject: [PATCH 4/5] [GraphBolt][CUDA] Add native `GPUCachedFeature` instead of using DGL (#6939) --- CMakeLists.txt | 9 ++ graphbolt/CMakeLists.txt | 5 + graphbolt/build.bat | 4 +- graphbolt/build.sh | 2 +- graphbolt/src/cuda/gpu_cache.cu | 108 ++++++++++++++++++ graphbolt/src/cuda/gpu_cache.h | 66 +++++++++++ graphbolt/src/python_binding.cc | 10 ++ python/dgl/graphbolt/impl/__init__.py | 1 + python/dgl/graphbolt/impl/gpu_cache.py | 53 +++++++++ .../dgl/graphbolt/impl/gpu_cached_feature.py | 24 ++-- .../graphbolt/impl/test_gpu_cached_feature.py | 46 +++++--- 11 files changed, 296 insertions(+), 32 deletions(-) create mode 100644 graphbolt/src/cuda/gpu_cache.cu create mode 100644 graphbolt/src/cuda/gpu_cache.h create mode 100644 python/dgl/graphbolt/impl/gpu_cache.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 7815b1459549..3661ca9b86a4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -530,6 +530,10 @@ if(BUILD_GRAPHBOLT) string(REPLACE ";" "\\;" CUDA_ARCHITECTURES_ESCAPED "${CUDA_ARCHITECTURES}") file(TO_NATIVE_PATH ${CMAKE_CURRENT_BINARY_DIR} BINDIR) file(TO_NATIVE_PATH ${CMAKE_COMMAND} CMAKE_CMD) + if(USE_CUDA) + get_target_property(GPU_CACHE_INCLUDE_DIRS gpu_cache INCLUDE_DIRECTORIES) + endif(USE_CUDA) + string(REPLACE ";" "\\;" GPU_CACHE_INCLUDE_DIRS_ESCAPED "${GPU_CACHE_INCLUDE_DIRS}") if(MSVC) file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/graphbolt/build.bat BUILD_SCRIPT) add_custom_target( @@ -540,6 +544,7 @@ if(BUILD_GRAPHBOLT) CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR} USE_CUDA=${USE_CUDA} BINDIR=${BINDIR} + GPU_CACHE_INCLUDE_DIRS="${GPU_CACHE_INCLUDE_DIRS_ESCAPED}" CFLAGS=${CMAKE_C_FLAGS} CXXFLAGS=${CMAKE_CXX_FLAGS} CUDAARCHS="${CUDA_ARCHITECTURES_ESCAPED}" @@ -557,6 +562,7 @@ if(BUILD_GRAPHBOLT) CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR} USE_CUDA=${USE_CUDA} BINDIR=${CMAKE_CURRENT_BINARY_DIR} + GPU_CACHE_INCLUDE_DIRS="${GPU_CACHE_INCLUDE_DIRS_ESCAPED}" CFLAGS=${CMAKE_C_FLAGS} CXXFLAGS=${CMAKE_CXX_FLAGS} CUDAARCHS="${CUDA_ARCHITECTURES_ESCAPED}" @@ -565,4 +571,7 @@ if(BUILD_GRAPHBOLT) DEPENDS ${BUILD_SCRIPT} WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/graphbolt) endif(MSVC) + if(USE_CUDA) + add_dependencies(graphbolt gpu_cache) + endif(USE_CUDA) endif(BUILD_GRAPHBOLT) diff --git a/graphbolt/CMakeLists.txt b/graphbolt/CMakeLists.txt index 8f377292e368..b932dc640550 100644 --- a/graphbolt/CMakeLists.txt +++ b/graphbolt/CMakeLists.txt @@ -76,6 +76,11 @@ if(USE_CUDA) "../third_party/cccl/thrust" "../third_party/cccl/cub" "../third_party/cccl/libcudacxx/include") + + message(STATUS "Use HugeCTR gpu_cache for graphbolt with INCLUDE_DIRS $ENV{GPU_CACHE_INCLUDE_DIRS}.") + target_include_directories(${LIB_GRAPHBOLT_NAME} PRIVATE $ENV{GPU_CACHE_INCLUDE_DIRS}) + target_link_directories(${LIB_GRAPHBOLT_NAME} PRIVATE ${GPU_CACHE_BUILD_DIR}) + target_link_libraries(${LIB_GRAPHBOLT_NAME} gpu_cache) get_property(archs TARGET ${LIB_GRAPHBOLT_NAME} PROPERTY CUDA_ARCHITECTURES) message(STATUS "CUDA_ARCHITECTURES for graphbolt: ${archs}") diff --git a/graphbolt/build.bat b/graphbolt/build.bat index 59df5ddb5109..b54411c60095 100755 --- a/graphbolt/build.bat +++ b/graphbolt/build.bat @@ -11,7 +11,7 @@ IF x%1x == xx GOTO single FOR %%X IN (%*) DO ( DEL /S /Q * - "%CMAKE_COMMAND%" -DCMAKE_CONFIGURATION_TYPES=Release -DPYTHON_INTERP=%%X .. -G "Visual Studio 16 2019" || EXIT /B 1 + "%CMAKE_COMMAND%" -DGPU_CACHE_BUILD_DIR=%BINDIR% -DCMAKE_CONFIGURATION_TYPES=Release -DPYTHON_INTERP=%%X .. -G "Visual Studio 16 2019" || EXIT /B 1 msbuild graphbolt.sln /m /nr:false || EXIT /B 1 COPY /Y Release\*.dll "%BINDIR%\graphbolt" || EXIT /B 1 ) @@ -21,7 +21,7 @@ GOTO end :single DEL /S /Q * -"%CMAKE_COMMAND%" -DCMAKE_CONFIGURATION_TYPES=Release .. -G "Visual Studio 16 2019" || EXIT /B 1 +"%CMAKE_COMMAND%" -DGPU_CACHE_BUILD_DIR=%BINDIR% -DCMAKE_CONFIGURATION_TYPES=Release .. -G "Visual Studio 16 2019" || EXIT /B 1 msbuild graphbolt.sln /m /nr:false || EXIT /B 1 COPY /Y Release\*.dll "%BINDIR%\graphbolt" || EXIT /B 1 diff --git a/graphbolt/build.sh b/graphbolt/build.sh index 3906eda2ede3..dd4564e9cb1e 100755 --- a/graphbolt/build.sh +++ b/graphbolt/build.sh @@ -12,7 +12,7 @@ else CPSOURCE=*.so fi -CMAKE_FLAGS="-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_TOOLKIT_ROOT_DIR -DUSE_CUDA=$USE_CUDA" +CMAKE_FLAGS="-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_TOOLKIT_ROOT_DIR -DUSE_CUDA=$USE_CUDA -DGPU_CACHE_BUILD_DIR=$BINDIR" echo $CMAKE_FLAGS if [ $# -eq 0 ]; then diff --git a/graphbolt/src/cuda/gpu_cache.cu b/graphbolt/src/cuda/gpu_cache.cu new file mode 100644 index 000000000000..0a47bbbddc18 --- /dev/null +++ b/graphbolt/src/cuda/gpu_cache.cu @@ -0,0 +1,108 @@ +/** + * Copyright (c) 2023 by Contributors + * Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) + * @file cuda/gpu_cache.cu + * @brief GPUCache implementation on CUDA. + */ +#include + +#include "./common.h" +#include "./gpu_cache.h" + +namespace graphbolt { +namespace cuda { + +GpuCache::GpuCache(const std::vector &shape, torch::ScalarType dtype) { + TORCH_CHECK(shape.size() >= 2, "Shape must at least have 2 dimensions."); + const auto num_items = shape[0]; + const int64_t num_feats = + std::accumulate(shape.begin() + 1, shape.end(), 1ll, std::multiplies<>()); + const int element_size = + torch::empty(1, torch::TensorOptions().dtype(dtype)).element_size(); + num_bytes_ = num_feats * element_size; + num_float_feats_ = (num_bytes_ + sizeof(float) - 1) / sizeof(float); + cache_ = std::make_unique( + (num_items + bucket_size - 1) / bucket_size, num_float_feats_); + shape_ = shape; + shape_[0] = -1; + dtype_ = dtype; + device_id_ = cuda::GetCurrentStream().device_index(); +} + +std::tuple GpuCache::Query( + torch::Tensor keys) { + TORCH_CHECK(keys.device().is_cuda(), "Keys should be on a CUDA device."); + TORCH_CHECK( + keys.device().index() == device_id_, + "Keys should be on the correct CUDA device."); + TORCH_CHECK(keys.sizes().size() == 1, "Keys should be a 1D tensor."); + keys = keys.to(torch::kLong); + auto values = torch::empty( + {keys.size(0), num_float_feats_}, keys.options().dtype(torch::kFloat)); + auto missing_index = + torch::empty(keys.size(0), keys.options().dtype(torch::kLong)); + auto missing_keys = + torch::empty(keys.size(0), keys.options().dtype(torch::kLong)); + cuda::CopyScalar missing_len; + auto stream = cuda::GetCurrentStream(); + cache_->Query( + reinterpret_cast(keys.data_ptr()), keys.size(0), + values.data_ptr(), + reinterpret_cast(missing_index.data_ptr()), + reinterpret_cast(missing_keys.data_ptr()), missing_len.get(), + stream); + values = values.view(torch::kByte) + .slice(1, 0, num_bytes_) + .view(dtype_) + .view(shape_); + // To safely read missing_len, we synchronize + stream.synchronize(); + missing_index = missing_index.slice(0, 0, static_cast(missing_len)); + missing_keys = missing_keys.slice(0, 0, static_cast(missing_len)); + return std::make_tuple(values, missing_index, missing_keys); +} + +void GpuCache::Replace(torch::Tensor keys, torch::Tensor values) { + TORCH_CHECK(keys.device().is_cuda(), "Keys should be on a CUDA device."); + TORCH_CHECK( + keys.device().index() == device_id_, + "Keys should be on the correct CUDA device."); + TORCH_CHECK(values.device().is_cuda(), "Keys should be on a CUDA device."); + TORCH_CHECK( + values.device().index() == device_id_, + "Values should be on the correct CUDA device."); + TORCH_CHECK( + keys.size(0) == values.size(0), + "The first dimensions of keys and values must match."); + TORCH_CHECK( + std::equal(shape_.begin() + 1, shape_.end(), values.sizes().begin() + 1), + "Values should have the correct dimensions."); + TORCH_CHECK( + values.scalar_type() == dtype_, "Values should have the correct dtype."); + keys = keys.to(torch::kLong); + torch::Tensor float_values; + if (num_bytes_ % sizeof(float) != 0) { + float_values = torch::empty( + {values.size(0), num_float_feats_}, + values.options().dtype(torch::kFloat)); + float_values.view(torch::kByte) + .slice(1, 0, num_bytes_) + .copy_(values.view(torch::kByte).view({values.size(0), -1})); + } else { + float_values = values.view(torch::kByte) + .view({values.size(0), -1}) + .view(torch::kFloat) + .contiguous(); + } + cache_->Replace( + reinterpret_cast(keys.data_ptr()), keys.size(0), + float_values.data_ptr(), cuda::GetCurrentStream()); +} + +c10::intrusive_ptr GpuCache::Create( + const std::vector &shape, torch::ScalarType dtype) { + return c10::make_intrusive(shape, dtype); +} + +} // namespace cuda +} // namespace graphbolt diff --git a/graphbolt/src/cuda/gpu_cache.h b/graphbolt/src/cuda/gpu_cache.h new file mode 100644 index 000000000000..ce9455618a69 --- /dev/null +++ b/graphbolt/src/cuda/gpu_cache.h @@ -0,0 +1,66 @@ +/** + * Copyright (c) 2023 by Contributors + * Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) + * @file cuda/gpu_cache.h + * @brief Header file of HugeCTR gpu_cache wrapper. + */ + +#ifndef GRAPHBOLT_GPU_CACHE_H_ +#define GRAPHBOLT_GPU_CACHE_H_ + +#include +#include + +#include +#include + +namespace graphbolt { +namespace cuda { + +class GpuCache : public torch::CustomClassHolder { + using key_t = long long; + constexpr static int set_associativity = 2; + constexpr static int WARP_SIZE = 32; + constexpr static int bucket_size = WARP_SIZE * set_associativity; + using gpu_cache_t = ::gpu_cache::gpu_cache< + key_t, uint64_t, std::numeric_limits::max(), set_associativity, + WARP_SIZE>; + + public: + /** + * @brief Constructor for the GpuCache struct. + * + * @param shape The shape of the GPU cache. + * @param dtype The datatype of items to be stored. + */ + GpuCache(const std::vector& shape, torch::ScalarType dtype); + + GpuCache() = default; + + std::tuple Query( + torch::Tensor keys); + + void Replace(torch::Tensor keys, torch::Tensor values); + + static c10::intrusive_ptr Create( + const std::vector& shape, torch::ScalarType dtype); + + private: + std::vector shape_; + torch::ScalarType dtype_; + std::unique_ptr cache_; + int64_t num_bytes_; + int64_t num_float_feats_; + torch::DeviceIndex device_id_; +}; + +// The cu file in HugeCTR gpu cache uses unsigned int and long long. +// Changing to int64_t results in a mismatch of template arguments. +static_assert( + sizeof(long long) == sizeof(int64_t), + "long long and int64_t needs to have the same size."); // NOLINT + +} // namespace cuda +} // namespace graphbolt + +#endif // GRAPHBOLT_GPU_CACHE_H_ diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index d44bccf66cc9..9e926b672bd8 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -15,6 +15,10 @@ #include "./index_select.h" #include "./random.h" +#ifdef GRAPHBOLT_USE_CUDA +#include "./cuda/gpu_cache.h" +#endif + namespace graphbolt { namespace sampling { @@ -70,6 +74,12 @@ TORCH_LIBRARY(graphbolt, m) { g->SetState(state); return g; }); +#ifdef GRAPHBOLT_USE_CUDA + m.class_("GpuCache") + .def("query", &cuda::GpuCache::Query) + .def("replace", &cuda::GpuCache::Replace); + m.def("gpu_cache", &cuda::GpuCache::Create); +#endif m.def("fused_csc_sampling_graph", &FusedCSCSamplingGraph::Create); m.def( "load_from_shared_memory", &FusedCSCSamplingGraph::LoadFromSharedMemory); diff --git a/python/dgl/graphbolt/impl/__init__.py b/python/dgl/graphbolt/impl/__init__.py index fa984439094d..943eb4ecca2c 100644 --- a/python/dgl/graphbolt/impl/__init__.py +++ b/python/dgl/graphbolt/impl/__init__.py @@ -1,6 +1,7 @@ """Implementation of GraphBolt.""" from .basic_feature_store import * from .fused_csc_sampling_graph import * +from .gpu_cache import * from .gpu_cached_feature import * from .in_subgraph_sampler import * from .legacy_dataset import * diff --git a/python/dgl/graphbolt/impl/gpu_cache.py b/python/dgl/graphbolt/impl/gpu_cache.py new file mode 100644 index 000000000000..7c07e7c52a0b --- /dev/null +++ b/python/dgl/graphbolt/impl/gpu_cache.py @@ -0,0 +1,53 @@ +"""HugeCTR gpu_cache wrapper for graphbolt.""" +import torch + + +class GPUCache(object): + """High-level wrapper for GPU embedding cache""" + + def __init__(self, cache_shape, dtype): + major, _ = torch.cuda.get_device_capability() + assert ( + major >= 7 + ), "GPUCache is supported only on CUDA compute capability >= 70 (Volta)." + self._cache = torch.ops.graphbolt.gpu_cache(cache_shape, dtype) + self.total_miss = 0 + self.total_queries = 0 + + def query(self, keys): + """Queries the GPU cache. + + Parameters + ---------- + keys : Tensor + The keys to query the GPU cache with. + + Returns + ------- + tuple(Tensor, Tensor, Tensor) + A tuple containing (values, missing_indices, missing_keys) where + values[missing_indices] corresponds to cache misses that should be + filled by quering another source with missing_keys. + """ + self.total_queries += keys.shape[0] + values, missing_index, missing_keys = self._cache.query(keys) + self.total_miss += missing_keys.shape[0] + return values, missing_index, missing_keys + + def replace(self, keys, values): + """Inserts key-value pairs into the GPU cache using the Least-Recently + Used (LRU) algorithm to remove old key-value pairs if it is full. + + Parameters + ---------- + keys: Tensor + The keys to insert to the GPU cache. + values: Tensor + The values to insert to the GPU cache. + """ + self._cache.replace(keys, values) + + @property + def miss_rate(self): + """Returns the cache miss rate since creation.""" + return self.total_miss / self.total_queries diff --git a/python/dgl/graphbolt/impl/gpu_cached_feature.py b/python/dgl/graphbolt/impl/gpu_cached_feature.py index 8986c79b1d4a..0be929ba4abf 100644 --- a/python/dgl/graphbolt/impl/gpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/gpu_cached_feature.py @@ -1,10 +1,10 @@ """GPU cached feature for GraphBolt.""" import torch -from dgl.cuda import GPUCache - from ..feature_store import Feature +from .gpu_cache import GPUCache + __all__ = ["GPUCachedFeature"] @@ -52,10 +52,7 @@ def __init__(self, fallback_feature: Feature, cache_size: int): self.cache_size = cache_size # Fetching the feature dimension from the underlying feature. feat0 = fallback_feature.read(torch.tensor([0])) - self.item_shape = (-1,) + feat0.shape[1:] - feat0 = torch.reshape(feat0, (1, -1)) - self.flat_shape = (-1, feat0.shape[1]) - self._feature = GPUCache(cache_size, feat0.shape[1]) + self._feature = GPUCache((cache_size,) + feat0.shape[1:], feat0.dtype) def read(self, ids: torch.Tensor = None): """Read the feature by index. @@ -75,15 +72,12 @@ def read(self, ids: torch.Tensor = None): The read feature. """ if ids is None: - return self._fallback_feature.read().to("cuda") - keys = ids.to("cuda") - values, missing_index, missing_keys = self._feature.query(keys) + return self._fallback_feature.read() + values, missing_index, missing_keys = self._feature.query(ids) missing_values = self._fallback_feature.read(missing_keys).to("cuda") - missing_values = missing_values.reshape(self.flat_shape) - values = values.to(missing_values.dtype) values[missing_index] = missing_values self._feature.replace(missing_keys, missing_values) - return torch.reshape(values, self.item_shape) + return values def size(self): """Get the size of the feature. @@ -114,10 +108,8 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None): size = min(self.cache_size, value.shape[0]) self._feature.replace( torch.arange(0, size, device="cuda"), - value[:size].to("cuda").reshape(self.flat_shape), + value[:size].to("cuda"), ) else: self._fallback_feature.update(value, ids) - self._feature.replace( - ids.to("cuda"), value.to("cuda").reshape(self.flat_shape) - ) + self._feature.replace(ids, value) diff --git a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py index 85c8666a6a7b..d251701cdaf9 100644 --- a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py +++ b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py @@ -2,34 +2,53 @@ import backend as F +import pytest import torch from dgl import graphbolt as gb @unittest.skipIf( - F._default_context_str != "gpu", - reason="GPUCachedFeature requires a GPU.", + F._default_context_str != "gpu" + or torch.cuda.get_device_capability()[0] < 7, + reason="GPUCachedFeature requires a Volta or later generation NVIDIA GPU.", ) -def test_gpu_cached_feature(): - a = torch.tensor([[1, 2, 3], [4, 5, 6]]).to("cuda").float() - b = torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]]).to("cuda").float() +@pytest.mark.parametrize( + "dtype", + [ + torch.bool, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + ], +) +def test_gpu_cached_feature(dtype): + a = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dtype, pin_memory=True) + b = torch.tensor( + [[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype, pin_memory=True + ) feat_store_a = gb.GPUCachedFeature(gb.TorchBasedFeature(a), 2) feat_store_b = gb.GPUCachedFeature(gb.TorchBasedFeature(b), 1) # Test read the entire feature. - assert torch.equal(feat_store_a.read(), a) - assert torch.equal(feat_store_b.read(), b) + assert torch.equal(feat_store_a.read(), a.to("cuda")) + assert torch.equal(feat_store_b.read(), b.to("cuda")) # Test read with ids. assert torch.equal( feat_store_a.read(torch.tensor([0]).to("cuda")), - torch.tensor([[1.0, 2.0, 3.0]]).to("cuda"), + torch.tensor([[1, 2, 3]], dtype=dtype).to("cuda"), ) assert torch.equal( feat_store_b.read(torch.tensor([1, 1]).to("cuda")), - torch.tensor([[[4.0, 5.0], [6.0, 7.0]], [[4.0, 5.0], [6.0, 7.0]]]).to( + torch.tensor([[[4, 5], [6, 7]], [[4, 5], [6, 7]]], dtype=dtype).to( "cuda" ), ) @@ -40,18 +59,19 @@ def test_gpu_cached_feature(): # Test update the entire feature. feat_store_a.update( - torch.tensor([[0.0, 1.0, 2.0], [3.0, 5.0, 2.0]]).to("cuda") + torch.tensor([[0, 1, 2], [3, 5, 2]], dtype=dtype).to("cuda") ) assert torch.equal( feat_store_a.read(), - torch.tensor([[0.0, 1.0, 2.0], [3.0, 5.0, 2.0]]).to("cuda"), + torch.tensor([[0, 1, 2], [3, 5, 2]], dtype=dtype).to("cuda"), ) # Test update with ids. feat_store_a.update( - torch.tensor([[2.0, 0.0, 1.0]]).to("cuda"), torch.tensor([0]).to("cuda") + torch.tensor([[2, 0, 1]], dtype=dtype).to("cuda"), + torch.tensor([0]).to("cuda"), ) assert torch.equal( feat_store_a.read(), - torch.tensor([[2.0, 0.0, 1.0], [3.0, 5.0, 2.0]]).to("cuda"), + torch.tensor([[2, 0, 1], [3, 5, 2]], dtype=dtype).to("cuda"), ) From 78fa316aee6f9fe1f158659a61a8626fc14df58a Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Thu, 18 Jan 2024 03:19:07 -0500 Subject: [PATCH 5/5] [GraphBolt][CUDA] Modify multiGPU example to use GPU sampling. (#6961) --- .../multigpu/graphbolt/node_classification.py | 39 ++++++++++++------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/examples/multigpu/graphbolt/node_classification.py b/examples/multigpu/graphbolt/node_classification.py index 2d3344ce301c..5ef93311fe55 100644 --- a/examples/multigpu/graphbolt/node_classification.py +++ b/examples/multigpu/graphbolt/node_classification.py @@ -126,9 +126,6 @@ def create_dataloader( shuffle=shuffle, drop_uneven_inputs=drop_uneven_inputs, ) - datapipe = datapipe.sample_neighbor(graph, args.fanout) - datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) - ############################################################################ # [Note]: # datapipe.copy_to() / gb.CopyTo() @@ -137,8 +134,14 @@ def create_dataloader( # [Output]: # A CopyTo object copying data in the datapipe to a specified device.\ ############################################################################ - datapipe = datapipe.copy_to(device) - dataloader = gb.DataLoader(datapipe, num_workers=args.num_workers) + if not args.cpu_sampling: + datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"]) + datapipe = datapipe.sample_neighbor(graph, args.fanout) + datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) + if args.cpu_sampling: + datapipe = datapipe.copy_to(device) + + dataloader = gb.DataLoader(datapipe, args.num_workers) # Return the fully-initialized DataLoader object. return dataloader @@ -272,15 +275,18 @@ def run(rank, world_size, args, devices, dataset): rank=rank, ) - graph = dataset.graph - features = dataset.feature + # Pin the graph and features to enable GPU access. + if not args.cpu_sampling: + dataset.graph.pin_memory_() + dataset.feature.pin_memory_() + train_set = dataset.tasks[0].train_set valid_set = dataset.tasks[0].validation_set test_set = dataset.tasks[0].test_set args.fanout = list(map(int, args.fanout.split(","))) num_classes = dataset.tasks[0].metadata["num_classes"] - in_size = features.size("node", None, "feat")[0] + in_size = dataset.feature.size("node", None, "feat")[0] hidden_size = 256 out_size = num_classes @@ -291,8 +297,8 @@ def run(rank, world_size, args, devices, dataset): # Create data loaders. train_dataloader = create_dataloader( args, - graph, - features, + dataset.graph, + dataset.feature, train_set, device, drop_last=False, @@ -301,8 +307,8 @@ def run(rank, world_size, args, devices, dataset): ) valid_dataloader = create_dataloader( args, - graph, - features, + dataset.graph, + dataset.feature, valid_set, device, drop_last=False, @@ -311,8 +317,8 @@ def run(rank, world_size, args, devices, dataset): ) test_dataloader = create_dataloader( args, - graph, - features, + dataset.graph, + dataset.feature, test_set, device, drop_last=False, @@ -387,6 +393,11 @@ def parse_args(): parser.add_argument( "--num-workers", type=int, default=0, help="The number of processes." ) + parser.add_argument( + "--cpu-sampling", + action="store_true", + help="Disables GPU sampling and utilizes the CPU for dataloading.", + ) return parser.parse_args()