Skip to content

Commit

Permalink
[Enhancements] add a vcude device to help mitigate compile time GPU m…
Browse files Browse the repository at this point in the history
…emory usage (#302)

## The problem
During compilation, some passes needs to replace certain operators with
one or more new operators. For example, resolve_variants,
subgraph_rewrite. In doing so, hidet relies on `imperative_run` to
generate the new operators. This creates a lot of intermediate tensors
that can exceed the actual run-time memory consumption, which makes
larger models (such as Llama-7B) unable to compile even for a GPU with
24GB using fp16.

## Fix
We introduce a vcuda device that allows GPU tensors to be stored on CPU
and only transferred to GPU on demand. We call this v`cuda`. With this
change, any additional GPU memory usage during compilation is off-loaded
to CPU.

This might incur a bit compilation overhead when enabled, but given the
time-consuming nature of such large models, this compile time increase
is negligible.

Now, on RTX3090, this is the GPU memory consumption for running llama
test

```
Status of cuda:0 memory pool
   Allocated: 14081 MiB
        Peak: 14081 MiB
    Reserved: 1196 MiB
      Active: 12884 MiB
Status of cpu memory pool
   Allocated: 4 KiB
        Peak: 26022 MiB
    Reserved: 4 KiB
      Active: 0 Bytes
Status of vcuda:0 memory pool
   Allocated: 0 Bytes
        Peak: 25486 MiB
    Reserved: 0 Bytes
      Active: 0 Bytes
```

---------

Co-authored-by: Xin Li <xin@centml.ai>
  • Loading branch information
xinli-git and xinli-centml committed Jul 5, 2023
1 parent 1c1cd11 commit a15f5c0
Show file tree
Hide file tree
Showing 11 changed files with 208 additions and 14 deletions.
48 changes: 48 additions & 0 deletions python/hidet/graph/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,54 @@ def find_all_nodes(u: Operator):

return free_vars, nodes, usage_count

def vcuda_(self) -> None:
"""
casts the flow graph object to vcuda device in place
"""
from hidet.runtime.device import instantiate_device, Device

for x in self.inputs:
if not x.device.is_cuda():
raise ValueError("Inputs must be on cuda device")
x.vcuda_()

for node in self.nodes:
if 'device' in node.attrs:
dev = instantiate_device(node.attrs['device'])
if dev.is_cuda():
dev = Device('vcuda', dev.id)
node.attrs['device'] = dev
for inp in node.inputs:
if inp.device.is_cuda():
inp.vcuda_()
for outp in node.outputs:
if outp.device.is_cuda():
outp.vcuda_()

def cuda_(self) -> None:
"""
casts the flow graph object from vcuda device in place
"""
from hidet.runtime.device import instantiate_device, Device

for x in self.inputs:
if not x.device.is_vcuda():
raise ValueError("Inputs must be on vcuda device")
x.cuda_()

for node in self.nodes:
if 'device' in node.attrs:
dev = instantiate_device(node.attrs['device'])
if dev.is_vcuda():
dev = Device('cuda', dev.id)
node.attrs['device'] = dev
for inp in node.inputs:
if inp.device.is_vcuda():
inp.cuda_()
for outp in node.outputs:
if outp.device.is_vcuda():
outp.cuda_()


def trace_from(tensor: Union[Tensor, List[Tensor]], inputs: Optional[Union[Tensor, List[Tensor]]] = None) -> FlowGraph:
"""
Expand Down
8 changes: 6 additions & 2 deletions python/hidet/graph/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def device(self) -> Device:
if len(self.inputs) == 0:
raise ValueError('Cannot infer device from an operator with no inputs and "device" attribute')
# when the operator has inputs, get the device from the inputs
if not all(t.device == self.inputs[0].device for t in self.inputs):
if not all(t.device.target == self.inputs[0].device.target for t in self.inputs):
raise ValueError('All inputs of an operator must be on the same device')
return self.inputs[0].device

Expand Down Expand Up @@ -103,8 +103,12 @@ def build_target(self) -> str:

if isinstance(self, TransferOp):
return 'cuda'
if self.device.kind in ["cuda", "vcuda"]:
return "cuda"
elif self.device.kind == "cpu":
return "cpu"
else:
return self.device.kind
raise NotImplementedError()

@property
def compiled_task(self) -> CompiledTask:
Expand Down
39 changes: 39 additions & 0 deletions python/hidet/graph/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,45 @@ def cuda(self, device=None):
else:
return transfer(self, device)

def vcuda_(self):
"""Cast the tensor to vcuda device in place.
If the current tensor is already on vcuda device, nothing is performed
Returns
-------
ret: None
This operation is in-place
"""

if self.device.is_vcuda():
return
if not self.device.is_cuda():
raise ValueError("Tensor must be on cuda device, got {}".format(self.device))
# if the tensor has no storage, there is no need to cast
if self.storage is not None:
self._storage = self.storage.vcuda(self.device.id)
self._device = Device('vcuda', self.device.id)

def cuda_(self):
"""Cast the tensor from vcuda device in place.
If the current tensor is already on cuda device, nothing is performed
Returns
-------
ret: None
This operation is in-place
"""
if self.device.is_cuda():
return
if not self.device.is_vcuda():
raise ValueError("Tensor must be on vcuda device, got {}".format(self.device))

if self.storage is not None:
self._storage = self.storage.cuda(self.device.id)
self._device = Device('cuda', self.device.id)

def copy(self) -> Tensor:
"""Create a copy of current tensor.
Expand Down
13 changes: 13 additions & 0 deletions python/hidet/graph/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,19 @@ def profile_pass_instrument(self, log_file: Optional[str] = None, print_stdout:
self.instruments.append(ProfileInstrument(log_file, print_stdout))
return self

def reduce_cuda_compile_mem(self, enable: Optional[bool] = None):
"""
Reduce CUDA memory used during compilation by using vcuda tensors, might incur compile time cost
Parameters
----------
enable: Optional[bool]
When given, will always enable or disable this instrument.
If no argument is given, the compiler will decide to enable this with some heuristics
"""
from .instruments import ConvertGraphToVCuda # pylint: disable=import-outside-toplevel

self.instruments.append(ConvertGraphToVCuda(enable))


class GraphPass:
def __init__(self):
Expand Down
1 change: 1 addition & 0 deletions python/hidet/graph/transforms/instruments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from .base import GraphPassInstrument
from .profile_instrument import ProfileInstrument
from .save_graph_instrument import SaveGraphInstrument
from .convert_flowgraph_to_vcuda import ConvertGraphToVCuda
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.graph.flow_graph import FlowGraph
from hidet.graph.graph_utils.functors import graph_collect
from hidet.cuda.device import properties, current_device
from .base import GraphPassInstrument


class ConvertGraphToVCuda(GraphPassInstrument):
def __init__(self, enable: bool):
super().__init__()
self.enable = enable
self.applied = False

# passes may take up to 2x memory (80% of total memory)
self.threshold = 0.4

def before_all_passes(self, graph: FlowGraph):
if not self.should_enable(graph):
return
graph.vcuda_()
self.applied = True

def after_all_passes(self, graph: FlowGraph) -> None:
if self.applied:
graph.cuda_()
self.applied = False

def should_enable(self, graph):
from hidet.graph import Tensor

if self.enable is None:
tensors = graph_collect(graph, Tensor)
graph_size = 0
for t in tensors:
if t.storage is not None and t.storage.device.is_cuda():
graph_size += t.storage.num_bytes

return graph_size / properties(current_device()).totalGlobalMem > self.threshold
else:
return self.enable
5 changes: 4 additions & 1 deletion python/hidet/runtime/compiled_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,10 @@ def run_async(self, inputs):
_check_inputs(self.meta_data.inputs, inputs)

outputs = self.create_outputs()

candidate = self.candidates[self.pick_best_candidate(inputs, outputs)]
candidate(*inputs, *outputs)

return outputs


Expand Down Expand Up @@ -213,7 +215,8 @@ def _check_inputs(traced_inputs: Iterable[TensorSignature], inputs):

symbol_map = {}
for i, (traced, new) in enumerate(zip(traced_inputs, inputs)):
if traced.device.partition(':')[0] != new.device.kind:
traced_dev_kind = traced.device.partition(':')[0]
if traced_dev_kind != new.device.target:
raise RuntimeError(
f"device mismatch at arg {i} between original: {traced.device} and new: {new.device.kind}"
)
Expand Down
13 changes: 10 additions & 3 deletions python/hidet/runtime/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def is_cpu(self) -> bool:
def is_cuda(self) -> bool:
return self.kind == 'cuda'

def is_vcuda(self) -> bool:
return self.kind == 'vcuda'

@property
def target(self) -> str:
return 'cuda' if self.kind in ['cuda', 'vcuda'] else 'cpu'


def device(device_type: str, device_index: Optional[int] = None):
if ':' in device_type:
Expand All @@ -70,8 +77,8 @@ def device(device_type: str, device_index: Optional[int] = None):
raise ValueError(f'Invalid device_index: {device_index}')
device_index = int(device_index)

if device_type not in ['cpu', 'cuda']:
raise ValueError(f'Invalid device_type: {device_type}, must be "cpu" or "cuda"')
if device_type not in ['cpu', 'cuda', 'vcuda']:
raise ValueError(f'Invalid device_type: {device_type}, must be "cpu" "cuda" or "vcuda"')

if device_index is not None and not isinstance(device_index, int):
raise ValueError(f'Invalid device_index: {device_index}, must be an integer')
Expand Down Expand Up @@ -111,7 +118,7 @@ def instantiate_device(dev) -> Device:
if dev.kind == 'cpu':
dev.id = None # CPU device does not have a device index
return dev
elif dev.kind == 'cuda':
elif dev.kind in ['cuda', 'vcuda']:
if dev.id is None:
dev.id = current_device()
return dev
Expand Down
18 changes: 17 additions & 1 deletion python/hidet/runtime/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,22 @@ def cuda_async(self, dst_id: int, stream: Optional[Stream] = None):
"""
return Storage._convert(self, Device('cuda', dst_id), non_blocking=True, stream=stream)

def vcuda(self, dst_id: int) -> Storage:
"""
Copy the storage to CUDA device. If the storage is already on the device, return itself.
Parameters
----------
dst_id: int
The id of the destination CUDA device.
Returns
-------
ret: Storage
The storage on the destination CUDA device.
"""
return Storage._convert(self, Device('vcuda', dst_id), non_blocking=False)

def copy(self) -> Storage:
"""
Copy the storage to the same device. If the storage is already on the device, return itself.
Expand Down Expand Up @@ -368,7 +384,7 @@ def __getitem__(self, device: Device) -> MemoryPool:
self.device2pool[device] = MemoryPool(
CudaMemoryAPI(device), block_size=4096, max_reserve_size=4 * 1024**3
)
elif device.is_cpu():
elif device.is_cpu() or device.is_vcuda():
self.device2pool[device] = MemoryPool(
CUDAHostMemoryAPI(device), block_size=4096, max_reserve_size=512 * 1024**2
)
Expand Down
7 changes: 5 additions & 2 deletions python/hidet/testing/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def generate(text: str, model, tokenizer, config, num_tokens=20, device='cuda',
outputs = []
for _ in range(num_tokens):
y = model(input_ids, position_ids, *past_keys_values)
input_ids = y[0][:, -1:]
input_ids = y[0][:, -1:].to(dtype=hidet.int32)
outputs.append(input_ids[0, -1].item())
past_keys_values = y[1:]

Expand Down Expand Up @@ -440,6 +440,7 @@ def get_compiled_model(name='decapoda-research/llama-7b-hf', device='cuda', opt=

with torch.device("cuda"): # reduce the time to load the model
model = hfLm.from_pretrained(name, torch_dtype=torch.float16)

model.cpu()
torch.cuda.empty_cache()

Expand All @@ -450,7 +451,9 @@ def get_compiled_model(name='decapoda-research/llama-7b-hf', device='cuda', opt=
flow_graph = build_flow_graph(model, device=device)

if opt:
flow_graph = hidet.graph.optimize(flow_graph)
with hidet.graph.PassContext() as ctx:
ctx.reduce_cuda_compile_mem()
flow_graph = hidet.graph.optimize(flow_graph)

compiled = flow_graph.build()
return compiled, config, tok
Expand Down
20 changes: 15 additions & 5 deletions tests/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,31 @@
# limitations under the License.
import pytest
from hidet.testing.models.llama import get_compiled_model, generate
from hidet.runtime.storage import current_memory_pool


@pytest.mark.skip(reason='This test requires a lot of memory')
def test_llama(device='cuda', opt=False):
# @pytest.mark.parametrize('device,opt', [('cuda', True)])
@pytest.mark.skip(reason='This test requires a lot of CPU memory > 32GB')
def test_llama(device, opt):
model, config, tokenizer = get_compiled_model(device=device, opt=opt)

text = generate('In the beginning was the Word.', model, tokenizer, config, num_tokens=12)
assert text == 'The Word was with God, and the Word was God.'
print(text)
expected = 'The Word was with God, and the Word was God.'
assert text == expected

text = generate(
"A robot may not injure a human being or, through inaction", model, tokenizer, config, num_tokens=55
)
expected = (
', allow a human being to come to harm. A robot must obey the orders given it by human beings'
', allow a human being to come to harm. A robot must obey orders given it by human beings'
' except where such orders would conflict with the First Law. A robot must protect its own'
' existence as long as such protection does not conflict with the First or Second Laws'
' existence as long as such protection does not conflict with the First or Second Laws.'
)

print(text)
assert text == expected

print(current_memory_pool("cuda"))
print(current_memory_pool("cpu"))
print(current_memory_pool("vcuda"))

0 comments on commit a15f5c0

Please sign in to comment.