Skip to content

Commit

Permalink
update projector
Browse files Browse the repository at this point in the history
  • Loading branch information
justinpinkney committed Sep 1, 2020
1 parent a4cd2eb commit dbf69a9
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 12 deletions.
5 changes: 2 additions & 3 deletions blend_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@ def main(low_res_pkl: Path, # Pickle file from which to take low res layers

out = blend_models(low_res_Gs, high_res_Gs, resolution, level, blend_width=blend_width, verbose=verbose)

rnd = np.random.RandomState(seed)
grid_latents = rnd.randn(np.prod(grid_size), *out.input_shape[1:])

if output_grid:
rnd = np.random.RandomState(seed)
grid_latents = rnd.randn(np.prod(grid_size), *out.input_shape[1:])
grid_fakes = out.run(grid_latents, None, is_validation=True, minibatch_size=1)
misc.save_image_grid(grid_fakes, output_grid, drange= [-1,1], grid_size=grid_size)

Expand Down
122 changes: 122 additions & 0 deletions project_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# from https://github.com/rolux/stylegan2encoder

import argparse
import os
import shutil
import numpy as np

import dnnlib
import dnnlib.tflib as tflib
import pretrained_networks
import projector
import dataset_tool
from training import dataset
from training import misc


def project_image(proj, src_file, dst_dir, tmp_dir, video=False):

data_dir = '%s/dataset' % tmp_dir
if os.path.exists(data_dir):
shutil.rmtree(data_dir)
image_dir = '%s/images' % data_dir
tfrecord_dir = '%s/tfrecords' % data_dir
os.makedirs(image_dir, exist_ok=True)
shutil.copy(src_file, image_dir + '/')
dataset_tool.create_from_images_raw(tfrecord_dir, image_dir, shuffle=0)
dataset_obj = dataset.load_dataset(
data_dir=data_dir, tfrecord_dir='tfrecords',
max_label_size=0, repeat=False, shuffle_mb=0
)

print('Projecting image "%s"...' % os.path.basename(src_file))
images, _labels = dataset_obj.get_minibatch_np(1)
images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
proj.start(images)
if video:
video_dir = '%s/video' % tmp_dir
os.makedirs(video_dir, exist_ok=True)
while proj.get_cur_step() < proj.num_steps:
print('\r%d / %d ... ' % (proj.get_cur_step(), proj.num_steps), end='', flush=True)
proj.step()
if video:
filename = '%s/%08d.png' % (video_dir, proj.get_cur_step())
misc.save_image_grid(proj.get_images(), filename, drange=[-1,1])
print('\r%-30s\r' % '', end='', flush=True)

os.makedirs(dst_dir, exist_ok=True)
filename = os.path.join(dst_dir, os.path.basename(src_file)[:-4] + '.png')
misc.save_image_grid(proj.get_images(), filename, drange=[-1,1])
filename = os.path.join(dst_dir, os.path.basename(src_file)[:-4] + '.npy')
np.save(filename, proj.get_dlatents()[0])


def render_video(src_file, dst_dir, tmp_dir, num_frames, mode, size, fps, codec, bitrate):

import PIL.Image
import moviepy.editor

def render_frame(t):
frame = np.clip(np.ceil(t * fps), 1, num_frames)
image = PIL.Image.open('%s/video/%08d.png' % (tmp_dir, frame))
if mode == 1:
canvas = image
else:
canvas = PIL.Image.new('RGB', (2 * src_size, src_size))
canvas.paste(src_image, (0, 0))
canvas.paste(image, (src_size, 0))
if size != src_size:
canvas = canvas.resize((mode * size, size), PIL.Image.LANCZOS)
return np.array(canvas)

src_image = PIL.Image.open(src_file)
src_size = src_image.size[1]
duration = num_frames / fps
filename = os.path.join(dst_dir, os.path.basename(src_file)[:-4] + '.mp4')
video_clip = moviepy.editor.VideoClip(render_frame, duration=duration)
video_clip.write_videofile(filename, fps=fps, codec=codec, bitrate=bitrate)


def main():
parser = argparse.ArgumentParser(description='Project real-world images into StyleGAN2 latent space')
parser.add_argument('src_dir', help='Directory with aligned images for projection')
parser.add_argument('dst_dir', help='Output directory')
parser.add_argument('--tmp-dir', default='.stylegan2-tmp', help='Temporary directory for tfrecords and video frames')
parser.add_argument('--network-pkl', default='http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-f.pkl', help='StyleGAN2 network pickle filename')
parser.add_argument('--vgg16-pkl', default='http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/vgg16_zhang_perceptual.pkl', help='VGG16 network pickle filename')
parser.add_argument('--num-steps', type=int, default=1000, help='Number of optimization steps')
parser.add_argument('--initial-learning-rate', type=float, default=0.1, help='Initial learning rate')
parser.add_argument('--initial-noise-factor', type=float, default=0.05, help='Initial noise factor')
parser.add_argument('--verbose', type=bool, default=False, help='Verbose output')
parser.add_argument('--video', type=bool, default=False, help='Render video of the optimization process')
parser.add_argument('--video-mode', type=int, default=1, help='Video mode: 1 for optimization only, 2 for source + optimization')
parser.add_argument('--video-size', type=int, default=1024, help='Video size (height in px)')
parser.add_argument('--video-fps', type=int, default=25, help='Video framerate')
parser.add_argument('--video-codec', default='libx264', help='Video codec')
parser.add_argument('--video-bitrate', default='5M', help='Video bitrate')
args = parser.parse_args()

print('Loading networks from "%s"...' % args.network_pkl)
_G, _D, Gs = pretrained_networks.load_networks(args.network_pkl)
proj = projector.Projector(
vgg16_pkl = args.vgg16_pkl,
num_steps = args.num_steps,
initial_learning_rate = args.initial_learning_rate,
initial_noise_factor = args.initial_noise_factor,
verbose = args.verbose
)
proj.set_network(Gs)

src_files = sorted([os.path.join(args.src_dir, f) for f in os.listdir(args.src_dir) if f[0] not in '._'])
for src_file in src_files:
project_image(proj, src_file, args.dst_dir, args.tmp_dir, video=args.video)
if args.video:
render_video(
src_file, args.dst_dir, args.tmp_dir, args.num_steps, args.video_mode,
args.video_size, args.video_fps, args.video_codec, args.video_bitrate
)
shutil.rmtree(args.tmp_dir)


if __name__ == '__main__':
main()
26 changes: 17 additions & 9 deletions projector.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,24 @@
#----------------------------------------------------------------------------

class Projector:
def __init__(self):
self.num_steps = 1000
def __init__(self,
vgg16_pkl = 'https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2',
num_steps = 1000,
initial_learning_rate = 0.1,
initial_noise_factor = 0.05,
verbose = False
):

self.vgg16_pkl = vgg16_pkl
self.num_steps = num_steps
self.dlatent_avg_samples = 10000
self.initial_learning_rate = 0.1
self.initial_noise_factor = 0.05
self.initial_learning_rate = initial_learning_rate
self.initial_noise_factor = initial_noise_factor
self.lr_rampdown_length = 0.25
self.lr_rampup_length = 0.05
self.noise_ramp_length = 0.75
self.regularize_noise_weight = 1e5
self.verbose = False
self.verbose = verbose
self.clone_net = True

self._Gs = None
Expand Down Expand Up @@ -63,8 +71,8 @@ def set_network(self, Gs, minibatch_size=1):
# Find dlatent stats.
self._info('Finding W midpoint and stddev using %d samples...' % self.dlatent_avg_samples)
latent_samples = np.random.RandomState(123).randn(self.dlatent_avg_samples, *self._Gs.input_shapes[0][1:])
dlatent_samples = self._Gs.components.mapping.run(latent_samples, None)[:, :1, :] # [N, 1, 512]
self._dlatent_avg = np.mean(dlatent_samples, axis=0, keepdims=True) # [1, 1, 512]
dlatent_samples = self._Gs.components.mapping.run(latent_samples, None) # [N, 18, 512]
self._dlatent_avg = np.mean(dlatent_samples, axis=0, keepdims=True) # [1, 18, 512]
self._dlatent_std = (np.sum((dlatent_samples - self._dlatent_avg) ** 2) / self.dlatent_avg_samples) ** 0.5
self._info('std = %g' % self._dlatent_std)

Expand Down Expand Up @@ -92,7 +100,7 @@ def set_network(self, Gs, minibatch_size=1):
self._dlatents_var = tf.Variable(tf.zeros([self._minibatch_size] + list(self._dlatent_avg.shape[1:])), name='dlatents_var')
self._noise_in = tf.placeholder(tf.float32, [], name='noise_in')
dlatents_noise = tf.random.normal(shape=self._dlatents_var.shape) * self._noise_in
self._dlatents_expr = tf.tile(self._dlatents_var + dlatents_noise, [1, self._Gs.components.synthesis.input_shape[1], 1])
self._dlatents_expr = self._dlatents_var + dlatents_noise
self._images_expr = self._Gs.components.synthesis.get_output_for(self._dlatents_expr, randomize_noise=False)

# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
Expand All @@ -106,7 +114,7 @@ def set_network(self, Gs, minibatch_size=1):
self._info('Building loss graph...')
self._target_images_var = tf.Variable(tf.zeros(proc_images_expr.shape), name='target_images_var')
if self._lpips is None:
self._lpips = misc.load_pkl('http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/vgg16_zhang_perceptual.pkl')
self._lpips = misc.load_pkl(self.vgg16_pkl) # vgg16_zhang_perceptual.pkl
self._dist = self._lpips.get_output_for(proc_images_expr, self._target_images_var)
self._loss = tf.reduce_sum(self._dist)

Expand Down

0 comments on commit dbf69a9

Please sign in to comment.