From 124a525e4cf019e49ee47b08377998a4f87e0c9f Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Fri, 1 Mar 2024 23:28:43 +0900 Subject: [PATCH 1/4] Populate torch_dtype from model to pipeline Signed-off-by: B-Step62 --- src/transformers/pipelines/base.py | 10 ++++++++++ tests/pipelines/test_pipelines_common.py | 23 +++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 758484107b76..f8cb67ae122f 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -862,6 +862,16 @@ def __init__( else: self.device = device if device is not None else -1 self.torch_dtype = torch_dtype + + if not self.torch_dtype and is_torch_available(): + # If pipeline dtype is not specified, populate it from the model + # NB: We should only do this when the extracted dtype is not default one (float32), + # because not all models/pipelines support torch_dtype. Here we assume that if the + # model dtype is not float32 it is set by the user with torch_dtype param, so the + # model or pipeline should support it. + if hasattr(model, "dtype") and model.dtype not in (torch.float32, "float32", "torch.float32"): + self.torch_dtype = model.dtype + self.binary_output = binary_output # We shouldn't call `model.to()` for models loaded with accelerate diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 5e3e15f39c10..52c118a296d0 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -199,6 +199,29 @@ def test_unbatch_attentions_hidden_states(self): outputs = text_classifier(["This is great !"] * 20, batch_size=32) self.assertEqual(len(outputs), 20) + @require_torch + def test_torch_dtype_set_to_pipeline(self): + import torch + + # If dtype is specified in the pipeline constructor, it should be set to the pipeline and the model config + pipe = pipeline(model="hf-internal-testing/tiny-random-distilbert", torch_dtype=torch.float16) + self.assertEqual(pipe.torch_dtype, torch.float16) + self.assertEqual(pipe.model.config.torch_dtype, torch.float16) + + # If dtype is not specified, it should be set based on the model config + model = DistilBertForSequenceClassification.from_pretrained( + "hf-internal-testing/tiny-random-distilbert", torch_dtype=torch.bfloat16 + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-distilbert") + pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer) + self.assertEqual(pipe.torch_dtype, torch.bfloat16) + + # If dtype is not specified and not available in the model config, it should be set based + # on the model's parameters dtype + model.config.torch_dtype = None + pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer) + self.assertEqual(pipe.torch_dtype, torch.bfloat16) + @is_pipeline_test class PipelineScikitCompatTest(unittest.TestCase): From e575f35df34d649558a5f0ed0ad4f1d6180e823c Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Sat, 2 Mar 2024 00:09:55 +0900 Subject: [PATCH 2/4] use property Signed-off-by: B-Step62 --- src/transformers/pipelines/base.py | 25 ++++++++++------ tests/pipelines/test_pipelines_common.py | 38 ++++++++++++++---------- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index f8cb67ae122f..a3f2bd366c6a 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -861,16 +861,8 @@ def __init__( raise ValueError(f"{device} unrecognized or not available.") else: self.device = device if device is not None else -1 - self.torch_dtype = torch_dtype - if not self.torch_dtype and is_torch_available(): - # If pipeline dtype is not specified, populate it from the model - # NB: We should only do this when the extracted dtype is not default one (float32), - # because not all models/pipelines support torch_dtype. Here we assume that if the - # model dtype is not float32 it is set by the user with torch_dtype param, so the - # model or pipeline should support it. - if hasattr(model, "dtype") and model.dtype not in (torch.float32, "float32", "torch.float32"): - self.torch_dtype = model.dtype + self._initial_torch_dtype = torch_dtype self.binary_output = binary_output @@ -964,6 +956,21 @@ def predict(self, X): """ return self(X) + @property + def torch_dtype(self): + if hasattr(self.model, "dtype"): + # NB: We extract dtype from the underlying model, but it is possible that the model has dtype + # but the pipeline subclass doesn't support it. In such case we should not return anything, + # but it is not straightforward to detect it in a generic way. Therefore, we assume that the + # pipeline support torch_dtype if (1) the extracted dtype is not default one (float32), or + # (2) the torch_dtype argument was set by the user when creating the pipeline. + if ( + self._initial_torch_dtype is not None + or self.model.dtype not in (torch.float32, "float32", "torch.float32") + ): + return self.model.dtype + return self._initial_torch_dtype + @contextmanager def device_placement(self): """ diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 52c118a296d0..92181ee58e99 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -200,27 +200,35 @@ def test_unbatch_attentions_hidden_states(self): self.assertEqual(len(outputs), 20) @require_torch - def test_torch_dtype_set_to_pipeline(self): + def test_torch_dtype_property(self): import torch + model_id = "hf-internal-testing/tiny-random-distilbert" - # If dtype is specified in the pipeline constructor, it should be set to the pipeline and the model config - pipe = pipeline(model="hf-internal-testing/tiny-random-distilbert", torch_dtype=torch.float16) + # If dtype is specified in the pipeline constructor, the property should return that type + pipe = pipeline(model=model_id, torch_dtype=torch.float16) self.assertEqual(pipe.torch_dtype, torch.float16) - self.assertEqual(pipe.model.config.torch_dtype, torch.float16) - # If dtype is not specified, it should be set based on the model config - model = DistilBertForSequenceClassification.from_pretrained( - "hf-internal-testing/tiny-random-distilbert", torch_dtype=torch.bfloat16 - ) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-distilbert") - pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer) + # If the underlying model changes dtype, the property should return the new type + pipe.model.to(torch.bfloat16) self.assertEqual(pipe.torch_dtype, torch.bfloat16) - # If dtype is not specified and not available in the model config, it should be set based - # on the model's parameters dtype - model.config.torch_dtype = None - pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer) - self.assertEqual(pipe.torch_dtype, torch.bfloat16) + # Even if the model dtype is the default one, we can safely assume the pipeline supports torch_dtype + # as it is constructed with torch_dtype specified + pipe.model.to(torch.float32) + self.assertEqual(pipe.torch_dtype, torch.float32) + + # If dtype is NOT specified in the pipeline constructor, the property should NOT return type + # as we don't know if the pipeline supports torch_dtype + pipe = pipeline(model=model_id) + self.assertEqual(pipe.torch_dtype, None) + + # If the model changes to non default dtype, we assume the pipeline supports torch_dtype + pipe.model.to(torch.float16) + self.assertEqual(pipe.torch_dtype, torch.float16) + + # If the model dtype is the default, we conservatively assume the pipeline doesn't support torch_dtype + pipe.model.to(torch.float32) + self.assertEqual(pipe.torch_dtype, None) @is_pipeline_test From c9216787e40d8e80dcf3cacaa3300186a796c3af Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Sat, 2 Mar 2024 01:02:33 +0900 Subject: [PATCH 3/4] lint Signed-off-by: B-Step62 --- src/transformers/pipelines/base.py | 7 ++++--- tests/pipelines/test_pipelines_common.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index a3f2bd366c6a..959feee21b78 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -964,9 +964,10 @@ def torch_dtype(self): # but it is not straightforward to detect it in a generic way. Therefore, we assume that the # pipeline support torch_dtype if (1) the extracted dtype is not default one (float32), or # (2) the torch_dtype argument was set by the user when creating the pipeline. - if ( - self._initial_torch_dtype is not None - or self.model.dtype not in (torch.float32, "float32", "torch.float32") + if self._initial_torch_dtype is not None or self.model.dtype not in ( + torch.float32, + "float32", + "torch.float32", ): return self.model.dtype return self._initial_torch_dtype diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 92181ee58e99..d6f6f65eb1c4 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -202,6 +202,7 @@ def test_unbatch_attentions_hidden_states(self): @require_torch def test_torch_dtype_property(self): import torch + model_id = "hf-internal-testing/tiny-random-distilbert" # If dtype is specified in the pipeline constructor, the property should return that type From e9ae6c9a46c920438884268ff9fcd9692a3e8655 Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Tue, 12 Mar 2024 00:20:56 +0900 Subject: [PATCH 4/4] Remove default handling Signed-off-by: B-Step62 --- src/transformers/pipelines/base.py | 21 +++++---------------- tests/pipelines/test_pipelines_common.py | 21 ++++++--------------- 2 files changed, 11 insertions(+), 31 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 959feee21b78..079da4980851 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -862,8 +862,6 @@ def __init__( else: self.device = device if device is not None else -1 - self._initial_torch_dtype = torch_dtype - self.binary_output = binary_output # We shouldn't call `model.to()` for models loaded with accelerate @@ -957,20 +955,11 @@ def predict(self, X): return self(X) @property - def torch_dtype(self): - if hasattr(self.model, "dtype"): - # NB: We extract dtype from the underlying model, but it is possible that the model has dtype - # but the pipeline subclass doesn't support it. In such case we should not return anything, - # but it is not straightforward to detect it in a generic way. Therefore, we assume that the - # pipeline support torch_dtype if (1) the extracted dtype is not default one (float32), or - # (2) the torch_dtype argument was set by the user when creating the pipeline. - if self._initial_torch_dtype is not None or self.model.dtype not in ( - torch.float32, - "float32", - "torch.float32", - ): - return self.model.dtype - return self._initial_torch_dtype + def torch_dtype(self) -> Optional["torch.dtype"]: + """ + Torch dtype of the model (if it's Pytorch model), `None` otherwise. + """ + return getattr(self.model, "dtype", None) @contextmanager def device_placement(self): diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index d6f6f65eb1c4..13b97aff3216 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -213,23 +213,14 @@ def test_torch_dtype_property(self): pipe.model.to(torch.bfloat16) self.assertEqual(pipe.torch_dtype, torch.bfloat16) - # Even if the model dtype is the default one, we can safely assume the pipeline supports torch_dtype - # as it is constructed with torch_dtype specified - pipe.model.to(torch.float32) - self.assertEqual(pipe.torch_dtype, torch.float32) - - # If dtype is NOT specified in the pipeline constructor, the property should NOT return type - # as we don't know if the pipeline supports torch_dtype + # If dtype is NOT specified in the pipeline constructor, the property should just return + # the dtype of the underlying model (default) pipe = pipeline(model=model_id) - self.assertEqual(pipe.torch_dtype, None) - - # If the model changes to non default dtype, we assume the pipeline supports torch_dtype - pipe.model.to(torch.float16) - self.assertEqual(pipe.torch_dtype, torch.float16) + self.assertEqual(pipe.torch_dtype, torch.float32) - # If the model dtype is the default, we conservatively assume the pipeline doesn't support torch_dtype - pipe.model.to(torch.float32) - self.assertEqual(pipe.torch_dtype, None) + # If underlying model doesn't have dtype property, simply return None + pipe.model = None + self.assertIsNone(pipe.torch_dtype) @is_pipeline_test