Skip to content

Commit

Permalink
Merge branch 'master' into gb_cuda_examples2
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jan 16, 2024
2 parents d9843f6 + 053c822 commit 005ff4f
Show file tree
Hide file tree
Showing 12 changed files with 355 additions and 175 deletions.
39 changes: 28 additions & 11 deletions examples/sampling/graphbolt/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
shuffle=is_train,
)

############################################################################
# [Input]:
# 'device': The device to copy the data to.
# [Output]:
# A CopyTo object to copy the data to the specified device. Copying here
# ensures that the rest of the operations run on the GPU.
############################################################################
if args.storage_device != "cpu":
datapipe = datapipe.copy_to(device=args.device)

############################################################################
# [Input]:
# 'args.neg_ratio': Specify the ratio of negative to positive samples.
Expand Down Expand Up @@ -216,7 +226,8 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# [Output]:
# A CopyTo object to copy the data to the specified device.
############################################################################
datapipe = datapipe.copy_to(device=args.device)
if args.storage_device == "cpu":
datapipe = datapipe.copy_to(device=args.device)

############################################################################
# [Input]:
Expand Down Expand Up @@ -304,11 +315,11 @@ def train(args, model, graph, features, train_set):
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
dataloader = create_dataloader(args, graph, features, train_set)

for epoch in tqdm.trange(args.epochs):
for epoch in range(args.epochs):
model.train()
total_loss = 0
start_epoch_time = time.time()
for step, data in enumerate(dataloader):
for step, data in tqdm.tqdm(enumerate(dataloader)):
# Get node pairs with labels for loss calculation.
compacted_pairs, labels = data.node_pairs_with_labels

Expand Down Expand Up @@ -366,24 +377,30 @@ def parse_args():
help="Whether to exclude reverse edges during sampling. Default: 1",
)
parser.add_argument(
"--device",
default="cpu",
choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
"--mode",
default="pinned-cuda",
choices=["cpu-cpu", "cpu-cuda", "pinned-cuda", "cuda-cuda"],
help="Dataset storage placement and Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
)
return parser.parse_args()


def main(args):
if not torch.cuda.is_available():
args.device = "cpu"
print(f"Training in {args.device} mode.")
args.mode = "cpu-cpu"
print(f"Training in {args.mode} mode.")
args.storage_device, args.device = args.mode.split("-")
args.device = torch.device(args.device)

# Load and preprocess dataset.
print("Loading data")
dataset = gb.BuiltinDataset("ogbl-citation2").load()
graph = dataset.graph
features = dataset.feature

# Move the dataset to the selected storage.
graph = dataset.graph.to(args.storage_device)
features = dataset.feature.to(args.storage_device)

train_set = dataset.tasks[0].train_set
args.fanout = list(map(int, args.fanout.split(",")))

Expand Down
48 changes: 32 additions & 16 deletions examples/sampling/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ def create_dataloader(

############################################################################
# [Step-2]:
# self.copy_to()
# [Input]:
# 'device': The device to copy the data to.
# 'extra_attrs': The extra attributes to copy.
# [Output]:
# A CopyTo object to copy the data to the specified device. Copying here
# ensures that the rest of the operations run on the GPU.
############################################################################
if args.storage_device != "cpu":
datapipe = datapipe.copy_to(device=device, extra_attrs=["seed_nodes"])

############################################################################
# [Step-3]:
# self.sample_neighbor()
# [Input]:
# 'graph': The network topology for sampling.
Expand All @@ -109,7 +122,7 @@ def create_dataloader(
)

############################################################################
# [Step-3]:
# [Step-4]:
# self.fetch_feature()
# [Input]:
# 'features': The node features.
Expand All @@ -125,17 +138,18 @@ def create_dataloader(
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])

############################################################################
# [Step-4]:
# [Step-5]:
# self.copy_to()
# [Input]:
# 'device': The device to copy the data to.
# [Output]:
# A CopyTo object to copy the data to the specified device.
############################################################################
datapipe = datapipe.copy_to(device=device)
if args.storage_device == "cpu":
datapipe = datapipe.copy_to(device=device)

############################################################################
# [Step-5]:
# [Step-6]:
# gb.DataLoader()
# [Input]:
# 'datapipe': The datapipe object to be used for data loading.
Expand Down Expand Up @@ -259,7 +273,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
job="evaluate",
)

for step, data in tqdm(enumerate(dataloader)):
for step, data in tqdm(enumerate(dataloader), "Evaluating"):
x = data.node_features["feat"]
y.append(data.labels)
y_hats.append(model(data.blocks, x))
Expand Down Expand Up @@ -289,7 +303,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
t0 = time.time()
model.train()
total_loss = 0
for step, data in enumerate(dataloader):
for step, data in tqdm(enumerate(dataloader), "Training"):
# The input features from the source nodes in the first layer's
# computation graph.
x = data.node_features["feat"]
Expand Down Expand Up @@ -349,28 +363,30 @@ def parse_args():
" identical with the number of layers in your model. Default: 10,10,10",
)
parser.add_argument(
"--device",
default="cpu",
choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
"--mode",
default="pinned-cuda",
choices=["cpu-cpu", "cpu-cuda", "pinned-cuda", "cuda-cuda"],
help="Dataset storage placement and Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
)
return parser.parse_args()


def main(args):
if not torch.cuda.is_available():
args.device = "cpu"
print(f"Training in {args.device} mode.")
args.mode = "cpu-cpu"
print(f"Training in {args.mode} mode.")
args.storage_device, args.device = args.mode.split("-")
args.device = torch.device(args.device)

# Load and preprocess dataset.
print("Loading data...")
dataset = gb.BuiltinDataset("ogbn-products").load()

graph = dataset.graph
# Currently the neighbor-sampling process can only be done on the CPU,
# therefore there is no need to copy the graph to the GPU.
features = dataset.feature
# Move the dataset to the selected storage.
graph = dataset.graph.to(args.storage_device)
features = dataset.feature.to(args.storage_device)

train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
Expand Down
14 changes: 10 additions & 4 deletions examples/sampling/graphbolt/quickstart/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
############################################################################
# (HIGHLIGHT) Create a single process dataloader with dgl graphbolt package.
############################################################################
def create_dataloader(dateset, device, is_train=True):
def create_dataloader(dataset, device, is_train=True):
# The second of two tasks in the dataset is link prediction.
task = dataset.tasks[1]
itemset = task.train_set if is_train else task.test_set

# Sample seed edges from the itemset.
datapipe = gb.ItemSampler(itemset, batch_size=256)

# Copy the mini-batch to the designated device for sampling and training.
datapipe = datapipe.copy_to(device)

if is_train:
# Sample negative edges for the seed edges.
datapipe = datapipe.sample_uniform_negative(
Expand All @@ -47,9 +50,6 @@ def create_dataloader(dateset, device, is_train=True):
dataset.feature, node_feature_keys=["feat"]
)

# Copy the mini-batch to the designated device for training.
datapipe = datapipe.copy_to(device)

# Initiate the dataloader for the datapipe.
return gb.DataLoader(datapipe)

Expand Down Expand Up @@ -158,6 +158,12 @@ def train(model, dataset, device):
print("Loading data...")
dataset = gb.BuiltinDataset("cora").load()

# If a CUDA device is selected, we pin the graph and the features so that
# the GPU can access them.
if device == torch.device("cuda:0"):
dataset.graph.pin_memory_()
dataset.feature.pin_memory_()

in_size = dataset.feature.size("node", None, "feat")[0]
model = GraphSAGE(in_size).to(device)

Expand Down
14 changes: 10 additions & 4 deletions examples/sampling/graphbolt/quickstart/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
############################################################################
# (HIGHLIGHT) Create a single process dataloader with dgl graphbolt package.
############################################################################
def create_dataloader(dateset, itemset, device):
def create_dataloader(dataset, itemset, device):
# Sample seed nodes from the itemset.
datapipe = gb.ItemSampler(itemset, batch_size=16)

# Copy the mini-batch to the designated device for sampling and training.
datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"])

# Sample neighbors for the seed nodes.
datapipe = datapipe.sample_neighbor(dataset.graph, fanouts=[4, 2])

Expand All @@ -25,9 +28,6 @@ def create_dataloader(dateset, itemset, device):
dataset.feature, node_feature_keys=["feat"]
)

# Copy the mini-batch to the designated device for training.
datapipe = datapipe.copy_to(device)

# Initiate the dataloader for the datapipe.
return gb.DataLoader(datapipe)

Expand Down Expand Up @@ -119,6 +119,12 @@ def train(model, dataset, device):
print("Loading data...")
dataset = gb.BuiltinDataset("cora").load()

# If a CUDA device is selected, we pin the graph and the features so that
# the GPU can access them.
if device == torch.device("cuda:0"):
dataset.graph.pin_memory_()
dataset.feature.pin_memory_()

in_size = dataset.feature.size("node", None, "feat")[0]
out_size = dataset.tasks[0].metadata["num_classes"]
model = GCN(in_size, out_size).to(device)
Expand Down
6 changes: 4 additions & 2 deletions examples/sampling/graphbolt/rgcn/hetero_rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,14 +424,15 @@ def evaluate(
else:
evaluator = MAG240MEvaluator()

num_etype = len(g.num_edges)
data_loader = create_dataloader(
name,
g,
features,
item_set,
device,
batch_size=4096,
fanouts=[25, 10],
fanouts=[torch.full((num_etype,), 25), torch.full((num_etype,), 10)],
shuffle=False,
num_workers=num_workers,
)
Expand Down Expand Up @@ -485,14 +486,15 @@ def train(
print("Start to train...")
category = "paper"

num_etype = len(g.num_edges)
data_loader = create_dataloader(
name,
g,
features,
train_set,
device,
batch_size=1024,
fanouts=[25, 10],
fanouts=[torch.full((num_etype,), 25), torch.full((num_etype,), 10)],
shuffle=True,
num_workers=num_workers,
)
Expand Down
16 changes: 15 additions & 1 deletion python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,21 @@ def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
def _to(x):
return x.to(device) if hasattr(x, "to") else x

return self._apply_to_members(_to)
def _pin(x):
return x.pin_memory() if hasattr(x, "pin_memory") else x

# Create a copy of self.
self2 = fused_csc_sampling_graph(
self.csc_indptr,
self.indices,
self.node_type_offset,
self.type_per_edge,
self.node_type_to_id,
self.edge_type_to_id,
self.node_attributes,
self.edge_attributes,
)
return self2._apply_to_members(_pin if device == "pinned" else _to)

def pin_memory_(self):
"""Copy `FusedCSCSamplingGraph` to the pinned memory in-place."""
Expand Down
Loading

0 comments on commit 005ff4f

Please sign in to comment.