diff --git a/requirements.txt b/requirements.txt index 7fbd0834fa..cc2d563f8b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ Pillow>=8.3.2 tqdm>=4.30.0 tensorflow-addons>=0.13.0 rapidfuzz>=1.6.0 +keras<2.7.0 diff --git a/setup.py b/setup.py index f34223b5d8..6431e92ad0 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ "tqdm>=4.30.0", "tensorflow-addons>=0.13.0", "rapidfuzz>=1.6.0", + "keras<2.7.0", ] deps = {b: a for a, b in (re.findall(r"^(([^!=<>]+)(?:[!=<>].*)?$)", x)[0] for x in _deps)} @@ -86,8 +87,8 @@ def deps_list(*pkgs): ] extras = {} -extras["tf"] = deps_list("tensorflow", "tensorflow-addons") -extras["tf-cpu"] = deps_list("tensorflow-cpu", "tensorflow-addons") +extras["tf"] = deps_list("tensorflow", "tensorflow-addons", "keras") +extras["tf-cpu"] = deps_list("tensorflow-cpu", "tensorflow-addons", "keras") extras["torch"] = deps_list("torch", "torchvision") extras["all"] = ( extras["tf"]