From 1d98ac760c4d6174897c0b0cf3ac5ef4537f2b68 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Wed, 2 Jun 2021 11:44:36 -0400 Subject: [PATCH] Fix tf v1 gpu import error --- .../notebooks/sparse_quantized_transfer_learning.ipynb | 2 +- src/sparseml/tensorflow_v1/base.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/integrations/pytorch/notebooks/sparse_quantized_transfer_learning.ipynb b/integrations/pytorch/notebooks/sparse_quantized_transfer_learning.ipynb index 6568a10c31b..96b9b9837b7 100644 --- a/integrations/pytorch/notebooks/sparse_quantized_transfer_learning.ipynb +++ b/integrations/pytorch/notebooks/sparse_quantized_transfer_learning.ipynb @@ -454,4 +454,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/src/sparseml/tensorflow_v1/base.py b/src/sparseml/tensorflow_v1/base.py index 04e196d9607..97093e2eb2a 100644 --- a/src/sparseml/tensorflow_v1/base.py +++ b/src/sparseml/tensorflow_v1/base.py @@ -102,7 +102,13 @@ def check_tensorflow_install( raise tensorflow_err return False - return check_version("tensorflow", min_version, max_version, raise_on_error) + return check_version( + "tensorflow", + min_version, + max_version, + raise_on_error, + alternate_package_names=["tensorflow-gpu"], + ) def check_tf2onnx_install(