-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add more flexible output processing and saving.
Output saving is now encapsulated in output postprocessors, which are basically any callables performing some operation on output of Generator network. They can be included in a plugin or used on their own, as in generate.py script.
- Loading branch information
Showing
5 changed files
with
249 additions
and
116 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import torch | ||
import output_postprocess | ||
from torch.autograd import Variable | ||
from utils import * | ||
from argparse import ArgumentParser | ||
from functools import partial | ||
from output_postprocess import * | ||
|
||
|
||
default_params = { | ||
'generator_path': '', | ||
'num_samples': 6, | ||
'postprocessors': [], | ||
'description': 'unknown', | ||
} | ||
|
||
|
||
def output_samples(generator_path, num_samples, postprocessors, description): | ||
G = torch.load(generator_path) | ||
G.cuda() | ||
latent_size = getattr(G, 'latent_size', 512) # yup I just want to use old checkpoints | ||
print('Sampling noise...') | ||
gen_input = Variable(random_latents(num_samples, latent_size)).cuda() | ||
print('Generating...') | ||
output = generate_samples(G, gen_input) | ||
print('Done.') | ||
for proc in postprocessors: | ||
print('Outputting for postprocessor: {}'.format(proc)) | ||
proc(output, description) | ||
print('Done.') | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = ArgumentParser() | ||
needarg_classes = get_all_classes(output_postprocess) | ||
auto_args = create_params(needarg_classes) | ||
# default_params.update(auto_args) | ||
for k in default_params: | ||
parser.add_argument('--{}'.format(k), type=partial(generic_arg_parse, hinttype=type(default_params[k]))) | ||
for cls in auto_args: | ||
for k in auto_args[cls]: | ||
name = '{}.{}'.format(cls, k) | ||
parser.add_argument('--{}'.format(name), type=generic_arg_parse) | ||
default_params[name] = auto_args[cls][k] | ||
parser.set_defaults(**default_params) | ||
params = get_structured_params(vars(parser.parse_args())) | ||
postprocessors = [ globals()[x](**params[x]) for x in params['postprocessors'] ] | ||
output_samples(params['generator_path'], params['num_samples'], postprocessors, params['description']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import os | ||
import utils | ||
import numpy as np | ||
import PIL.Image | ||
import librosa as lbr | ||
from functools import reduce | ||
from utils import adjust_dynamic_range, numpy_upsample_nearest | ||
|
||
|
||
class Postprocessor(object): | ||
|
||
def __init__(self, samples_path='.'): | ||
self.samples_path = samples_path | ||
|
||
|
||
class ImageSaver(Postprocessor): | ||
|
||
output_file_format = 'fakes_{}.png' | ||
|
||
def __init__(self, samples_path='.', drange=(-1,1), resolution=512, create_subdirs=True, img_postprocessors=None): | ||
super(ImageSaver, self).__init__(samples_path) | ||
self.samples_path = samples_path | ||
if create_subdirs: | ||
os.makedirs(self.samples_path, exist_ok=True) | ||
# Setup snapshot image grid. | ||
self.resolution = resolution | ||
self.drange = drange | ||
self.mode = None | ||
self.img_postprocessors = [lambda x: x] if not img_postprocessors else img_postprocessors | ||
|
||
def create_image_grid(self, images): | ||
(count, channels, img_h, img_w) = images.shape | ||
|
||
grid_w = max(int(np.ceil(np.sqrt(count))), 1) | ||
grid_h = max((count - 1) // grid_w + 1, 1) | ||
|
||
grid = np.zeros((channels,) + (grid_h * img_h, grid_w * img_w), dtype=images.dtype) | ||
for i in range(count): | ||
x = (i % grid_w) * img_w | ||
y = (i // grid_w) * img_h | ||
grid[:, y: y + img_h, x: x + img_w] = images[i] | ||
return grid | ||
|
||
def convert_to_pil_image(self, image): | ||
format = 'RGB' | ||
if image.ndim == 3: | ||
if image.shape[0] == 1: | ||
image = image[0] | ||
format = 'L' | ||
else: | ||
image = image.transpose(1, 2, 0) | ||
format = 'RGB' | ||
|
||
# image = reduce(lambda acc, x: x(acc), self.img_postprocessors, image) | ||
image = utils.adjust_dynamic_range(image, self.drange, (0, 255)) | ||
|
||
image = image.round().clip(0, 255).astype(np.uint8) | ||
return PIL.Image.fromarray(image, format) | ||
|
||
def __call__(self, output, description): | ||
if self.resolution is not None: | ||
output = numpy_upsample_nearest(output, 2, size=self.resolution) | ||
im = self.create_image_grid(output) | ||
im = self.convert_to_pil_image(im) | ||
fname = self.output_file_format | ||
if type(description) is int: | ||
fname = fname.format('{:06}') | ||
im.save(os.path.join(self.samples_path, fname.format(description))) | ||
|
||
|
||
class SoundSaver(Postprocessor): | ||
|
||
output_file_format = 'fakes_sound_{}_{}.wav' | ||
|
||
def __init__(self, samples_path='.', drange=(-1, 1), resolution=512, mode='abslog', sample_rate=16000, | ||
hop_length=128, create_subdirs=True, verbose=False): | ||
super(SoundSaver, self).__init__(samples_path) | ||
self.samples_path = samples_path | ||
if create_subdirs: | ||
os.makedirs(self.samples_path, exist_ok=True) | ||
self.drange = drange | ||
self.mode = mode | ||
self.sample_rate = sample_rate | ||
self.hop_length = hop_length | ||
self.verbose = verbose | ||
self.resolution = resolution | ||
|
||
def reconstruct_from_magnitude(self, stft_mag, it=100): | ||
n_fft = (stft_mag.shape[0] - 1) * 2 | ||
x = np.random.randn((stft_mag.shape[1] - 1) * self.hop_length) | ||
for i in range(it): | ||
stft_rec = lbr.stft(x, n_fft=n_fft, hop_length=self.hop_length) | ||
angle = np.angle(stft_rec) | ||
my_stft = stft_mag * np.exp(1.0j * angle) | ||
if self.verbose: # and i == it - 1: | ||
prev_x = x | ||
x = lbr.istft(my_stft, hop_length=self.hop_length) | ||
if self.verbose: # and i == it - 1: | ||
mse = np.sqrt(np.square(x - prev_x).sum()) # logmse would be more appropriate? | ||
print('MSE between sub- and ultimate iteration: {}'.format(mse)) | ||
return x | ||
|
||
def image_to_sound(self, image): | ||
if self.mode == 'reallog' or self.mode == 'abslog': | ||
x = np.zeros((image.shape[0] + 1, image.shape[1])) # real spectrograms have 2**i + 1 freq bins | ||
# x.fill(image.mean()) | ||
x[:image.shape[0], :image.shape[1]] = image | ||
if self.mode == 'reallog': | ||
signed = adjust_dynamic_range(x, self.drange, (-1, 1)) | ||
sgn = np.sign(signed) | ||
real_pt_stft = (np.exp(np.abs(signed)) - 1) * sgn | ||
signal = lbr.istft(real_pt_stft, self.hop_length) | ||
else: | ||
x = adjust_dynamic_range(x, self.drange, (0, 255)) | ||
signal = self.reconstruct_from_magnitude(x) | ||
elif self.mode == 'raw': | ||
signal = image.ravel() | ||
else: | ||
raise Exception( | ||
'image_to_sound: unrecognized mode: {}. Available modes are: reallog, abslog, raw.'.format(self.mode) | ||
) | ||
signal = signal / np.abs(signal).max() | ||
return signal | ||
|
||
def output_wav(self, signal, samples_description, ith): | ||
fname = self.output_file_format | ||
if type(samples_description) is int: | ||
fname = fname.format('{:06}', '{:02}') | ||
else: | ||
fname = fname.format('{}', '{:02}') | ||
try: | ||
lbr.output.write_wav( | ||
os.path.join(self.samples_path, fname.format(samples_description, ith)), | ||
signal, | ||
self.sample_rate, | ||
norm=True | ||
) | ||
except Exception as e: | ||
with open(os.path.join(self.samples_path, 'error_{}_{}.txt'.format(samples_description, ith)), 'w') as f: | ||
f.write('Exception trying to save sound: {}'.format(e)) | ||
|
||
def __call__(self, output, samples_description): | ||
times_smaller = self.resolution // output.shape[-1] | ||
if self.mode == 'raw': | ||
times_smaller *= times_smaller | ||
for i, img in enumerate(output): | ||
signal = self.image_to_sound(img[0]) | ||
signal = numpy_upsample_nearest(signal, 1, scale_factor=times_smaller) | ||
self.output_wav(signal, samples_description, i) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.