diff --git a/docs/source/conf.py b/docs/source/conf.py index 28eed91..d2805c1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -25,9 +25,9 @@ author = "The Inseq Team" # The short X.Y version -version = "0.6" +version = "0.7" # The full version, including alpha/beta/rc tags -release = "0.6.0" +release = "0.7.0.dev0" # Prefix link to point to master, comment this during version release and uncomment below line diff --git a/inseq/models/attribution_model.py b/inseq/models/attribution_model.py index b96db5c..2f259a4 100644 --- a/inseq/models/attribution_model.py +++ b/inseq/models/attribution_model.py @@ -219,6 +219,7 @@ def __init__(self, **kwargs) -> None: self.pad_token: Optional[str] = None self.embed_scale: Optional[float] = None self._device: Optional[str] = None + self.device_map: Optional[dict[str, Union[str, int, torch.device]]] = None self.attribution_method: Optional[FeatureAttribution] = None self.is_hooked: bool = False self._default_attributed_fn_id: str = "probability" diff --git a/inseq/models/huggingface_model.py b/inseq/models/huggingface_model.py index c7416a5..d6cc3f6 100644 --- a/inseq/models/huggingface_model.py +++ b/inseq/models/huggingface_model.py @@ -127,6 +127,9 @@ def __init__( self.embed_scale = 1.0 self.encoder_int_embeds = None self.decoder_int_embeds = None + self.device_map = None + if hasattr(self.model, "hf_device_map") and self.model.hf_device_map is not None: + self.device_map = self.model.hf_device_map self.is_encoder_decoder = self.model.config.is_encoder_decoder self.configure_embeddings_scale() self.setup(device, attribution_method, **kwargs) @@ -162,16 +165,19 @@ def device(self, new_device: str) -> None: is_loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) is_loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False) is_quantized = is_loaded_in_8bit or is_loaded_in_4bit + has_device_map = self.device_map is not None # Enable compatibility with 8bit models if self.model: - if not is_quantized: - self.model.to(self._device) - else: + if is_quantized: mode = "8bit" if is_loaded_in_8bit else "4bit" logger.warning( f"The model is loaded in {mode} mode. The device cannot be changed after loading the model." ) + elif has_device_map: + logger.warning("The model is loaded with a device map. The device cannot be changed after loading.") + else: + self.model.to(self._device) @abstractmethod def configure_embeddings_scale(self) -> None: diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py index f632ba3..9eb39ab 100644 --- a/inseq/utils/__init__.py +++ b/inseq/utils/__init__.py @@ -10,6 +10,7 @@ ) from .hooks import StackFrame, get_post_variable_assignment_hook from .import_utils import ( + is_accelerate_available, is_captum_available, is_datasets_available, is_ipywidgets_available, @@ -130,4 +131,5 @@ "validate_indices", "pad_with_nan", "recursive_get_submodule", + "is_accelerate_available", ] diff --git a/inseq/utils/import_utils.py b/inseq/utils/import_utils.py index 2a1ccc2..e8ae455 100644 --- a/inseq/utils/import_utils.py +++ b/inseq/utils/import_utils.py @@ -8,6 +8,7 @@ _captum_available = find_spec("captum") is not None _joblib_available = find_spec("joblib") is not None _nltk_available = find_spec("nltk") is not None +_accelerate_available = find_spec("accelerate") is not None def is_ipywidgets_available(): @@ -40,3 +41,7 @@ def is_joblib_available(): def is_nltk_available(): return _nltk_available + + +def is_accelerate_available(): + return _accelerate_available diff --git a/pyproject.toml b/pyproject.toml index 8b0bdd0..32592ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "inseq" -version = "0.6.0" +version = "0.7.0.dev0" description = "Interpretability for Sequence Generation Models 🔍" readme = "README.md" requires-python = ">=3.9"