Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancements] add a vcude device to help mitigate compile time GPU memory usage #302

Merged
merged 13 commits into from
Jul 5, 2023
48 changes: 48 additions & 0 deletions python/hidet/graph/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,54 @@ def find_all_nodes(u: Operator):

return free_vars, nodes, usage_count

def vcuda_(self) -> None:
"""
casts the flow graph object to vcuda device in place
"""
from hidet.runtime.device import instantiate_device, Device

for x in self.inputs:
if not x.device.is_cuda():
raise ValueError("Inputs must be on cuda device")
x.vcuda_()

for node in self.nodes:
if 'device' in node.attrs:
dev = instantiate_device(node.attrs['device'])
if dev.is_cuda():
dev = Device('vcuda', dev.id)
node.attrs['device'] = dev
for inp in node.inputs:
if inp.device.is_cuda():
inp.vcuda_()
for outp in node.outputs:
if outp.device.is_cuda():
outp.vcuda_()

def cuda_(self) -> None:
"""
casts the flow graph object from vcuda device in place
"""
from hidet.runtime.device import instantiate_device, Device

for x in self.inputs:
if not x.device.is_vcuda():
raise ValueError("Inputs must be on vcuda device")
x.cuda_()

for node in self.nodes:
if 'device' in node.attrs:
dev = instantiate_device(node.attrs['device'])
if dev.is_vcuda():
dev = Device('cuda', dev.id)
node.attrs['device'] = dev
for inp in node.inputs:
if inp.device.is_vcuda():
inp.cuda_()
for outp in node.outputs:
if outp.device.is_vcuda():
outp.cuda_()


def trace_from(tensor: Union[Tensor, List[Tensor]], inputs: Optional[Union[Tensor, List[Tensor]]] = None) -> FlowGraph:
"""
Expand Down
8 changes: 6 additions & 2 deletions python/hidet/graph/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def device(self) -> Device:
if len(self.inputs) == 0:
raise ValueError('Cannot infer device from an operator with no inputs and "device" attribute')
# when the operator has inputs, get the device from the inputs
if not all(t.device == self.inputs[0].device for t in self.inputs):
if not all(t.device.target == self.inputs[0].device.target for t in self.inputs):
raise ValueError('All inputs of an operator must be on the same device')
return self.inputs[0].device

Expand All @@ -84,8 +84,12 @@ def build_target(self) -> str:

if isinstance(self, TransferOp):
return 'cuda'
if self.device.kind in ["cuda", "vcuda"]:
return "cuda"
elif self.device.kind == "cpu":
return "cpu"
else:
return self.device.kind
raise NotImplementedError()

@property
def compiled_task(self) -> CompiledTask:
Expand Down
39 changes: 39 additions & 0 deletions python/hidet/graph/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,45 @@ def cuda(self, device=None):
else:
return transfer(self, device)

def vcuda_(self):
"""Cast the tensor to vcuda device in place.

If the current tensor is already on vcuda device, nothing is performed

Returns
-------
ret: None
This operation is in-place
"""

if self.device.is_vcuda():
return
if not self.device.is_cuda():
raise ValueError("Tensor must be on cuda device, got {}".format(self.device))
# if the tensor has no storage, there is no need to cast
if self.storage is not None:
self._storage = self.storage.vcuda(self.device.id)
self._device = Device('vcuda', self.device.id)

def cuda_(self):
"""Cast the tensor from vcuda device in place.

If the current tensor is already on cuda device, nothing is performed

Returns
-------
ret: None
This operation is in-place
"""
if self.device.is_cuda():
return
if not self.device.is_vcuda():
raise ValueError("Tensor must be on vcuda device, got {}".format(self.device))

if self.storage is not None:
self._storage = self.storage.cuda(self.device.id)
self._device = Device('cuda', self.device.id)

def copy(self) -> Tensor:
"""Create a copy of current tensor.

Expand Down
13 changes: 13 additions & 0 deletions python/hidet/graph/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,19 @@ def profile_pass_instrument(self, log_file: Optional[str] = None, print_stdout:
self.instruments.append(ProfileInstrument(log_file, print_stdout))
return self

def reduce_cuda_compile_mem(self, enable: Optional[bool] = None):
"""
Reduce CUDA memory used during compilation by using vcuda tensors, might incur compile time cost
Parameters
----------
enable: Optional[bool]
When given, will always enable or disable this instrument.
If no argument is given, the compiler will decide to enable this with some heuristics
"""
from .instruments import ConvertGraphToVCuda # pylint: disable=import-outside-toplevel

self.instruments.append(ConvertGraphToVCuda(enable))


class GraphPass:
def __init__(self):
Expand Down
1 change: 1 addition & 0 deletions python/hidet/graph/transforms/instruments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from .base import GraphPassInstrument
from .profile_instrument import ProfileInstrument
from .save_graph_instrument import SaveGraphInstrument
from .convert_flowgraph_to_vcuda import ConvertGraphToVCuda
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.graph.flow_graph import FlowGraph
from hidet.graph.graph_utils.functors import graph_collect
from hidet.cuda.device import properties, current_device
from .base import GraphPassInstrument


class ConvertGraphToVCuda(GraphPassInstrument):
def __init__(self, enable: bool):
super().__init__()
self.enable = enable
self.applied = False

# passes may take up to 2x memory (80% of total memory)
self.threshold = 0.4

def before_all_passes(self, graph: FlowGraph):
if not self.should_enable(graph):
return
graph.vcuda_()
self.applied = True

def after_all_passes(self, graph: FlowGraph) -> None:
if self.applied:
graph.cuda_()
self.applied = False

def should_enable(self, graph):
from hidet.graph import Tensor

if self.enable is None:
tensors = graph_collect(graph, Tensor)
graph_size = 0
for t in tensors:
if t.storage is not None and t.storage.device.is_cuda():
graph_size += t.storage.num_bytes

return graph_size / properties(current_device()).totalGlobalMem > self.threshold
else:
return self.enable
5 changes: 4 additions & 1 deletion python/hidet/runtime/compiled_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,10 @@ def run_async(self, inputs):
_check_inputs(self.meta_data.inputs, inputs)

outputs = self.create_outputs()

candidate = self.candidates[self.pick_best_candidate(inputs, outputs)]
candidate(*inputs, *outputs)

return outputs


Expand Down Expand Up @@ -213,7 +215,8 @@ def _check_inputs(traced_inputs: Iterable[TensorSignature], inputs):

symbol_map = {}
for i, (traced, new) in enumerate(zip(traced_inputs, inputs)):
if traced.device.partition(':')[0] != new.device.kind:
traced_dev_kind = traced.device.partition(':')[0]
if traced_dev_kind != new.device.target:
raise RuntimeError(
f"device mismatch at arg {i} between original: {traced.device} and new: {new.device.kind}"
)
Expand Down
13 changes: 10 additions & 3 deletions python/hidet/runtime/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def is_cpu(self) -> bool:
def is_cuda(self) -> bool:
return self.kind == 'cuda'

def is_vcuda(self) -> bool:
return self.kind == 'vcuda'

@property
def target(self) -> str:
return 'cuda' if self.kind in ['cuda', 'vcuda'] else 'cpu'


def device(device_type: str, device_index: Optional[int] = None):
if ':' in device_type:
Expand All @@ -70,8 +77,8 @@ def device(device_type: str, device_index: Optional[int] = None):
raise ValueError(f'Invalid device_index: {device_index}')
device_index = int(device_index)

if device_type not in ['cpu', 'cuda']:
raise ValueError(f'Invalid device_type: {device_type}, must be "cpu" or "cuda"')
if device_type not in ['cpu', 'cuda', 'vcuda']:
raise ValueError(f'Invalid device_type: {device_type}, must be "cpu" "cuda" or "vcuda"')

if device_index is not None and not isinstance(device_index, int):
raise ValueError(f'Invalid device_index: {device_index}, must be an integer')
Expand Down Expand Up @@ -111,7 +118,7 @@ def instantiate_device(dev) -> Device:
if dev.kind == 'cpu':
dev.id = None # CPU device does not have a device index
return dev
elif dev.kind == 'cuda':
elif dev.kind in ['cuda', 'vcuda']:
if dev.id is None:
dev.id = current_device()
return dev
Expand Down
18 changes: 17 additions & 1 deletion python/hidet/runtime/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,22 @@ def cuda_async(self, dst_id: int, stream: Optional[Stream] = None):
"""
return Storage._convert(self, Device('cuda', dst_id), non_blocking=True, stream=stream)

def vcuda(self, dst_id: int) -> Storage:
"""
Copy the storage to CUDA device. If the storage is already on the device, return itself.

Parameters
----------
dst_id: int
The id of the destination CUDA device.

Returns
-------
ret: Storage
The storage on the destination CUDA device.
"""
return Storage._convert(self, Device('vcuda', dst_id), non_blocking=False)

def copy(self) -> Storage:
"""
Copy the storage to the same device. If the storage is already on the device, return itself.
Expand Down Expand Up @@ -368,7 +384,7 @@ def __getitem__(self, device: Device) -> MemoryPool:
self.device2pool[device] = MemoryPool(
CudaMemoryAPI(device), block_size=4096, max_reserve_size=4 * 1024**3
)
elif device.is_cpu():
elif device.is_cpu() or device.is_vcuda():
self.device2pool[device] = MemoryPool(
CUDAHostMemoryAPI(device), block_size=4096, max_reserve_size=512 * 1024**2
)
Expand Down
7 changes: 5 additions & 2 deletions python/hidet/testing/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def generate(text: str, model, tokenizer, config, num_tokens=20, device='cuda',
outputs = []
for _ in range(num_tokens):
y = model(input_ids, position_ids, *past_keys_values)
input_ids = y[0][:, -1:]
input_ids = y[0][:, -1:].to(dtype=hidet.int32)
outputs.append(input_ids[0, -1].item())
past_keys_values = y[1:]

Expand Down Expand Up @@ -440,6 +440,7 @@ def get_compiled_model(name='decapoda-research/llama-7b-hf', device='cuda', opt=

with torch.device("cuda"): # reduce the time to load the model
model = hfLm.from_pretrained(name, torch_dtype=torch.float16)

model.cpu()
torch.cuda.empty_cache()

Expand All @@ -450,7 +451,9 @@ def get_compiled_model(name='decapoda-research/llama-7b-hf', device='cuda', opt=
flow_graph = build_flow_graph(model, device=device)

if opt:
flow_graph = hidet.graph.optimize(flow_graph)
with hidet.graph.PassContext() as ctx:
ctx.reduce_cuda_compile_mem()
flow_graph = hidet.graph.optimize(flow_graph)

compiled = flow_graph.build()
return compiled, config, tok
Expand Down
20 changes: 15 additions & 5 deletions tests/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,31 @@
# limitations under the License.
import pytest
from hidet.testing.models.llama import get_compiled_model, generate
from hidet.runtime.storage import current_memory_pool


@pytest.mark.skip(reason='This test requires a lot of memory')
def test_llama(device='cuda', opt=False):
# @pytest.mark.parametrize('device,opt', [('cuda', True)])
@pytest.mark.skip(reason='This test requires a lot of CPU memory > 32GB')
def test_llama(device, opt):
model, config, tokenizer = get_compiled_model(device=device, opt=opt)

text = generate('In the beginning was the Word.', model, tokenizer, config, num_tokens=12)
assert text == 'The Word was with God, and the Word was God.'
print(text)
expected = 'The Word was with God, and the Word was God.'
assert text == expected

text = generate(
"A robot may not injure a human being or, through inaction", model, tokenizer, config, num_tokens=55
)
expected = (
', allow a human being to come to harm. A robot must obey the orders given it by human beings'
', allow a human being to come to harm. A robot must obey orders given it by human beings'
' except where such orders would conflict with the First Law. A robot must protect its own'
' existence as long as such protection does not conflict with the First or Second Laws'
' existence as long as such protection does not conflict with the First or Second Laws.'
)

print(text)
assert text == expected

print(current_memory_pool("cuda"))
print(current_memory_pool("cpu"))
print(current_memory_pool("vcuda"))