Skip to content

Commit

Permalink
[Feature][Performance][GPU] Introducing UnifiedTensor for efficient z…
Browse files Browse the repository at this point in the history
…ero-copy host memory access from GPU (#3086)

* Add pytorch-direct version

* Initial commit of unified tensor

* Merge branch 'master' of https://github.com/davidmin7/dgl

* Remove unnecessary things

* Fix error message

* Fix/Add descriptions

* whitespace fix

* add unpin

* disable IndexSelectCPUFromGPU with no CUDA

* add a newline for unified_tensor.py

* Apply changes based on feedback

* add 'os' module

* skip unified tensor unit test for cpu only

* Update tests/pytorch/test_unified_tensor.py

Co-authored-by: xiang song(charlie.song) <classicxsong@gmail.com>

* reflect feedback

Co-authored-by: shhssdm <shhssdm@gmail.com>
Co-authored-by: Jinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: xiang song(charlie.song) <classicxsong@gmail.com>
  • Loading branch information
4 people committed Jul 16, 2021
1 parent 7e92318 commit 905c0aa
Show file tree
Hide file tree
Showing 16 changed files with 606 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cmake/modules/CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ macro(dgl_config_cuda out_variable)
file(GLOB_RECURSE DGL_CUDA_SRC
src/array/cuda/*.cc
src/array/cuda/*.cu
src/array/cuda/uvm/*.cc
src/array/cuda/uvm/*.cu
src/kernel/cuda/*.cc
src/kernel/cuda/*.cu
src/partition/cuda/*.cu
Expand Down
190 changes: 190 additions & 0 deletions examples/pytorch/graphsage/train_sampling_unified_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm

from model import SAGE
from load_graph import load_reddit, inductive_split, load_ogb

def compute_acc(pred, labels):
"""
Compute the accuracy of prediction given the labels.
"""
labels = labels.long()
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

def evaluate(model, g, nfeat, labels, val_nid, device):
"""
Evaluate the model on the validation set specified by ``val_nid``.
g : The entire graph.
inputs : The features of all the nodes.
labels : The labels of all the nodes.
val_nid : the node Ids for validation.
device : The GPU device to evaluate on.
"""
model.eval()
with th.no_grad():
pred = model.inference(g, nfeat, device, args.batch_size, args.num_workers)
model.train()
return compute_acc(pred[val_nid], labels[val_nid].to(pred.device))

def load_subtensor(nfeat, labels, seeds, input_nodes, device):
"""
Extracts features and labels for a subset of nodes
"""
batch_inputs = nfeat[input_nodes.to(device)]
batch_labels = labels[seeds].to(device)
return batch_inputs, batch_labels

#### Entry point
def run(args, device, data):
# Unpack data
n_classes, train_g, val_g, test_g, train_nfeat, train_labels, \
val_nfeat, val_labels, test_nfeat, test_labels = data
in_feats = train_nfeat.shape[1]
train_nid = th.nonzero(train_g.ndata['train_mask'], as_tuple=True)[0]
val_nid = th.nonzero(val_g.ndata['val_mask'], as_tuple=True)[0]
test_nid = th.nonzero(~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']), as_tuple=True)[0]

dataloader_device = th.device('cpu')
if args.sample_gpu:
train_nid = train_nid.to(device)
# copy only the csc to the GPU
train_g = train_g.formats(['csc'])
train_g = train_g.to(device)
dataloader_device = device


# Create PyTorch DataLoader for constructing blocks
sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.dataloading.NodeDataLoader(
train_g,
train_nid,
sampler,
device=dataloader_device,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)

if args.data_cpu:
# Convert input feature tensor to unified tensor
train_nfeat = dgl.contrib.UnifiedTensor(train_nfeat, device=device)

# Define model and optimizer
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
model = model.to(device)
loss_fcn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# Training loop
avg = 0
iter_tput = []
for epoch in range(args.num_epochs):
tic = time.time()

# Loop over the dataloader to sample the computation dependency graph as a list of
# blocks.
tic_step = time.time()
for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
# Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(train_nfeat, train_labels,
seeds, input_nodes, device)
blocks = [block.int().to(device) for block in blocks]

# Compute loss and prediction
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, batch_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()

iter_tput.append(len(seeds) / (time.time() - tic_step))
if step % args.log_every == 0:
acc = compute_acc(batch_pred, batch_labels)
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB'.format(
epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc))
tic_step = time.time()

toc = time.time()
print('Epoch Time(s): {:.4f}'.format(toc - tic))
if epoch >= 5:
avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0:
eval_acc = evaluate(model, val_g, val_nfeat, val_labels, val_nid, device)
print('Eval Acc {:.4f}'.format(eval_acc))
test_acc = evaluate(model, test_g, test_nfeat, test_labels, test_nid, device)
print('Test Acc: {:.4f}'.format(test_acc))

print('Avg epoch time: {}'.format(avg / (epoch - 4)))

if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training")
argparser.add_argument('--dataset', type=str, default='reddit')
argparser.add_argument('--num-epochs', type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2)
argparser.add_argument('--fan-out', type=str, default='10,25')
argparser.add_argument('--batch-size', type=int, default=1000)
argparser.add_argument('--log-every', type=int, default=20)
argparser.add_argument('--eval-every', type=int, default=5)
argparser.add_argument('--lr', type=float, default=0.003)
argparser.add_argument('--dropout', type=float, default=0.5)
argparser.add_argument('--num-workers', type=int, default=4,
help="Number of sampling processes. Use 0 for no extra process.")
argparser.add_argument('--sample-gpu', action='store_true',
help="Perform the sampling process on the GPU. Must have 0 workers.")
argparser.add_argument('--inductive', action='store_true',
help="Inductive learning setting")
argparser.add_argument('--data-cpu', action='store_true',
help="By default the script puts all node features and labels "
"on GPU when using it to save time for data copy. This may "
"be undesired if they cannot fit in GPU memory at once. "
"Setting this flag makes all node features to be located"
"in the unified tensor instead.")
args = argparser.parse_args()

if args.gpu >= 0:
device = th.device('cuda:%d' % args.gpu)
else:
device = th.device('cpu')

if args.dataset == 'reddit':
g, n_classes = load_reddit()
elif args.dataset == 'ogbn-products':
g, n_classes = load_ogb('ogbn-products')
else:
raise Exception('unknown dataset')

if args.inductive:
train_g, val_g, test_g = inductive_split(g)
train_nfeat = train_g.ndata.pop('features')
val_nfeat = val_g.ndata.pop('features')
test_nfeat = test_g.ndata.pop('features')
train_labels = train_g.ndata.pop('labels')
val_labels = val_g.ndata.pop('labels')
test_labels = test_g.ndata.pop('labels')
else:
train_g = val_g = test_g = g
train_nfeat = val_nfeat = test_nfeat = g.ndata.pop('features')
train_labels = val_labels = test_labels = g.ndata.pop('labels')

if not args.data_cpu:
train_nfeat = train_nfeat.to(device)
train_labels = train_labels.to(device)

# Pack data
data = n_classes, train_g, val_g, test_g, train_nfeat, train_labels, \
val_nfeat, val_labels, test_nfeat, test_labels

run(args, device, data)
29 changes: 29 additions & 0 deletions include/dgl/aten/macro.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,35 @@
} \
} while (0)

/*
* Dispatch data type only based on bit-width (8-bit, 16-bit, 32-bit, 64-bit):
*
* ATEN_DTYPE_BITS_ONLY_SWITCH(array->dtype, DType, {
* // Now DType is the type which has the same bit-width with the
* // data type in array.
* // Do not use for computation, but only for read and write.
* // For instance, one can do this for a CPU array:
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_DTYPE_BITS_ONLY_SWITCH(val, DType, val_name, ...) do { \
if ((val).bits == 8) { \
typedef int8_t DType; \
{__VA_ARGS__} \
} else if ((val).bits == 16) { \
typedef int16_t DType; \
{__VA_ARGS__} \
} else if ((val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be 8-bit, 16-bit, 32-bit, or 64-bit"; \
} \
} while (0)

/*
* Dispatch according to integral type of CSR graphs.
* Identical to ATEN_ID_TYPE_SWITCH except for a different error message.
Expand Down
10 changes: 10 additions & 0 deletions include/dgl/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,16 @@ DGL_DLL int DGLStreamStreamSynchronize(int device_type,
*/
DGL_DLL int DGLLoadTensorAdapter(const char *path);

/*!
* \brief Pin host memory.
*/
int DGLArrayPinData(DGLArrayHandle handle, DLContext ctx);

/*!
* \brief Unpin host memory.
*/
int DGLArrayUnpinData(DGLArrayHandle handle, DLContext ctx);

/*!
* \brief Bug report macro.
*
Expand Down
18 changes: 18 additions & 0 deletions include/dgl/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,24 @@ class DeviceAPI {
DGL_DLL virtual void SyncStreamFromTo(DGLContext ctx,
DGLStreamHandle event_src,
DGLStreamHandle event_dst);

/*!
* \brief Pin host memory using cudaHostRegister().
*
* \param ctx The context of pinning and mapping.
* \param ptr The host memory pointer to be pinned.
* \param nbytes The size to be pinned.
*/
DGL_DLL virtual void PinData(DGLContext ctx, void* ptr, size_t nbytes);

/*!
* \brief Unpin host memory ussing cudaHostUnregister().
*
* \param ctx The context to unmap and unpin.
* \param ptr The host memory pointer to be unpinned.
*/
DGL_DLL virtual void UnpinData(DGLContext ctx, void* ptr);

/*!
* \brief Allocate temporal workspace for backend execution.
*
Expand Down
20 changes: 20 additions & 0 deletions python/dgl/_ffi/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,26 @@ def copyto(self, target):
raise ValueError("Unsupported target type %s" % str(type(target)))
return target

def pin_memory_(self, ctx):
"""Pin host memory and map into GPU address space (in-place)
Parameters
----------
ctx : DGLContext
The target GPU to map the host memory space
"""
check_call(_LIB.DGLArrayPinData(self.handle, ctx))

def unpin_memory_(self, ctx):
"""Unpin host memory pinned by pin_memory_()
Parameters
----------
ctx : DGLContext
The target GPU to map the host memory space
"""
check_call(_LIB.DGLArrayUnpinData(self.handle, ctx))


def free_extension_handle(handle, type_code):
"""Free c++ extension type handle
Expand Down
1 change: 1 addition & 0 deletions python/dgl/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from . import graph_store
from .dis_kvstore import KVClient, KVServer
from .dis_kvstore import read_ip_config
from .unified_tensor import UnifiedTensor
81 changes: 81 additions & 0 deletions python/dgl/contrib/unified_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Unified Tensor."""
from .. import backend as F
from .._ffi.function import _init_api
from .. import utils


class UnifiedTensor: #UnifiedTensor
'''Class for storing unified tensor. Declaration of
UnifiedTensor automatically pins the input tensor.
Parameters
----------
input : Tensor
Tensor which we want to convert into the
unified tensor.
device : device
Device to create the mapping of the unified tensor.
'''

def __init__(self, input, device):
if F.device_type(device) != 'cuda':
raise ValueError("Target device must be a cuda device")
if F.device_type(F.context(input)) != 'cpu':
raise ValueError("Input tensor must be a cpu tensor")

self._input = input
self._array = F.zerocopy_to_dgl_ndarray(self._input)
self._device = device

self._array.pin_memory_(utils.to_dgl_context(self._device))

def __len__(self):
return len(self._array)

def __repr__(self):
return self._input.__repr__()

def __getitem__(self, key):
'''Perform zero-copy access from GPU if the context of
the key is cuda. Otherwise, just safely fallback to the
backend specific indexing scheme.
Parameters
----------
key : Tensor
Tensor which contains the index ids
'''
if F.device_type(F.context(key)) != 'cuda':
return self._input[key]
else:
return F.zerocopy_from_dgl_ndarray(
_CAPI_DGLIndexSelectCPUFromGPU(self._array,
F.zerocopy_to_dgl_ndarray(key)))

def __setitem__(self, key, val):
self._input[key] = val

def __del__(self):
if hasattr(self, '_array') and self._array != None:
self._array.unpin_memory_(utils.to_dgl_context(self._device))
self._array = None

if hasattr(self, '_input'):
self._input = None

@property
def shape(self):
"""Shape of this tensor"""
return self._array.shape

@property
def dtype(self):
"""Type of this tensor"""
return self._array.dtype

@property
def device(self):
"""Device of this tensor"""
return self._device

_init_api("dgl.ndarray.uvm", __name__)
Loading

0 comments on commit 905c0aa

Please sign in to comment.