Skip to content

Commit

Permalink
Merge pull request #51 from fbcotter/py3support
Browse files Browse the repository at this point in the history
Py3support
  • Loading branch information
lengstrom committed Feb 12, 2017
2 parents 1d84ace + 021750b commit e218db9
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 21 deletions.
7 changes: 3 additions & 4 deletions evaluate.py
Expand Up @@ -4,7 +4,7 @@
import transform, numpy as np, vgg, pdb, os
import scipy.misc
import tensorflow as tf
from utils import save_img, get_img, exists, list_files, check_version
from utils import save_img, get_img, exists, list_files
from argparse import ArgumentParser
from collections import defaultdict
import time
Expand Down Expand Up @@ -238,7 +238,6 @@ def check_opts(opts):
assert opts.batch_size > 0

def main():
check_version()
parser = build_parser()
opts = parser.parse_args()
check_opts(opts)
Expand All @@ -254,8 +253,8 @@ def main():
device=opts.device)
else:
files = list_files(opts.in_path)
full_in = map(lambda x: os.path.join(opts.in_path,x), files)
full_out = map(lambda x: os.path.join(opts.out_path,x), files)
full_in = [os.path.join(opts.in_path,x) for x in files]
full_out = [os.path.join(opts.out_path,x) for x in files]
if opts.allow_different_dimensions:
ffwd_different_dimensions(full_in, full_out, opts.checkpoint_dir,
device_t=opts.device, batch_size=opts.batch_size)
Expand Down
2 changes: 2 additions & 0 deletions setup.sh
@@ -1,3 +1,5 @@
#! /bin/bash

mkdir data
cd data
wget http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat
Expand Down
5 changes: 3 additions & 2 deletions src/optimize.py
@@ -1,4 +1,5 @@
from __future__ import print_function
import functools
import vgg, pdb, time
import tensorflow as tf, numpy as np, os
import transform
Expand Down Expand Up @@ -75,7 +76,7 @@ def optimize(content_targets, style_target, content_weight, style_weight,
style_gram = style_features[style_layer]
style_losses.append(2 * tf.nn.l2_loss(grams - style_gram)/style_gram.size)

style_loss = style_weight * reduce(tf.add, style_losses) / batch_size
style_loss = style_weight * functools.reduce(tf.add, style_losses) / batch_size

# total variation denoising
tv_y_size = _tensor_size(preds[:,1:,:,:])
Expand Down Expand Up @@ -138,4 +139,4 @@ def optimize(content_targets, style_target, content_weight, style_weight,

def _tensor_size(tensor):
from operator import mul
return reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1)
return functools.reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1)
10 changes: 0 additions & 10 deletions src/utils.py
Expand Up @@ -31,13 +31,3 @@ def list_files(in_path):

return files

def check_version():
if sys.version_info[0] != 2:
err_str = (
"This project only supports Python 2! Either run using "
"Python 2 or submit a pull request to "
"https://github.com/lengstrom/fast-style-transfer/ "
"to make the project version neutral!"
)

raise Exception(err_str)
6 changes: 3 additions & 3 deletions style.py
Expand Up @@ -4,7 +4,7 @@
import numpy as np, scipy.misc
from optimize import optimize
from argparse import ArgumentParser
from utils import save_img, get_img, exists, list_files, check_version
from utils import save_img, get_img, exists, list_files
import evaluate

CONTENT_WEIGHT = 7.5e0
Expand Down Expand Up @@ -106,10 +106,10 @@ def check_opts(opts):

def _get_files(img_dir):
files = list_files(img_dir)
return map(lambda x: os.path.join(img_dir,x), files)
return [os.path.join(img_dir,x) for x in files]


def main():
check_version()
parser = build_parser()
options = parser.parse_args()
check_opts(options)
Expand Down
4 changes: 2 additions & 2 deletions transform_video.py
Expand Up @@ -67,8 +67,8 @@ def main():

subprocess.call(" ".join(in_args), shell=True)
base_names = list_files(in_dir)
in_files = list(map(lambda x: os.path.join(in_dir, x), base_names))
out_files = list(map(lambda x: os.path.join(out_dir, x), base_names))
in_files = [os.path.join(in_dir, x) for x in base_names]
out_files = [os.path.join(out_dir, x) for x in base_names]
evaluate.ffwd(in_files, out_files, opts.checkpoint, device_t=opts.device,
batch_size=opts.batch_size)
fr = 30 # wtf
Expand Down

0 comments on commit e218db9

Please sign in to comment.