Skip to content
Permalink
Browse files

Add dockerfile for model export

  • Loading branch information...
buckhx committed Jul 29, 2019
1 parent 6d3cd27 commit 3ecf35c58c7cbd83fb8ecce1c1462f44f91c25dc
@@ -1,3 +1,3 @@
[submodule "python/bert"]
path = python/bert
[submodule "export/bert"]
path = export/bert
url = git@github.com:google-research/bert.git
@@ -1,44 +1,50 @@
TFLIB=$(shell cd var/lib && pwd)
MOUNT_PATH=$(shell cd var && pwd)
TGO_ENV := LIBRARY_PATH=${LIBRARY_PATH}:${TFLIB} LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${TFLIB} DYLD_LIBRARY_PATH=${DYLD_LIBRARY_PATH}:${TFLIB}
MODEL_PATH ?= var/export/embedding
EXPORT_IMAGE := bert-export
COVERFILE := coverage.out
NUM_LABELS := 2
MODEL ?= bert-base-uncased

check: lint test

clean:
# TODO flexible model
rm -rf python/output/*
rm coverage.out
# TODO python
rm ${COVERFILE}

cover: cover/func

cover/%:
go tool cover -$*=coverage.out

get:
${TGO_ENV} go get ./...

go:
${TGO_ENV} MODEL_PATH=${MODEL_PATH} go run main.go
go tool cover -$*=${COVERFILE}

ex/search:
${TGO_ENV} go run ./examples/semantic-search -seqlen=16 -d='|' ${MODEL_PATH} ./examples/semantic-search/go-faq.csv
${TGO_ENV} go run ./examples/semantic-search -seqlen=16 -d='|' ${MOUNT_PATH}/export/${MODEL} ./examples/semantic-search/go-faq.csv

ex/%:
${TGO_ENV} MODEL_PATH=${MODEL_PATH} go run ./examples/$*

model/classifier:
cd python && python export_embedding.py ${MODEL_PATH} var/export/classifier 2
get:
go get -u golang.org/x/lint/golint
${TGO_ENV} go get ./...

model/embedding:
# TODO flexible model w/ download
cd python && python export_embedding.py ${MODEL_PATH} var/export/embedding
image/export:
cd export && docker build -t ${EXPORT_IMAGE} .

inspect_model/%:
# TODO drop in favor of CMD
python ${TF_ROOT}/tensorflow/python/tools/saved_model_cli.py show --dir=$* --all

lint:
go vet ./...
golint ./...

model: export_image
mkdir -p ${MOUNT_PATH}
docker run -v ${MOUNT_PATH}:/var/bert ${EXPORT_IMAGE} export_embedding.py --download=${MODEL} /var/bert/model /var/bert/export/${MODEL}

model/classifier: export_image
mkdir -p ${MOUNT_PATH}
docker run -v ${MOUNT_PATH}:/var/bert ${EXPORT_IMAGE} export_classifier.py /var/bert/model/${MODEL} /var/bert/export/${MODEL} ${NUM_LABELS}

test:
${TGO_ENV} go test -coverprofile=coverage.out -v ./...
${TGO_ENV} go test -coverprofile=${COVERFILE} -v ./...
@@ -19,7 +19,7 @@ func main() {
tkz := tokenize.NewTokenizer(voc)
ff := tokenize.FeatureFactory{Tokenizer: tkz, SeqLen: 120}
f := ff.Feature("the dog is hairy.")
m, err := tf.LoadSavedModel(modelPath, []string{"bert-untuned"}, nil)
m, err := tf.LoadSavedModel(modelPath, []string{"bert-pretrained"}, nil)
if err != nil {
panic(err)
}
@@ -0,0 +1,126 @@
# Taken from https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
@@ -0,0 +1,8 @@
FROM tensorflow/tensorflow:1.13.2-py3

WORKDIR /workspace/export

COPY . .

ENTRYPOINT ["./entrypoint.sh"]
CMD ["export_embedding.py", "--download=bert-base-uncased", "/var/bert/model", "/var/bert/export/embedding"]
File renamed without changes.
@@ -0,0 +1,75 @@
from __future__ import print_function

import argparse
import os
import os.path
import urllib.request # TODO py2/3 compat
import zipfile

parser = argparse.ArgumentParser()
parser.add_argument("model", help="Name of the model to download")
parser.add_argument("output_path",
help="Path to save model to, model name will be appended")

parser.add_argument("-f", "--force", action='store_true',
help="Force a download, even if model exists")
parser.add_argument("-k", "--keep", action='store_true',
help="Keep the downloaded archive")

BERT_PRETRAINED_MODEL_URLS = {
'bert-base-uncased': "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip", # noqa
'bert-large-uncased': "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip", # noqa
'bert-base-cased': "https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip", # noqa
'bert-large-cased': "https://storage.googleapis.com/bert_models/2018_10_18/cased_L-24_H-1024_A-16.zip", # noqa
'bert-base-multilingual-cased': "https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip", # noqa
'bert-base-chinese': "https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip", # noqa
'bert-large-uncased-wwm': "https://storage.googleapis.com/bert_models/2019_05_30/wwm_uncased_L-24_H-1024_A-16.zip", # noqa
'bert-large-cased-wwm': "https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip", # noqa
}


def download(model, output_path, force=False, keep=False):
url = BERT_PRETRAINED_MODEL_URLS.get(model, None)
if not url:
print("Invalid Model Name:", model)
for k in BERT_PRETRAINED_MODEL_URLS.keys():
print("Valid Model Names:")
print("\t", k)
exit(1)
path = os.path.join(output_path, model)
if os.path.exists(path):
if not force:
print("Model Already Exists:", path)
print("If desired, use --force to overwrite")
return
os.makedirs(path, exist_ok=True) # TODO (py3.2+) py2 compat
print("Downloading and extracting model:", model)
print("Downloading archive:", url)
archive = os.path.join(path, os.path.basename(url))
urllib.request.urlretrieve(url, archive)
print("Extracting archive to path:", path)
_unzip(archive, path)
if not keep:
print("Removing archive:", archive)
os.remove(archive)
print("Extracted Model", model, "to", path)
print("Done.")


def _unzip(archive, dst):
""" unzips archive and flattens directory structure """
with zipfile.ZipFile(archive) as zf:
for zi in zf.infolist():
if zi.filename[-1] == '/': # skip dir
continue
zi.filename = os.path.basename(zi.filename)
zf.extract(zi, dst)


if __name__ == '__main__':
args = parser.parse_args()
model = args.model
output_path = args.output_path
force = args.force
keep = args.keep
download(model, output_path, force, keep)
@@ -0,0 +1,3 @@
#!/bin/sh

python $@
File renamed without changes.
@@ -3,7 +3,7 @@
from __future__ import absolute_import
from __future__ import print_function
import sys
sys.path.insert(0,'bert')
sys.path.insert(0, 'bert') # noqa

import os.path
import argparse
@@ -12,18 +12,26 @@

import bert.modeling as modeling
from util import export
from download_pretrained import download


parser = argparse.ArgumentParser()
parser.add_argument("model_path", help="Path for pre-trained BERT model")
parser.add_argument("export_path", help="Path to export to")

parser.add_argument("--bert_config_path", help="If bert_config is not in"
"model_path/bert_config.json, specify its path here")
"model_path/bert_config.json, specify its path here")
parser.add_argument("--download", help="Download pretrained model by name"
"Model be saved in model_path with name appened")


def export_embedding(args):
untuned_name = os.path.basename(os.path.normpath(args.model_path))
# pretrained_name = os.path.basename(os.path.normpath(args.model_path))
dl = args.download
if dl:
download(dl, args.model_path)
args.model_path = os.path.join(args.model_path, dl)
print("Model Path updated:", args.model_path)
config_path = os.path.join(args.model_path, "bert_config.json")
if args.bert_config_path:
config_path = args.bert_config_path
@@ -32,7 +40,7 @@ def transfer():
bert_config = modeling.BertConfig.from_json_file(config_path)
# Inputs
# TODO shapes
#unique_ids = tf.compat.v1.placeholder(tf.int32, (None), 'unique_ids')
# unique_ids = tf.compat.v1.placeholder(tf.int32, (None), 'unique_ids')
input_ids = tf.compat.v1.placeholder(tf.int32, (None, None), 'input_ids')
input_mask = tf.compat.v1.placeholder(tf.int32, (None, None), 'input_mask')
segment_ids = tf.compat.v1.placeholder(tf.int32, (None, None), 'input_type_ids')
@@ -53,17 +61,17 @@ def transfer():
output = masker(layers, mask)
embedding = tf.identity(output, 'embedding')
return {
# 'unique_ids': unique_ids,
# 'unique_ids': unique_ids,
'input_ids': input_ids,
'input_mask': input_mask,
'input_type_ids': segment_ids
}, {
# "feature_ids": unique_ids,
# "feature_ids": unique_ids,
"embedding": embedding
}
export(args.model_path, args.export_path, transfer,
method_name="bert/untuned/embedding",
sig_name="embedding", tags=["bert-untuned"]) #, untuned_name])
method_name="bert/pretrained/embedding",
sig_name="embedding", tags=["bert-pretrained"])


if __name__ == '__main__':

0 comments on commit 3ecf35c

Please sign in to comment.
You can’t perform that action at this time.