Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ def __init__(self, model, config):
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
"If you want to use cuda graph, please upgrade torch to at least v1.10"

# Check if model passed to engine is loaded w/ meta tensors, in which case
# kernel injection must be enabled.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the comment to add why this only works for HF models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a note saying that the device type is sourced assuming a Hugging Face hierarchy.

# NOTE: This check assumes a Hugging Face hierarchy for the device type i.e. module.device.type
self.model_meta_device = self.module.device.type == 'meta' if hasattr(self.module, "device") else False

if self.model_meta_device:
assert config.replace_with_kernel_inject, "Meta tensor support is only available when kernel injection is enabled"

# convert model to intended dtype
if config.dtype:
self._convert_to_dtype(config)
Expand Down
8 changes: 8 additions & 0 deletions deepspeed/module_inject/containers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,19 @@ def create_ds_model_config(self):

return self.ds_model_config

def check_meta_tensor_support(self):
if hasattr(self.qkvw, 'is_meta'):
if self.qkvw.is_meta:
assert self.ckpt_load_enabled, "Meta tensors are not supported for this model currently."
else:
raise NotImplementedError("Meta tensor support is not available, please upgrade to torch 1.10+")

def initialize_tensors(self, enable_training=False):
# Set the tensors from policy (user module) to container (DS module)
self.set_attention(*self.policy.attention(enable_training=enable_training))
self.set_mlp(*self.policy.mlp(enable_training=enable_training))
self.set_layernorm(*self.policy.layernorm())
self.check_meta_tensor_support()

def convert_to_required_dtype(self):
# Note: converting tensors to fp16 requires that we do it in-place using self.__dict__ and not make a list/dict copy
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/module_inject/containers/features/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# DeepSpeed Team

from abc import ABC, abstractmethod
from packaging import version as pkg_version
import torch


class MetaTensorContainer(ABC):
Expand All @@ -14,6 +16,8 @@ class MetaTensorContainer(ABC):
"""

def __init__(self, **kwargs):
if pkg_version.parse('1.10') > pkg_version.parse(torch.__version__):
raise NotImplementedError("Meta tensor support is not available, please upgrade to torch 1.10+")
super().__init__(**kwargs)
self.is_meta = False
self.ckpt_load_enabled = True
Expand Down