Skip to content

Commit

Permalink
Merge branch 'master' into ondisk_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Feb 5, 2024
2 parents 7f46f22 + a2e1c79 commit 44b3c79
Show file tree
Hide file tree
Showing 24 changed files with 987 additions and 237 deletions.
1 change: 1 addition & 0 deletions docs/source/api/python/dgl.graphbolt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ Utilities
etype_tuple_to_str
isin
seed
index_select
expand_indptr
add_reverse_edges
exclude_seed_edges
Expand Down
12 changes: 12 additions & 0 deletions examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,12 @@ def run(rank, world_size, args, devices, dataset):
hidden_size = 256
out_size = num_classes

if args.gpu_cache_size > 0:
dataset.feature._features[("node", None, "feat")] = gb.GPUCachedFeature(
dataset.feature._features[("node", None, "feat")],
args.gpu_cache_size,
)

# Create GraphSAGE model. It should be copied onto a GPU as a replica.
model = SAGE(in_size, hidden_size, out_size).to(device)
model = DDP(model)
Expand Down Expand Up @@ -381,6 +387,12 @@ def parse_args():
parser.add_argument(
"--num-workers", type=int, default=0, help="The number of processes."
)
parser.add_argument(
"--gpu-cache-size",
type=int,
default=0,
help="The capacity of the GPU cache, the number of features to store.",
)
parser.add_argument(
"--mode",
default="pinned-cuda",
Expand Down
38 changes: 16 additions & 22 deletions examples/sampling/graphbolt/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,10 @@ def forward(self, blocks, x):
hidden_x = F.relu(hidden_x)
return hidden_x

def inference(self, graph, features, dataloader, device):
def inference(self, graph, features, dataloader, storage_device):
"""Conduct layer-wise inference to get all the node embeddings."""
feature = features.read("node", None, "feat")

buffer_device = torch.device("cpu")
# Enable pin_memory for faster CPU to GPU data transfer if the
# model is running on a GPU.
pin_memory = buffer_device != device
pin_memory = storage_device == "pinned"
buffer_device = torch.device("cpu" if pin_memory else storage_device)

print("Start node embedding inference.")
for layer_idx, layer in enumerate(self.layers):
Expand All @@ -99,17 +95,17 @@ def inference(self, graph, features, dataloader, device):
device=buffer_device,
pin_memory=pin_memory,
)
feature = feature.to(device)
for step, data in tqdm.tqdm(enumerate(dataloader)):
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
for data in tqdm.tqdm(dataloader):
# len(blocks) = 1
hidden_x = layer(data.blocks[0], data.node_features["feat"])
if not is_last_layer:
hidden_x = F.relu(hidden_x)
# By design, our seed nodes are contiguous.
y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to(
buffer_device, non_blocking=True
)
feature = y
if not is_last_layer:
features.update("node", None, "feat", y)

return y

Expand Down Expand Up @@ -185,7 +181,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# [Role]:
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.sample_neighbor(
graph, args.fanout if is_train else [-1]
)

############################################################################
# [Input]:
Expand Down Expand Up @@ -213,12 +211,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# A FeatureFetcher object to fetch node features.
# [Role]:
# Initialize a feature fetcher for fetching features of the sampled
# subgraphs. This step is skipped in evaluation/inference because features
# are updated as a whole during it, thus storing features in minibatch is
# unnecessary.
# subgraphs.
############################################################################
if is_train:
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])

############################################################################
# [Input]:
Expand Down Expand Up @@ -286,15 +281,12 @@ def evaluate(args, model, graph, features, all_nodes_set, valid_set, test_set):
model.eval()
evaluator = Evaluator(name="ogbl-citation2")

# Since we need to use all neghborhoods for evaluation, we set the fanout
# to -1.
args.fanout = [-1]
dataloader = create_dataloader(
args, graph, features, all_nodes_set, is_train=False
)

# Compute node embeddings for the entire graph.
node_emb = model.inference(graph, features, dataloader, args.device)
node_emb = model.inference(graph, features, dataloader, args.storage_device)
results = []

# Loop over both validation and test sets.
Expand Down Expand Up @@ -340,6 +332,8 @@ def train(args, model, graph, features, train_set):

total_loss += loss.item()
if step + 1 == args.early_stop:
# Early stopping requires a new dataloader to reset its state.
dataloader = create_dataloader(args, graph, features, train_set)
break

end_epoch_time = time.time()
Expand Down
29 changes: 11 additions & 18 deletions examples/sampling/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,9 @@ def create_dataloader(
# A FeatureFetcher object to fetch node features.
# [Role]:
# Initialize a feature fetcher for fetching features of the sampled
# subgraphs. This step is skipped in inference because features are updated
# as a whole during it, thus storing features in minibatch is unnecessary.
# subgraphs.
############################################################################
if job != "infer":
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])

############################################################################
# [Step-5]:
Expand Down Expand Up @@ -194,14 +192,10 @@ def forward(self, blocks, x):
hidden_x = self.dropout(hidden_x)
return hidden_x

def inference(self, graph, features, dataloader, device):
def inference(self, graph, features, dataloader, storage_device):
"""Conduct layer-wise inference to get all the node embeddings."""
feature = features.read("node", None, "feat")

buffer_device = torch.device("cpu")
# Enable pin_memory for faster CPU to GPU data transfer if the
# model is running on a GPU.
pin_memory = buffer_device != device
pin_memory = storage_device == "pinned"
buffer_device = torch.device("cpu" if pin_memory else storage_device)

for layer_idx, layer in enumerate(self.layers):
is_last_layer = layer_idx == len(self.layers) - 1
Expand All @@ -213,19 +207,18 @@ def inference(self, graph, features, dataloader, device):
device=buffer_device,
pin_memory=pin_memory,
)
feature = feature.to(device)

for step, data in tqdm(enumerate(dataloader)):
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
for data in tqdm(dataloader):
# len(blocks) = 1
hidden_x = layer(data.blocks[0], data.node_features["feat"])
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
# By design, our output nodes are contiguous.
y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to(
buffer_device
)
feature = y
if not is_last_layer:
features.update("node", None, "feat", y)

return y

Expand All @@ -245,7 +238,7 @@ def layerwise_infer(
num_workers=args.num_workers,
job="infer",
)
pred = model.inference(graph, features, dataloader, args.device)
pred = model.inference(graph, features, dataloader, args.storage_device)
pred = pred[test_set._items[0]]
label = test_set._items[1].to(pred.device)

Expand Down
12 changes: 6 additions & 6 deletions graphbolt/src/cuda/gpu_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,19 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query(
torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
auto missing_keys =
torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
cuda::CopyScalar<size_t> missing_len;
auto stream = cuda::GetCurrentStream();
auto allocator = cuda::GetAllocator();
auto missing_len_device = allocator.AllocateStorage<size_t>(1);
cache_->Query(
reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),
values.data_ptr<float>(),
reinterpret_cast<uint64_t *>(missing_index.data_ptr()),
reinterpret_cast<key_t *>(missing_keys.data_ptr()), missing_len.get(),
stream);
reinterpret_cast<key_t *>(missing_keys.data_ptr()),
missing_len_device.get(), cuda::GetCurrentStream());
values = values.view(torch::kByte)
.slice(1, 0, num_bytes_)
.view(dtype_)
.view(shape_);
// To safely read missing_len, we synchronize
stream.synchronize();
cuda::CopyScalar<size_t> missing_len(missing_len_device.get());
missing_index = missing_index.slice(0, 0, static_cast<size_t>(missing_len));
missing_keys = missing_keys.slice(0, 0, static_cast<size_t>(missing_len));
return std::make_tuple(values, missing_index, missing_keys);
Expand All @@ -79,6 +78,7 @@ void GpuCache::Replace(torch::Tensor keys, torch::Tensor values) {
"Values should have the correct dimensions.");
TORCH_CHECK(
values.scalar_type() == dtype_, "Values should have the correct dtype.");
if (keys.numel() == 0) return;
keys = keys.to(torch::kLong);
torch::Tensor float_values;
if (num_bytes_ % sizeof(float) != 0) {
Expand Down
85 changes: 83 additions & 2 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,21 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(
}
if (node_attributes.has_value()) {
for (const auto& pair : node_attributes.value()) {
TORCH_CHECK(pair.value().size(0) == indptr.size(0) - 1);
TORCH_CHECK(
pair.value().size(0) == indptr.size(0) - 1,
"Expected node_attribute.size(0) and num_nodes to be equal, "
"but node_attribute.size(0) was ",
pair.value().size(0), ", and num_nodes was ", indptr.size(0) - 1,
".");
}
}
if (edge_attributes.has_value()) {
for (const auto& pair : edge_attributes.value()) {
TORCH_CHECK(pair.value().size(0) == indices.size(0));
TORCH_CHECK(
pair.value().size(0) == indices.size(0),
"Expected edge_attribute.size(0) and num_edges to be equal, "
"but edge_attribute.size(0) was ",
pair.value().size(0), ", and num_edges was ", indices.size(0), ".");
}
}
return c10::make_intrusive<FusedCSCSamplingGraph>(
Expand Down Expand Up @@ -810,12 +819,71 @@ torch::Tensor TemporalMask(
return mask;
}

/**
* @brief Fast path for temporal sampling without probability. It is used when
* the number of neighbors is large. It randomly samples neighbors and checks
* the timestamp of the neighbors. It is successful if the number of sampled
* neighbors in kTriedThreshold trials is equal to the fanout.
*/
std::pair<bool, std::vector<int64_t>> FastTemporalPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indices, int64_t fanout,
bool replace, const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors) {
constexpr int64_t kTriedThreshold = 1000;
auto timestamp = utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset);
std::vector<int64_t> sampled_edges;
sampled_edges.reserve(fanout);
std::set<int64_t> sampled_edge_set;
int64_t sample_count = 0;
int64_t tried = 0;
while (sample_count < fanout && tried < kTriedThreshold) {
int64_t edge_id =
RandomEngine::ThreadLocal()->RandInt(offset, offset + num_neighbors);
++tried;
if (!replace && sampled_edge_set.count(edge_id) > 0) {
continue;
}
if (node_timestamp.has_value()) {
int64_t neighbor_id =
utils::GetValueByIndex<int64_t>(csc_indices, edge_id);
if (utils::GetValueByIndex<int64_t>(
node_timestamp.value(), neighbor_id) >= timestamp)
continue;
}
if (edge_timestamp.has_value() &&
utils::GetValueByIndex<int64_t>(edge_timestamp.value(), edge_id) >=
timestamp) {
continue;
}
if (!replace) {
sampled_edge_set.insert(edge_id);
}
sampled_edges.push_back(edge_id);
sample_count++;
}
if (sample_count < fanout) {
return {false, {}};
}
return {true, sampled_edges};
}

int64_t TemporalNumPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout,
bool replace, const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors) {
constexpr int64_t kFastPathThreshold = 1000;
if (num_neighbors > kFastPathThreshold && !probs_or_mask.has_value()) {
// TODO: Currently we use the fast path both in TemporalNumPick and
// TemporalPick. We may only sample once in TemporalNumPick and use the
// sampled edges in TemporalPick to avoid sampling twice.
auto [success, sampled_edges] = FastTemporalPick(
seed_timestamp, csc_indics, fanout, replace, node_timestamp,
edge_timestamp, seed_offset, offset, num_neighbors);
if (success) return sampled_edges.size();
}
auto mask = TemporalMask(
utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indics,
probs_or_mask, node_timestamp, edge_timestamp,
Expand Down Expand Up @@ -1183,6 +1251,19 @@ int64_t TemporalPick(
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, SamplerArgs<S> args,
PickedType* picked_data_ptr) {
constexpr int64_t kFastPathThreshold = 1000;
if (S == SamplerType::NEIGHBOR && num_neighbors > kFastPathThreshold &&
!probs_or_mask.has_value()) {
auto [success, sampled_edges] = FastTemporalPick(
seed_timestamp, csc_indices, fanout, replace, node_timestamp,
edge_timestamp, seed_offset, offset, num_neighbors);
if (success) {
for (size_t i = 0; i < sampled_edges.size(); ++i) {
picked_data_ptr[i] = static_cast<PickedType>(sampled_edges[i]);
}
return sampled_edges.size();
}
}
auto mask = TemporalMask(
utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indices,
probs_or_mask, node_timestamp, edge_timestamp,
Expand Down
9 changes: 5 additions & 4 deletions python/dgl/convert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module for converting graph from/to other object."""

from collections import defaultdict
from collections.abc import Mapping

Expand Down Expand Up @@ -296,9 +297,9 @@ def heterograph(data_dict, num_nodes_dict=None, idtype=None, device=None):
>>> g = dgl.heterograph(data_dict)
>>> g
Graph(num_nodes={'game': 5, 'topic': 3, 'user': 4},
num_edges={('user', 'follows', 'user'): 2, ('user', 'follows', 'topic'): 2,
num_edges={('user', 'follows', 'topic'): 2, ('user', 'follows', 'user'): 2,
('user', 'plays', 'game'): 2},
metagraph=[('user', 'user', 'follows'), ('user', 'topic', 'follows'),
metagraph=[('user', 'topic', 'follows'), ('user', 'user', 'follows'),
('user', 'game', 'plays')])
Explicitly specify the number of nodes for each node type in the graph.
Expand Down Expand Up @@ -1810,11 +1811,11 @@ def to_networkx(
... ('user', 'follows', 'topic'): (torch.tensor([1, 1]), torch.tensor([1, 2])),
... ('user', 'plays', 'game'): (torch.tensor([0, 3]), torch.tensor([3, 4]))
... })
... g.ndata['n'] = {
>>> g.ndata['n'] = {
... 'game': torch.zeros(5, 1),
... 'user': torch.ones(4, 1)
... }
... g.edata['e'] = {
>>> g.edata['e'] = {
... ('user', 'follows', 'user'): torch.zeros(2, 1),
... 'plays': torch.ones(2, 1)
... }
Expand Down

0 comments on commit 44b3c79

Please sign in to comment.