-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Enhancements] add a vcude device to help mitigate compile time GPU m…
…emory usage (#302) ## The problem During compilation, some passes needs to replace certain operators with one or more new operators. For example, resolve_variants, subgraph_rewrite. In doing so, hidet relies on `imperative_run` to generate the new operators. This creates a lot of intermediate tensors that can exceed the actual run-time memory consumption, which makes larger models (such as Llama-7B) unable to compile even for a GPU with 24GB using fp16. ## Fix We introduce a vcuda device that allows GPU tensors to be stored on CPU and only transferred to GPU on demand. We call this v`cuda`. With this change, any additional GPU memory usage during compilation is off-loaded to CPU. This might incur a bit compilation overhead when enabled, but given the time-consuming nature of such large models, this compile time increase is negligible. Now, on RTX3090, this is the GPU memory consumption for running llama test ``` Status of cuda:0 memory pool Allocated: 14081 MiB Peak: 14081 MiB Reserved: 1196 MiB Active: 12884 MiB Status of cpu memory pool Allocated: 4 KiB Peak: 26022 MiB Reserved: 4 KiB Active: 0 Bytes Status of vcuda:0 memory pool Allocated: 0 Bytes Peak: 25486 MiB Reserved: 0 Bytes Active: 0 Bytes ``` --------- Co-authored-by: Xin Li <xin@centml.ai>
- Loading branch information
1 parent
1c1cd11
commit a15f5c0
Showing
11 changed files
with
208 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 50 additions & 0 deletions
50
python/hidet/graph/transforms/instruments/convert_flowgraph_to_vcuda.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from hidet.graph.flow_graph import FlowGraph | ||
from hidet.graph.graph_utils.functors import graph_collect | ||
from hidet.cuda.device import properties, current_device | ||
from .base import GraphPassInstrument | ||
|
||
|
||
class ConvertGraphToVCuda(GraphPassInstrument): | ||
def __init__(self, enable: bool): | ||
super().__init__() | ||
self.enable = enable | ||
self.applied = False | ||
|
||
# passes may take up to 2x memory (80% of total memory) | ||
self.threshold = 0.4 | ||
|
||
def before_all_passes(self, graph: FlowGraph): | ||
if not self.should_enable(graph): | ||
return | ||
graph.vcuda_() | ||
self.applied = True | ||
|
||
def after_all_passes(self, graph: FlowGraph) -> None: | ||
if self.applied: | ||
graph.cuda_() | ||
self.applied = False | ||
|
||
def should_enable(self, graph): | ||
from hidet.graph import Tensor | ||
|
||
if self.enable is None: | ||
tensors = graph_collect(graph, Tensor) | ||
graph_size = 0 | ||
for t in tensors: | ||
if t.storage is not None and t.storage.device.is_cuda(): | ||
graph_size += t.storage.num_bytes | ||
|
||
return graph_size / properties(current_device()).totalGlobalMem > self.threshold | ||
else: | ||
return self.enable |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters