Skip to content

Latest commit

 

History

History
125 lines (94 loc) · 4.18 KB

blend_models.py

File metadata and controls

125 lines (94 loc) · 4.18 KB
 
Aug 22, 2020
Aug 22, 2020
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import tensorflow as tf
import sys, getopt, os
import numpy as np
import dnnlib
import dnnlib.tflib as tflib
from dnnlib.tflib import tfutil
from dnnlib.tflib.autosummary import autosummary
import math
import numpy as np
from training import dataset
from training import misc
import pickle
from pathlib import Path
import typer
from typing import Optional
def extract_conv_names(model):
# layers are G_synthesis/{res}x{res}/...
# make a list of (name, resolution, level, position)
# Currently assuming square(?)
model_names = list(model.trainables.keys())
conv_names = []
resolutions = [4*2**x for x in range(9)]
level_names = [["Conv0_up", "Const"],
["Conv1", "ToRGB"]]
position = 0
# option not to split levels
for res in resolutions:
root_name = f"G_synthesis/{res}x{res}/"
for level, level_suffixes in enumerate(level_names):
for suffix in level_suffixes:
search_name = root_name + suffix
matched_names = [x for x in model_names if x.startswith(search_name)]
to_add = [(name, f"{res}x{res}", level, position) for name in matched_names]
conv_names.extend(to_add)
position += 1
return conv_names
Aug 22, 2020
Aug 22, 2020
48
def blend_models(model_1, model_2, resolution, level, blend_width=None, verbose=False):
Aug 22, 2020
Aug 22, 2020
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# y is the blending amount which y = 0 means all model 1, y = 1 means all model_2
# TODO add small x offset for smoother blend animations
resolution = f"{resolution}x{resolution}"
model_1_names = extract_conv_names(model_1)
model_2_names = extract_conv_names(model_2)
assert all((x == y for x, y in zip(model_1_names, model_2_names)))
model_out = model_1.clone()
short_names = [(x[1:3]) for x in model_1_names]
full_names = [(x[0]) for x in model_1_names]
mid_point_idx = short_names.index((resolution, level))
mid_point_pos = model_1_names[mid_point_idx][3]
ys = []
for name, resolution, level, position in model_1_names:
# low to high (res)
x = position - mid_point_pos
if blend_width:
exponent = -x/blend_width
y = 1 / (1 + math.exp(exponent))
else:
y = 1 if x > 1 else 0
ys.append(y)
Aug 22, 2020
Aug 22, 2020
78
79
if verbose:
print(f"Blending {name} by {y}")
Aug 22, 2020
Aug 22, 2020
80
81
82
83
84
85
86
87
88
89
90
tfutil.set_vars(
tfutil.run(
{model_out.vars[name]: (model_2.vars[name] * y + model_1.vars[name] * (1-y))
for name, y
in zip(full_names, ys)}
)
)
return model_out
Aug 22, 2020
Aug 22, 2020
91
92
93
94
95
96
97
98
99
def main(low_res_pkl: Path, # Pickle file from which to take low res layers
high_res_pkl: Path, # Pickle file from which to take high res layers
resolution: int, # Resolution level at which to switch between models
level: int = 0, # Switch at Conv block 0 or 1?
blend_width: Optional[float] = None, # None = hard switch, float = smooth switch (logistic) with given width
output_grid: Optional[Path] = "blended.jpg", # Path of image file to save example grid (None = don't save)
seed: int = 0, # seed for random grid
output_pkl: Optional[Path] = None, # Output path of pickle (None = don't save)
verbose: bool = False, # Print out the exact blending fraction
Aug 22, 2020
Aug 22, 2020
100
101
102
103
104
105
106
107
108
109
):
grid_size = (3, 3)
tflib.init_tf()
with tf.Session() as sess, tf.device('/gpu:0'):
low_res_G, low_res_D, low_res_Gs = misc.load_pkl(low_res_pkl)
high_res_G, high_res_D, high_res_Gs = misc.load_pkl(high_res_pkl)
Aug 22, 2020
Aug 22, 2020
110
out = blend_models(low_res_Gs, high_res_Gs, resolution, level, blend_width=blend_width, verbose=verbose)
Aug 22, 2020
Aug 22, 2020
111
112
if output_grid:
Sep 1, 2020
Sep 1, 2020
113
114
rnd = np.random.RandomState(seed)
grid_latents = rnd.randn(np.prod(grid_size), *out.input_shape[1:])
Aug 22, 2020
Aug 22, 2020
115
116
117
118
119
120
121
122
123
124
125
grid_fakes = out.run(grid_latents, None, is_validation=True, minibatch_size=1)
misc.save_image_grid(grid_fakes, output_grid, drange= [-1,1], grid_size=grid_size)
# TODO modify all the networks
if output_pkl:
misc.save_pkl((low_res_G, low_res_D, out), output_pkl)
if __name__ == '__main__':
typer.run(main)