Universal Style Transfer via Feature Transforms with TensorFlow & Keras
This is a TensorFlow/Keras implementation of Universal Style Transfer via Feature Transforms by Li et al. The core architecture is an auto-encoder trained to reconstruct from intermediate layers of a pre-trained VGG19 image classification net. Stylization is accomplished by matching the statistics of content/style image features through the Whiten-Color Transform (WCT), which is implemented here in both TensorFlow and NumPy. No style images are used for training, and the WCT allows for 'universal' style transfer for arbitrary content/style image pairs.
As in the original paper, reconstruction decoders for layers
reluX_1 (X=1,2,3,4,5) are trained separately and then hooked up in a multi-level stylization pipeline in a single graph. To reduce memory usage, a single VGG encoder is loaded up to the deepest relu layer and is shared by all decoders.
- Python 3.x
- tensorflow 1.2.1+
- keras 2.0.x
torchfileModified torchfile.py is included that is compatible with Windows
- OpenCV with contrib modules (for
- ffmpeg (for video stylization)
Running a pre-trained model
Download VGG19 model:
Download checkpoints for the five decoders:
Run stylization for live video with
webcam.pyor for images with
stylize.py. Both scripts share the same required arguments. For instance, to run a multi-level stylization pipeline that goes from relu5_1 -> relu4_1 -> relu3_1 -> relu2_1 -> relu1_1:
python webcam.py --checkpoints models/relu5_1 models/relu4_1 models/relu3_1 models/relu2_1 models/relu1_1 --relu-targets relu5_1 relu4_1 relu3_1 relu2_1 relu1_1 --style-size 512 --alpha 0.8 --style-path /path/to/styleimgs
--relu-targets specify space-delimited lists of decoder checkpoint folders and corresponding relu layer targets. The order of relu targets determines the stylization pipeline order, where the output of one encoder/decoder becomes the input for the next. Specifying one checkpoint/relu target will perform single-level stylization.
Other args to take note of:
--style-pathFolder of style images or a single style image
--style-sizeResize small side of style image to this
--crop-sizeIf specified center-crop a square of this size from the (resized) style image
--alpha[0,1] blending of content features + whiten-color transformed features to control degree of stylization
--passes# of times to run the stylization pipeline
--sourceSpecify camera input ID, default 0
--heightSet the size of camera frames
--video-outWrite stylized frames to .mp4 out path
--fpsFrames Per Second for video out
--scaleResize content images by this factor before stylizing
--keep-colorsApply CORAL transform to preserve colors of content
--deviceDevice to perform compute on, default
--concatAppend the style image to the stylized output
--noiseGenerate textures from random noise image instead of webcam
--randomLoad a new random image every # of frames
--adainUse Adaptive Instance Normalization as transfer op instead of WCT
There are also a couple of keyboard shortcuts:
rLoad random image from style folder
wWrite frame to a .png
cToggle color preservation
sToggle style swap (only applied on layer relu5_1)
aToggle AdaIN as transform instead of WCT
qQuit cleanly and close streams
stylize.py will stylize content images and does not require OpenCV. The options are the same as for the webcam script with the addition of
--content-path, which can be a single image file or folder, and
--out-path to specify the output folder. Each style in
--style-path will be applied to each content image.
Running with Docker
Download VGG19 model:
Download checkpoints for the five decoders:
To run the webcam example:
nvidia-docker build -t wct-tf . # It will take several minutes. xhost +local:root nvidia-docker run \ -ti \ --rm \ -v $PWD/models:/usr/src/app/models \ -v $PWD/images:/usr/src/app/images \ -v /tmp/.X11-unix:/tmp/.X11-unix:rw \ -e QT_X11_NO_MITSHM=1 \ -e DISPLAY \ --device=/dev/video0:/dev/video0 \ wct-tf
Download MS COCO images for content data.
Download VGG19 model:
Train one decoder per relu target layer. E.g. to train a decoder to reconstruct from relu3_1:
python train.py --relu-target relu3_1 --content-path /path/to/coco --batch-size 8 --feature-weight 1 --pixel-weight 1 --tv-weight 0 --checkpoint /path/to/checkpointdir --learning-rate 1e-4 --max-iter 15000
Monitor training with TensorBoard:
tensorboard --logdir /path/to/checkpointdir
Style-swap is another style transfer approach from this paper that works by substituting patches in a content encoding with nearest-neighbor patches in a style encoding. As in the official Torch WCT, I have included this as an option for the relu5_1 layer where the feature encodings are small enough for this to be computationally feasible. This option may enhance the stylization effect by transferring local structure from the style image in addition to the overall style.
Note how eyes and noses are transferred to semantically similar locations. Because the visual structure is reconstructed using features found in the style image, regions in the content without style counterparts may have odd replacements (like tongues in the first image).
The style-swap procedure implemented here is:
Encode the content & style images up to relu5_1 and whiten both to remove style information.
Extract patches from the whitened style encoding with tf.extract_image_patches.
Use the (normalized) style patches as conv2d filters to convolve with each spatial patch region in the content encoding. This is an efficient way to compute cross-correlation between all content/style patch pairs.
Find the channel-wise argmax for each spatial position to determine best matching style patch for the location. Replace with a channel-wise one-hot encoding.
For each content patch location, swap in the closest style patch using a transposed convolution over the one-hot encoding with the style patches as filters. The content encoding is now reconstructed using (hopefully) similar structures from the style encoding.
Apply WCT coloring to the style-swapped encoding to add style.
The args to use this with webcam.py and stylize.py:
--swap5Enable style swap. This will only be applied if relu5_1 is one of the target layers.
--ss-patch-sizePatch size for the convolution kernel. This is the size of patches in the feature encoding, not the full size image, so small values like 3 or 5 will typically work well.
--ss-strideStride for the patch kernel. Setting this equal to patch size will extract non-overlapping patches.
--ss-alphaBlending between the style-swapped encoding and the original content encoding.
python webcam.py --checkpoints models/relu5_1 models/relu4_1 models/relu3_1 models/relu2_1 models/relu1_1 --relu-targets relu5_1 relu4_1 relu3_1 relu2_1 relu1_1 --style-size 512 --alpha 0.8 --style-path /path/to/styleimgs --swap5 --ss-patch-size 3 --ss-stride 1 --ss-alpha .7
- This repo is based on my implementation of Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization by Huang et al. The AdaIN op is included here as an alternative transform to WCT. It generally requires multiple stylization passes to achieve a comparable effect.
- The stylization pipeline can be hooked up with decoders in any order. For instance, to reproduce the (sub-optimal) reversed fine-to-coarse pipeline in figure 5(d) from the original paper use the option
--relu-targets relu1_1 relu2_1 relu3_1 relu4_1 relu5_1in webcam.py/stylize.py.
coral.pyimplements CORellation ALignment to transfer colors from the content image to the style image in order to preserve colors in the stylized output. The default method uses NumPy and there is also a commented out version in PyTorch that is slightly faster.
- WCT involves two tf.svd ops, which as of TF r1.4 has a GPU implementation. However, this appears to be 2-4x slower than the CPU version and so is explicitly executed on
/cpu:0in ops.py. See here for an interesting discussion of the issue.
- There is an open issue where for some ill-conditioned matrices the CPU version of tf.svd will ungracefully segfault. Adding a small epsilon to the covariance matrices appears to avoid this without visibly affecting the results. If this issue does occur, there is a commented block that uses np.linalg.svd through tf.py_func. This is stable but incurs a 30%+ performance penalty.
Many thanks to the authors Yijun Li & collaborators at UC Merced/Adobe/NVIDIA for their work that inspired this fun project. After building the first version of this TF implementation I discovered their official Torch implementation that I referred to in tweaking the WCT op to be more stable.
Thanks also to Xun Huang for the normalized VGG and Torch version of CORAL.
Windows is now supported thanks to a torchfile compatibility fix by @xdaimon.
Docker support was graciously provided by @bryant1410.