-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature][Performance][GPU] Introducing UnifiedTensor for efficient z…
…ero-copy host memory access from GPU (#3086) * Add pytorch-direct version * Initial commit of unified tensor * Merge branch 'master' of https://github.com/davidmin7/dgl * Remove unnecessary things * Fix error message * Fix/Add descriptions * whitespace fix * add unpin * disable IndexSelectCPUFromGPU with no CUDA * add a newline for unified_tensor.py * Apply changes based on feedback * add 'os' module * skip unified tensor unit test for cpu only * Update tests/pytorch/test_unified_tensor.py Co-authored-by: xiang song(charlie.song) <classicxsong@gmail.com> * reflect feedback Co-authored-by: shhssdm <shhssdm@gmail.com> Co-authored-by: Jinjing Zhou <VoVAllen@users.noreply.github.com> Co-authored-by: xiang song(charlie.song) <classicxsong@gmail.com>
- Loading branch information
1 parent
7e92318
commit 905c0aa
Showing
16 changed files
with
606 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
190 changes: 190 additions & 0 deletions
190
examples/pytorch/graphsage/train_sampling_unified_tensor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
import dgl | ||
import numpy as np | ||
import torch as th | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import dgl.nn.pytorch as dglnn | ||
import time | ||
import argparse | ||
import tqdm | ||
|
||
from model import SAGE | ||
from load_graph import load_reddit, inductive_split, load_ogb | ||
|
||
def compute_acc(pred, labels): | ||
""" | ||
Compute the accuracy of prediction given the labels. | ||
""" | ||
labels = labels.long() | ||
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) | ||
|
||
def evaluate(model, g, nfeat, labels, val_nid, device): | ||
""" | ||
Evaluate the model on the validation set specified by ``val_nid``. | ||
g : The entire graph. | ||
inputs : The features of all the nodes. | ||
labels : The labels of all the nodes. | ||
val_nid : the node Ids for validation. | ||
device : The GPU device to evaluate on. | ||
""" | ||
model.eval() | ||
with th.no_grad(): | ||
pred = model.inference(g, nfeat, device, args.batch_size, args.num_workers) | ||
model.train() | ||
return compute_acc(pred[val_nid], labels[val_nid].to(pred.device)) | ||
|
||
def load_subtensor(nfeat, labels, seeds, input_nodes, device): | ||
""" | ||
Extracts features and labels for a subset of nodes | ||
""" | ||
batch_inputs = nfeat[input_nodes.to(device)] | ||
batch_labels = labels[seeds].to(device) | ||
return batch_inputs, batch_labels | ||
|
||
#### Entry point | ||
def run(args, device, data): | ||
# Unpack data | ||
n_classes, train_g, val_g, test_g, train_nfeat, train_labels, \ | ||
val_nfeat, val_labels, test_nfeat, test_labels = data | ||
in_feats = train_nfeat.shape[1] | ||
train_nid = th.nonzero(train_g.ndata['train_mask'], as_tuple=True)[0] | ||
val_nid = th.nonzero(val_g.ndata['val_mask'], as_tuple=True)[0] | ||
test_nid = th.nonzero(~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']), as_tuple=True)[0] | ||
|
||
dataloader_device = th.device('cpu') | ||
if args.sample_gpu: | ||
train_nid = train_nid.to(device) | ||
# copy only the csc to the GPU | ||
train_g = train_g.formats(['csc']) | ||
train_g = train_g.to(device) | ||
dataloader_device = device | ||
|
||
|
||
# Create PyTorch DataLoader for constructing blocks | ||
sampler = dgl.dataloading.MultiLayerNeighborSampler( | ||
[int(fanout) for fanout in args.fan_out.split(',')]) | ||
dataloader = dgl.dataloading.NodeDataLoader( | ||
train_g, | ||
train_nid, | ||
sampler, | ||
device=dataloader_device, | ||
batch_size=args.batch_size, | ||
shuffle=True, | ||
drop_last=False, | ||
num_workers=args.num_workers) | ||
|
||
if args.data_cpu: | ||
# Convert input feature tensor to unified tensor | ||
train_nfeat = dgl.contrib.UnifiedTensor(train_nfeat, device=device) | ||
|
||
# Define model and optimizer | ||
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout) | ||
model = model.to(device) | ||
loss_fcn = nn.CrossEntropyLoss() | ||
optimizer = optim.Adam(model.parameters(), lr=args.lr) | ||
|
||
# Training loop | ||
avg = 0 | ||
iter_tput = [] | ||
for epoch in range(args.num_epochs): | ||
tic = time.time() | ||
|
||
# Loop over the dataloader to sample the computation dependency graph as a list of | ||
# blocks. | ||
tic_step = time.time() | ||
for step, (input_nodes, seeds, blocks) in enumerate(dataloader): | ||
# Load the input features as well as output labels | ||
batch_inputs, batch_labels = load_subtensor(train_nfeat, train_labels, | ||
seeds, input_nodes, device) | ||
blocks = [block.int().to(device) for block in blocks] | ||
|
||
# Compute loss and prediction | ||
batch_pred = model(blocks, batch_inputs) | ||
loss = loss_fcn(batch_pred, batch_labels) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
iter_tput.append(len(seeds) / (time.time() - tic_step)) | ||
if step % args.log_every == 0: | ||
acc = compute_acc(batch_pred, batch_labels) | ||
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0 | ||
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB'.format( | ||
epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc)) | ||
tic_step = time.time() | ||
|
||
toc = time.time() | ||
print('Epoch Time(s): {:.4f}'.format(toc - tic)) | ||
if epoch >= 5: | ||
avg += toc - tic | ||
if epoch % args.eval_every == 0 and epoch != 0: | ||
eval_acc = evaluate(model, val_g, val_nfeat, val_labels, val_nid, device) | ||
print('Eval Acc {:.4f}'.format(eval_acc)) | ||
test_acc = evaluate(model, test_g, test_nfeat, test_labels, test_nid, device) | ||
print('Test Acc: {:.4f}'.format(test_acc)) | ||
|
||
print('Avg epoch time: {}'.format(avg / (epoch - 4))) | ||
|
||
if __name__ == '__main__': | ||
argparser = argparse.ArgumentParser() | ||
argparser.add_argument('--gpu', type=int, default=0, | ||
help="GPU device ID. Use -1 for CPU training") | ||
argparser.add_argument('--dataset', type=str, default='reddit') | ||
argparser.add_argument('--num-epochs', type=int, default=20) | ||
argparser.add_argument('--num-hidden', type=int, default=16) | ||
argparser.add_argument('--num-layers', type=int, default=2) | ||
argparser.add_argument('--fan-out', type=str, default='10,25') | ||
argparser.add_argument('--batch-size', type=int, default=1000) | ||
argparser.add_argument('--log-every', type=int, default=20) | ||
argparser.add_argument('--eval-every', type=int, default=5) | ||
argparser.add_argument('--lr', type=float, default=0.003) | ||
argparser.add_argument('--dropout', type=float, default=0.5) | ||
argparser.add_argument('--num-workers', type=int, default=4, | ||
help="Number of sampling processes. Use 0 for no extra process.") | ||
argparser.add_argument('--sample-gpu', action='store_true', | ||
help="Perform the sampling process on the GPU. Must have 0 workers.") | ||
argparser.add_argument('--inductive', action='store_true', | ||
help="Inductive learning setting") | ||
argparser.add_argument('--data-cpu', action='store_true', | ||
help="By default the script puts all node features and labels " | ||
"on GPU when using it to save time for data copy. This may " | ||
"be undesired if they cannot fit in GPU memory at once. " | ||
"Setting this flag makes all node features to be located" | ||
"in the unified tensor instead.") | ||
args = argparser.parse_args() | ||
|
||
if args.gpu >= 0: | ||
device = th.device('cuda:%d' % args.gpu) | ||
else: | ||
device = th.device('cpu') | ||
|
||
if args.dataset == 'reddit': | ||
g, n_classes = load_reddit() | ||
elif args.dataset == 'ogbn-products': | ||
g, n_classes = load_ogb('ogbn-products') | ||
else: | ||
raise Exception('unknown dataset') | ||
|
||
if args.inductive: | ||
train_g, val_g, test_g = inductive_split(g) | ||
train_nfeat = train_g.ndata.pop('features') | ||
val_nfeat = val_g.ndata.pop('features') | ||
test_nfeat = test_g.ndata.pop('features') | ||
train_labels = train_g.ndata.pop('labels') | ||
val_labels = val_g.ndata.pop('labels') | ||
test_labels = test_g.ndata.pop('labels') | ||
else: | ||
train_g = val_g = test_g = g | ||
train_nfeat = val_nfeat = test_nfeat = g.ndata.pop('features') | ||
train_labels = val_labels = test_labels = g.ndata.pop('labels') | ||
|
||
if not args.data_cpu: | ||
train_nfeat = train_nfeat.to(device) | ||
train_labels = train_labels.to(device) | ||
|
||
# Pack data | ||
data = n_classes, train_g, val_g, test_g, train_nfeat, train_labels, \ | ||
val_nfeat, val_labels, test_nfeat, test_labels | ||
|
||
run(args, device, data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
"""Unified Tensor.""" | ||
from .. import backend as F | ||
from .._ffi.function import _init_api | ||
from .. import utils | ||
|
||
|
||
class UnifiedTensor: #UnifiedTensor | ||
'''Class for storing unified tensor. Declaration of | ||
UnifiedTensor automatically pins the input tensor. | ||
Parameters | ||
---------- | ||
input : Tensor | ||
Tensor which we want to convert into the | ||
unified tensor. | ||
device : device | ||
Device to create the mapping of the unified tensor. | ||
''' | ||
|
||
def __init__(self, input, device): | ||
if F.device_type(device) != 'cuda': | ||
raise ValueError("Target device must be a cuda device") | ||
if F.device_type(F.context(input)) != 'cpu': | ||
raise ValueError("Input tensor must be a cpu tensor") | ||
|
||
self._input = input | ||
self._array = F.zerocopy_to_dgl_ndarray(self._input) | ||
self._device = device | ||
|
||
self._array.pin_memory_(utils.to_dgl_context(self._device)) | ||
|
||
def __len__(self): | ||
return len(self._array) | ||
|
||
def __repr__(self): | ||
return self._input.__repr__() | ||
|
||
def __getitem__(self, key): | ||
'''Perform zero-copy access from GPU if the context of | ||
the key is cuda. Otherwise, just safely fallback to the | ||
backend specific indexing scheme. | ||
Parameters | ||
---------- | ||
key : Tensor | ||
Tensor which contains the index ids | ||
''' | ||
if F.device_type(F.context(key)) != 'cuda': | ||
return self._input[key] | ||
else: | ||
return F.zerocopy_from_dgl_ndarray( | ||
_CAPI_DGLIndexSelectCPUFromGPU(self._array, | ||
F.zerocopy_to_dgl_ndarray(key))) | ||
|
||
def __setitem__(self, key, val): | ||
self._input[key] = val | ||
|
||
def __del__(self): | ||
if hasattr(self, '_array') and self._array != None: | ||
self._array.unpin_memory_(utils.to_dgl_context(self._device)) | ||
self._array = None | ||
|
||
if hasattr(self, '_input'): | ||
self._input = None | ||
|
||
@property | ||
def shape(self): | ||
"""Shape of this tensor""" | ||
return self._array.shape | ||
|
||
@property | ||
def dtype(self): | ||
"""Type of this tensor""" | ||
return self._array.dtype | ||
|
||
@property | ||
def device(self): | ||
"""Device of this tensor""" | ||
return self._device | ||
|
||
_init_api("dgl.ndarray.uvm", __name__) |
Oops, something went wrong.