Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions keras_nlp/api_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import types

from tensorflow import keras

try:
import namex
except ImportError:
namex = None


def maybe_register_serializable(symbol):
if isinstance(symbol, types.FunctionType) or hasattr(symbol, "get_config"):
keras.utils.register_keras_serializable(package="keras_nlp")(symbol)


if namex:

class keras_nlp_export(namex.export):
def __init__(self, path):
super().__init__(package="keras_nlp", path=path)

def __call__(self, symbol):
maybe_register_serializable(symbol)
return super().__call__(symbol)

else:

class keras_nlp_export:
def __init__(self, path):
pass

def __call__(self, symbol):
maybe_register_serializable(symbol)
return symbol
3 changes: 2 additions & 1 deletion keras_nlp/layers/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.utils.keras_utils import clone_initializer

from keras_nlp.layers.transformer_layer_utils import ( # isort:skip
merge_padding_and_attention_mask,
)


@keras.utils.register_keras_serializable(package="keras_nlp")
@keras_nlp_export("keras_nlp.layers.TransformerEncoder")
class TransformerEncoder(keras.layers.Layer):
"""Transformer encoder.

Expand Down
115 changes: 115 additions & 0 deletions pip_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to create (and optionally install) a `.whl` archive for KerasNLP.

Usage:

1. Create a `.whl` file in `dist/`:

```
python3 pip_build.py
```

2. Also install the new package immediately after:

```
python3 pip_build.py --install
```
"""
import argparse
import glob
import os
import pathlib
import shutil

import namex

package = "keras_nlp"
build_directory = "tmp_build_dir"
dist_directory = "dist"
to_copy = [
"setup.py",
"setup.cfg",
"README.md",
]


def build():
if os.path.exists(build_directory):
raise ValueError(f"Directory already exists: {build_directory}")

whl_path = None
try:
# Copy sources (`keras_nlp/` directory and setup files) to build directory
root_path = pathlib.Path(__file__).parent.resolve()
os.chdir(root_path)
os.mkdir(build_directory)
shutil.copytree(package, os.path.join(build_directory, package))
for fname in to_copy:
shutil.copy(fname, os.path.join(f"{build_directory}", fname))
os.chdir(build_directory)

# Restructure the codebase so that source files live in `keras_nlp/src`
namex.convert_codebase(package, code_directory="src")

# Generate API __init__.py files in `keras_nlp/`
namex.generate_api_files(package, code_directory="src", verbose=True)

# Make sure to export the __version__ string
from keras_nlp.src import __version__ # noqa: E402

with open(os.path.join(package, "__init__.py")) as f:
init_contents = f.read()
with open(os.path.join(package, "__init__.py"), "w") as f:
f.write(init_contents + "\n\n" + f'__version__ = "{__version__}"\n')

# Build the package
os.system("python3 -m build")

# Save the dist files generated by the build process
os.chdir(root_path)
if not os.path.exists(dist_directory):
os.mkdir(dist_directory)
for fpath in glob.glob(
os.path.join(build_directory, dist_directory, "*.*")
):
shutil.copy(fpath, dist_directory)

# Find the .whl file path
for fname in os.listdir(dist_directory):
if fname.endswith(".whl"):
whl_path = os.path.abspath(os.path.join(dist_directory, fname))
print(f"Build successful. Wheel file available at {whl_path}")
finally:
# Clean up: remove the build directory (no longer needed)
shutil.rmtree(build_directory)
return whl_path


def install_whl(whl_fpath):
print("Installing wheel file.")
os.system(f"pip3 install {whl_fpath} --force-reinstall --no-dependencies")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--install",
action="store_true",
help="Whether to install the generated wheel file.",
)
args = parser.parse_args()
whl_path = build()
if whl_path and args.install:
install_whl(whl_path)
2 changes: 2 additions & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ flake8
isort
pytest
pytest-cov
build
namex
# Optional deps.
rouge-score
sentencepiece