In [None]:
from os.path import isdir, join, isfile
from os import getcwd, system

import chainer
import yaml
# # local imports from the git.
from source.yaml_utils import Config, load_dataset
from source.misc_train_utils import load_models_cgan, ensure_config_paths
from evaluations.extensions_cgan import gen_images_cgan

###### Load config and base paths

In [None]:
# # the base path (current path that should have the code).
pb = join(getcwd(), '')
# # load config.
config = Config(yaml.load(open(join(pb, 'jobs/rocgan/demo.yml'))))
# # ensure that the paths of the config are correct.
config = ensure_config_paths(config, pb=pb)

###### Load model/iterator

In [None]:
# # load the class for the db reader.
db_test = load_dataset(config, validation=True, valid_path='files_test.txt')
# # load the iterator.
iterator = chainer.iterators.SerialIterator(db_test, config.batchsize, 
                                            shuffle=False, repeat=False)

# # load the encoder/decoder architecture.
enc, dec, _ = load_models_cgan(config)
# # path of the folder with the weights.
pfold = join(pb, 'models_rocgan', '{}_best.npz')
# # load the weights.
if not isfile(pfold.format('Encoder')):
    msg = ('Please download the model from http://bit.ly/2GBtx0z '
           ' and place the models in the models_rocgan/.')
    print(msg)
chainer.serializers.load_npz(pfold.format('Encoder'), enc)
chainer.serializers.load_npz(pfold.format('Decoder'), dec)

###### Evaluate and visualize

In [None]:
ims = gen_images_cgan(enc, dec, iterator, n=100, func=None)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
# # visualize the image idx.
# # From left to right: The output of the network, the 
# # corrupted image and the ground-truth one.
idx = 0
im_i = lambda ims, idx=idx: ims[idx].transpose(1, 2, 0)
im = np.concatenate((im_i(ims[0]), im_i(ims[1]), im_i(ims[2])), axis=1)
plt.imshow(im)