Skip to content

Commit

Permalink
Merge branch 'master' into gb_cuda_examples2
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jan 31, 2024
2 parents 1240697 + 942b17a commit e81a62b
Show file tree
Hide file tree
Showing 29 changed files with 1,474 additions and 397 deletions.
2 changes: 2 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,8 @@ pipeline {
steps {
unit_test_linux('tensorflow', 'cpu')
}
// Tensorflow is deprecated.
when { expression { false } }
}
}
post {
Expand Down
46 changes: 17 additions & 29 deletions examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ def create_dataloader(
features,
itemset,
device,
drop_last=False,
shuffle=True,
drop_uneven_inputs=False,
is_train,
):
############################################################################
# [HIGHLIGHT]
Expand Down Expand Up @@ -122,9 +120,9 @@ def create_dataloader(
datapipe = gb.DistributedItemSampler(
item_set=itemset,
batch_size=args.batch_size,
drop_last=drop_last,
shuffle=shuffle,
drop_uneven_inputs=drop_uneven_inputs,
drop_last=is_train,
shuffle=is_train,
drop_uneven_inputs=is_train,
)
############################################################################
# [Note]:
Expand Down Expand Up @@ -187,7 +185,7 @@ def train(
epoch_start = time.time()

model.train()
total_loss = torch.tensor(0, dtype=torch.float).to(device)
total_loss = torch.tensor(0, dtype=torch.float, device=device)
########################################################################
# (HIGHLIGHT) Use Join Context Manager to solve uneven input problem.
#
Expand Down Expand Up @@ -227,20 +225,17 @@ def train(
loss.backward()
optimizer.step()

total_loss += loss
total_loss += loss.detach()

# Evaluate the model.
if rank == 0:
print("Validating...")
acc = (
evaluate(
rank,
model,
valid_dataloader,
num_classes,
device,
)
/ world_size
acc = evaluate(
rank,
model,
valid_dataloader,
num_classes,
device,
)
########################################################################
# (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
Expand All @@ -252,14 +247,13 @@ def train(
dist.reduce(tensor=acc, dst=0)
total_loss /= step + 1
dist.reduce(tensor=total_loss, dst=0)
dist.barrier()

epoch_end = time.time()
if rank == 0:
print(
f"Epoch {epoch:05d} | "
f"Average Loss {total_loss.item() / world_size:.4f} | "
f"Accuracy {acc.item():.4f} | "
f"Accuracy {acc.item() / world_size:.4f} | "
f"Time {epoch_end - epoch_start:.4f}"
)

Expand Down Expand Up @@ -301,29 +295,23 @@ def run(rank, world_size, args, devices, dataset):
dataset.feature,
train_set,
device,
drop_last=False,
shuffle=True,
drop_uneven_inputs=False,
is_train=True,
)
valid_dataloader = create_dataloader(
args,
dataset.graph,
dataset.feature,
valid_set,
device,
drop_last=False,
shuffle=False,
drop_uneven_inputs=False,
is_train=False,
)
test_dataloader = create_dataloader(
args,
dataset.graph,
dataset.feature,
test_set,
device,
drop_last=False,
shuffle=False,
drop_uneven_inputs=False,
is_train=False,
)

# Model training.
Expand Down Expand Up @@ -354,7 +342,7 @@ def run(rank, world_size, args, devices, dataset):
/ world_size
)
dist.reduce(tensor=test_acc, dst=0)
dist.barrier()
torch.cuda.synchronize()
if rank == 0:
print(f"Test Accuracy {test_acc.item():.4f}")

Expand Down
5 changes: 3 additions & 2 deletions graphbolt/include/graphbolt/cuda_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,16 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
* given nodes and their indptr values.
*
* @param indptr The indptr tensor.
* @param nodes The nodes to read from indptr
* @param nodes The nodes to read from indptr. If not provided, assumed to be
* equal to torch.arange(indptr.size(0) - 1).
*
* @return Tuple of tensors with values:
* (indptr[nodes + 1] - indptr[nodes], indptr[nodes]), the returned indegrees
* tensor (first one) has size nodes.size(0) + 1 so that calling ExclusiveCumSum
* on it gives the output indptr.
*/
std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
torch::Tensor indptr, torch::Tensor nodes);
torch::Tensor indptr, torch::optional<torch::Tensor> nodes);

/**
* @brief Given the compacted sub_indptr tensor, edge type tensor and
Expand Down
9 changes: 5 additions & 4 deletions graphbolt/include/graphbolt/cuda_sampling_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ namespace ops {
*
* @param indptr Index pointer array of the CSC.
* @param indices Indices array of the CSC.
* @param nodes The nodes from which to sample neighbors.
* @param nodes The nodes from which to sample neighbors. If not provided,
* assumed to be equal to torch.arange(indptr.size(0) - 1).
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all
Expand Down Expand Up @@ -49,9 +50,9 @@ namespace ops {
* the sampled graph's information.
*/
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids,
torch::Tensor indptr, torch::Tensor indices,
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt);

Expand Down
5 changes: 3 additions & 2 deletions graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @brief Sample neighboring edges of the given nodes and return the induced
* subgraph.
*
* @param nodes The nodes from which to sample neighbors.
* @param nodes The nodes from which to sample neighbors. If not provided,
* assumed to be equal to torch.arange(NumNodes()).
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all
Expand Down Expand Up @@ -317,7 +318,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* the sampled graph's information.
*/
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const;

Expand Down
60 changes: 39 additions & 21 deletions graphbolt/src/cuda/index_select_csc_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,18 @@ struct AlignmentFunc {
}
};

template <typename indptr_t, typename indices_t>
template <typename indptr_t, typename indices_t, typename coo_rows_t>
__global__ void _CopyIndicesAlignedKernel(
const indptr_t edge_count, const int64_t num_nodes,
const indptr_t* const indptr, const indptr_t* const output_indptr,
const indptr_t edge_count, const indptr_t* const indptr,
const indptr_t* const output_indptr,
const indptr_t* const output_indptr_aligned, const indices_t* const indices,
indices_t* const output_indices, const int64_t* const perm) {
const coo_rows_t* const coo_aligned_rows, indices_t* const output_indices,
const int64_t* const perm) {
indptr_t idx = static_cast<indptr_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;

while (idx < edge_count) {
const auto permuted_row_pos =
cuda::UpperBound(output_indptr_aligned, num_nodes, idx) - 1;
const auto permuted_row_pos = coo_aligned_rows[idx];
const auto row_pos = perm ? perm[permuted_row_pos] : permuted_row_pos;
const auto out_row = output_indptr[row_pos];
const auto d = output_indptr[row_pos + 1] - out_row;
Expand Down Expand Up @@ -97,7 +97,8 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
torch::empty(num_nodes + 1, options.dtype(indptr_scalar_type));

auto output_indptr_aligned =
allocator.AllocateStorage<indptr_t>(num_nodes + 1);
torch::empty(num_nodes + 1, options.dtype(indptr_scalar_type));
auto output_indptr_aligned_ptr = output_indptr_aligned.data_ptr<indptr_t>();

{
// Returns the actual and modified_indegree as a pair, the
Expand All @@ -106,7 +107,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
auto modified_in_degree = thrust::make_transform_iterator(
iota, AlignmentFunc<indptr_t, indices_t>{in_degree, perm, num_nodes});
auto output_indptr_pair = thrust::make_zip_iterator(
output_indptr.data_ptr<indptr_t>(), output_indptr_aligned.get());
output_indptr.data_ptr<indptr_t>(), output_indptr_aligned_ptr);
thrust::tuple<indptr_t, indptr_t> zero_value{};
// Compute the prefix sum over actual and modified indegrees.
CUB_CALL(
Expand All @@ -121,25 +122,42 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
output_size = static_cast<indptr_t>(edge_count);
}
// Copy the modified number of edges.
auto edge_count_aligned =
cuda::CopyScalar{output_indptr_aligned.get() + num_nodes};
auto edge_count_aligned_ =
cuda::CopyScalar{output_indptr_aligned_ptr + num_nodes};
const int64_t edge_count_aligned = static_cast<indptr_t>(edge_count_aligned_);

// Allocate output array with actual number of edges.
torch::Tensor output_indices =
torch::empty(output_size.value(), options.dtype(indices.scalar_type()));
const dim3 block(BLOCK_SIZE);
const dim3 grid(
(static_cast<indptr_t>(edge_count_aligned) + BLOCK_SIZE - 1) /
BLOCK_SIZE);
const dim3 grid((edge_count_aligned + BLOCK_SIZE - 1) / BLOCK_SIZE);

// Find the smallest integer type to store the coo_aligned_rows tensor.
const int num_bits = cuda::NumberOfBits(num_nodes);
std::array<int, 4> type_bits = {8, 15, 31, 63};
const auto type_index =
std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
type_bits.begin();
std::array<torch::ScalarType, 5> types = {
torch::kByte, torch::kInt16, torch::kInt32, torch::kLong, torch::kLong};
auto coo_dtype = types[type_index];

// Perform the actual copying, of the indices array into
// output_indices in an aligned manner.
CUDA_KERNEL_CALL(
_CopyIndicesAlignedKernel, grid, block, 0,
static_cast<indptr_t>(edge_count_aligned), num_nodes, sliced_indptr,
output_indptr.data_ptr<indptr_t>(), output_indptr_aligned.get(),
reinterpret_cast<indices_t*>(indices.data_ptr()),
reinterpret_cast<indices_t*>(output_indices.data_ptr()), perm);
auto coo_aligned_rows = ExpandIndptrImpl(
output_indptr_aligned, coo_dtype, torch::nullopt, edge_count_aligned);

AT_DISPATCH_INTEGRAL_TYPES(
coo_dtype, "UVAIndexSelectCSCCopyIndicesCOO", ([&] {
using coo_rows_t = scalar_t;
// Perform the actual copying, of the indices array into
// output_indices in an aligned manner.
CUDA_KERNEL_CALL(
_CopyIndicesAlignedKernel, grid, block, 0,
static_cast<indptr_t>(edge_count_aligned_), sliced_indptr,
output_indptr.data_ptr<indptr_t>(), output_indptr_aligned_ptr,
reinterpret_cast<indices_t*>(indices.data_ptr()),
coo_aligned_rows.data_ptr<coo_rows_t>(),
reinterpret_cast<indices_t*>(output_indices.data_ptr()), perm);
}));
return {output_indptr, output_indices};
}

Expand Down
Loading

0 comments on commit e81a62b

Please sign in to comment.