Skip to content

Commit

Permalink
Add more flexible output processing and saving.
Browse files Browse the repository at this point in the history
Output saving is now encapsulated in output postprocessors, which are
basically any callables performing some operation on output of Generator
network. They can be included in a plugin or used on their own, as in
generate.py script.
  • Loading branch information
Michalaq committed Jan 27, 2018
1 parent 7c3c446 commit aea3646
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 116 deletions.
48 changes: 48 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import output_postprocess
from torch.autograd import Variable
from utils import *
from argparse import ArgumentParser
from functools import partial
from output_postprocess import *


default_params = {
'generator_path': '',
'num_samples': 6,
'postprocessors': [],
'description': 'unknown',
}


def output_samples(generator_path, num_samples, postprocessors, description):
G = torch.load(generator_path)
G.cuda()
latent_size = getattr(G, 'latent_size', 512) # yup I just want to use old checkpoints
print('Sampling noise...')
gen_input = Variable(random_latents(num_samples, latent_size)).cuda()
print('Generating...')
output = generate_samples(G, gen_input)
print('Done.')
for proc in postprocessors:
print('Outputting for postprocessor: {}'.format(proc))
proc(output, description)
print('Done.')


if __name__ == '__main__':
parser = ArgumentParser()
needarg_classes = get_all_classes(output_postprocess)
auto_args = create_params(needarg_classes)
# default_params.update(auto_args)
for k in default_params:
parser.add_argument('--{}'.format(k), type=partial(generic_arg_parse, hinttype=type(default_params[k])))
for cls in auto_args:
for k in auto_args[cls]:
name = '{}.{}'.format(cls, k)
parser.add_argument('--{}'.format(name), type=generic_arg_parse)
default_params[name] = auto_args[cls][k]
parser.set_defaults(**default_params)
params = get_structured_params(vars(parser.parse_args()))
postprocessors = [ globals()[x](**params[x]) for x in params['postprocessors'] ]
output_samples(params['generator_path'], params['num_samples'], postprocessors, params['description'])
149 changes: 149 additions & 0 deletions output_postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os
import utils
import numpy as np
import PIL.Image
import librosa as lbr
from functools import reduce
from utils import adjust_dynamic_range, numpy_upsample_nearest


class Postprocessor(object):

def __init__(self, samples_path='.'):
self.samples_path = samples_path


class ImageSaver(Postprocessor):

output_file_format = 'fakes_{}.png'

def __init__(self, samples_path='.', drange=(-1,1), resolution=512, create_subdirs=True, img_postprocessors=None):
super(ImageSaver, self).__init__(samples_path)
self.samples_path = samples_path
if create_subdirs:
os.makedirs(self.samples_path, exist_ok=True)
# Setup snapshot image grid.
self.resolution = resolution
self.drange = drange
self.mode = None
self.img_postprocessors = [lambda x: x] if not img_postprocessors else img_postprocessors

def create_image_grid(self, images):
(count, channels, img_h, img_w) = images.shape

grid_w = max(int(np.ceil(np.sqrt(count))), 1)
grid_h = max((count - 1) // grid_w + 1, 1)

grid = np.zeros((channels,) + (grid_h * img_h, grid_w * img_w), dtype=images.dtype)
for i in range(count):
x = (i % grid_w) * img_w
y = (i // grid_w) * img_h
grid[:, y: y + img_h, x: x + img_w] = images[i]
return grid

def convert_to_pil_image(self, image):
format = 'RGB'
if image.ndim == 3:
if image.shape[0] == 1:
image = image[0]
format = 'L'
else:
image = image.transpose(1, 2, 0)
format = 'RGB'

# image = reduce(lambda acc, x: x(acc), self.img_postprocessors, image)
image = utils.adjust_dynamic_range(image, self.drange, (0, 255))

image = image.round().clip(0, 255).astype(np.uint8)
return PIL.Image.fromarray(image, format)

def __call__(self, output, description):
if self.resolution is not None:
output = numpy_upsample_nearest(output, 2, size=self.resolution)
im = self.create_image_grid(output)
im = self.convert_to_pil_image(im)
fname = self.output_file_format
if type(description) is int:
fname = fname.format('{:06}')
im.save(os.path.join(self.samples_path, fname.format(description)))


class SoundSaver(Postprocessor):

output_file_format = 'fakes_sound_{}_{}.wav'

def __init__(self, samples_path='.', drange=(-1, 1), resolution=512, mode='abslog', sample_rate=16000,
hop_length=128, create_subdirs=True, verbose=False):
super(SoundSaver, self).__init__(samples_path)
self.samples_path = samples_path
if create_subdirs:
os.makedirs(self.samples_path, exist_ok=True)
self.drange = drange
self.mode = mode
self.sample_rate = sample_rate
self.hop_length = hop_length
self.verbose = verbose
self.resolution = resolution

def reconstruct_from_magnitude(self, stft_mag, it=100):
n_fft = (stft_mag.shape[0] - 1) * 2
x = np.random.randn((stft_mag.shape[1] - 1) * self.hop_length)
for i in range(it):
stft_rec = lbr.stft(x, n_fft=n_fft, hop_length=self.hop_length)
angle = np.angle(stft_rec)
my_stft = stft_mag * np.exp(1.0j * angle)
if self.verbose: # and i == it - 1:
prev_x = x
x = lbr.istft(my_stft, hop_length=self.hop_length)
if self.verbose: # and i == it - 1:
mse = np.sqrt(np.square(x - prev_x).sum()) # logmse would be more appropriate?
print('MSE between sub- and ultimate iteration: {}'.format(mse))
return x

def image_to_sound(self, image):
if self.mode == 'reallog' or self.mode == 'abslog':
x = np.zeros((image.shape[0] + 1, image.shape[1])) # real spectrograms have 2**i + 1 freq bins
# x.fill(image.mean())
x[:image.shape[0], :image.shape[1]] = image
if self.mode == 'reallog':
signed = adjust_dynamic_range(x, self.drange, (-1, 1))
sgn = np.sign(signed)
real_pt_stft = (np.exp(np.abs(signed)) - 1) * sgn
signal = lbr.istft(real_pt_stft, self.hop_length)
else:
x = adjust_dynamic_range(x, self.drange, (0, 255))
signal = self.reconstruct_from_magnitude(x)
elif self.mode == 'raw':
signal = image.ravel()
else:
raise Exception(
'image_to_sound: unrecognized mode: {}. Available modes are: reallog, abslog, raw.'.format(self.mode)
)
signal = signal / np.abs(signal).max()
return signal

def output_wav(self, signal, samples_description, ith):
fname = self.output_file_format
if type(samples_description) is int:
fname = fname.format('{:06}', '{:02}')
else:
fname = fname.format('{}', '{:02}')
try:
lbr.output.write_wav(
os.path.join(self.samples_path, fname.format(samples_description, ith)),
signal,
self.sample_rate,
norm=True
)
except Exception as e:
with open(os.path.join(self.samples_path, 'error_{}_{}.txt'.format(samples_description, ith)), 'w') as f:
f.write('Exception trying to save sound: {}'.format(e))

def __call__(self, output, samples_description):
times_smaller = self.resolution // output.shape[-1]
if self.mode == 'raw':
times_smaller *= times_smaller
for i, img in enumerate(output):
signal = self.image_to_sound(img[0])
signal = numpy_upsample_nearest(signal, 1, scale_factor=times_smaller)
self.output_wav(signal, samples_description, i)
17 changes: 13 additions & 4 deletions pggan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from trainer import Trainer
import dataset
from dataset import *
import output_postprocess
from output_postprocess import *
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler
from plugins import *
Expand Down Expand Up @@ -42,6 +44,7 @@
save_dataset='',
load_dataset='',
dataset_class='',
postprocessors=[]
)


Expand Down Expand Up @@ -167,7 +170,12 @@ def rampup(cur_nimg):
for i, loss_name in enumerate(losses):
trainer.register_plugin(EfficientLossMonitor(i, loss_name))
trainer.register_plugin(SaverPlugin(result_dir, True, params['network_snapshot_ticks']))
trainer.register_plugin(SampleGenerator(result_dir, lambda x: random_latents(x, latent_size), **params['SampleGenerator']))

def subsitute_samples_path(d):
return {k:(os.path.join(result_dir, v) if k == 'samples_path' else v) for k,v in d.items()}
postprocessors = [ globals()[x](**subsitute_samples_path(params[x])) for x in params['postprocessors'] ]
trainer.register_plugin(OutputGenerator(lambda x: random_latents(x, latent_size),
postprocessors, **params['OutputGenerator']))
trainer.register_plugin(AbsoluteTimeMonitor(params['resume_time']))
trainer.register_plugin(LRScheduler(lr_scheduler_d, lr_scheduler_g))
trainer.register_plugin(logger)
Expand All @@ -178,18 +186,19 @@ def rampup(cur_nimg):

if __name__ == "__main__":
parser = ArgumentParser()
needarg_classes = [Trainer, Generator, Discriminator, DepthManager, SaverPlugin, SampleGenerator, Adam]
needarg_classes = [Trainer, Generator, Discriminator, DepthManager, SaverPlugin, OutputGenerator, Adam]
needarg_classes += get_all_classes(dataset)
needarg_classes += get_all_classes(output_postprocess)
excludes = {'Adam': {'lr'}}
default_overrides = {'Adam': {'betas': (0.0, 0.99)}}
auto_args = create_params(needarg_classes, excludes, default_overrides)
# default_params.update(auto_args)
for k in default_params:
parser.add_argument('--{}'.format(k), type=partial(generic_arg_parse, hinttype=type(default_params[k])))
for cls in auto_args:
group = parser.add_argument_group(cls, 'Arguments for initialization of class {}'.format(cls))
for k in auto_args[cls]:
name = '{}.{}'.format(cls, k)
parser.add_argument('--{}'.format(name), type=generic_arg_parse)
group.add_argument('--{}'.format(name), type=generic_arg_parse)
default_params[name] = auto_args[cls][k]
parser.set_defaults(**default_params)
params = get_structured_params(vars(parser.parse_args()))
Expand Down
Loading

0 comments on commit aea3646

Please sign in to comment.