Skip to content

Commit

Permalink
add dump checkpoints into main script
Browse files Browse the repository at this point in the history
  • Loading branch information
cvalenzuela committed Jul 3, 2018
1 parent dd678f5 commit 1ec70d8
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
@@ -0,0 +1 @@
.pyc
2 changes: 1 addition & 1 deletion Dockerfile
Expand Up @@ -7,7 +7,7 @@ RUN apt-get install wget -y
RUN apt-get update && apt-get install -y software-properties-common

# Install "ffmpeg"
RUN add-apt-repository ppa:mc3man/trusty-media
RUN add-apt-repository ppa:mc3man/xerus-media
RUN apt-get update && apt-get install -y ffmpeg

# Copy all files in directory
Expand Down
Binary file added checkpoints/scream.ckpt
Binary file not shown.
85 changes: 85 additions & 0 deletions dump_checkpoints.py
@@ -0,0 +1,85 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to dump tensorflow checkpoint variables to deeplearnjs.
This script takes a checkpoint file and writes all of the variables in the
checkpoint to a directory.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import json
import os
import re
import string
import tensorflow as tf

FILENAME_CHARS = string.ascii_letters + string.digits + '_'

def _var_name_to_filename(var_name):
chars = []
for c in var_name:
if c in FILENAME_CHARS:
chars.append(c)
elif c == '/':
chars.append('_')
return ''.join(chars)

def remove_optimizer_variables(output):
vars_dir = os.path.expanduser(output)
manifest_file = os.path.join(output, 'manifest.json')
with open(manifest_file) as f:
manifest = json.load(f)
new_manifest = {key: manifest[key] for key in manifest
if 'Adam' not in key and 'beta' not in key}
with open(manifest_file, 'w') as f:
json.dump(new_manifest, f, indent=2, sort_keys=True)

for name in os.listdir(output):
if 'Adam' in name or 'beta' in name:
os.remove(os.path.join(output, name))

def dump_checkpoints(checkpoint_dir, output):
chk_fpath = os.path.expanduser(checkpoint_dir)
reader = tf.train.NewCheckpointReader(chk_fpath)
var_to_shape_map = reader.get_variable_to_shape_map()
output_dir = os.path.expanduser(output)
tf.gfile.MakeDirs(output_dir)
manifest = {}
remove_vars_compiled_re = re.compile('')

var_filenames_strs = []
for name in var_to_shape_map:
if ('' and
re.match(remove_vars_compiled_re, name)) or name == 'global_step':
continue
var_filename = _var_name_to_filename(name)
manifest[name] = {'filename': var_filename, 'shape': var_to_shape_map[name]}

tensor = reader.get_tensor(name)
with open(os.path.join(output_dir, var_filename), 'wb') as f:
f.write(tensor.tobytes())

var_filenames_strs.append("\"" + var_filename + "\"")

manifest_fpath = os.path.join(output_dir, 'manifest.json')
print('Writing manifest to ' + manifest_fpath)
with open(manifest_fpath, 'w') as f:
f.write(json.dumps(manifest, indent=2, sort_keys=True))

remove_optimizer_variables(output_dir)

File renamed without changes.
1 change: 1 addition & 0 deletions run.sh
Expand Up @@ -2,6 +2,7 @@

python style.py --style images/matildeperez.jpg \
--checkpoint-dir checkpoints/ \
--model-dir models/ \
--test images/violetaparra.jpg \
--test-dir tests/ \
--content-weight 1.5e1 \
Expand Down
12 changes: 10 additions & 2 deletions style.py
Expand Up @@ -6,6 +6,7 @@
from argparse import ArgumentParser
from utils import save_img, get_img, exists, list_files
import evaluate
from dump_checkpoints import dump_checkpoints

CONTENT_WEIGHT = 7.5e0
STYLE_WEIGHT = 1e2
Expand All @@ -17,12 +18,16 @@
CHECKPOINT_ITERATIONS = 2000
VGG_PATH = 'data/imagenet-vgg-verydeep-19.mat'
TRAIN_PATH = 'data/train2014'
MODEL_PATH = 'models'
BATCH_SIZE = 4
DEVICE = '/gpu:0'
FRAC_GPU = 1

def build_parser():
parser = ArgumentParser()
parser.add_argument('--model-dir', type=str,
dest='model_dir', help='dir to save ml5 models in',
metavar='MODELS_DIR', default=MODEL_PATH)
parser.add_argument('--checkpoint-dir', type=str,
dest='checkpoint_dir', help='dir to save checkpoint in',
metavar='CHECKPOINT_DIR', required=True)
Expand Down Expand Up @@ -126,9 +131,9 @@ def main():
"print_iterations":options.checkpoint_iterations,
"batch_size":options.batch_size,
"save_path":os.path.join(options.checkpoint_dir,'fns.ckpt'),
"learning_rate":options.learning_rate
"learning_rate":options.learning_rate,
}

if options.slow:
if options.epochs < 10:
kwargs['epochs'] = 1000
Expand Down Expand Up @@ -162,6 +167,9 @@ def main():
ckpt_dir = options.checkpoint_dir
cmd_text = 'python evaluate.py --checkpoint %s ...' % ckpt_dir
print("Training complete. For evaluation:\n `%s`" % cmd_text)
print('Converting model to ml5js')
dump_checkpoints(kwargs['save_path'], options.model_dir)
print('Done! Checkpoint saved. Visit https://ml5js.org/docs/StyleTransfer for more information')

if __name__ == '__main__':
main()

0 comments on commit 1ec70d8

Please sign in to comment.