In [None]:
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training script."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from absl import app
from absl import flags
from absl import logging

import tensorflow as tf
tf.debugging.set_log_device_placement(True)
tf.config.set_visible_devices([], 'GPU')
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)
import numpy as np

import sonnet as snt
# import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp

import logging as log

logger = tf.get_logger()
logger.setLevel(log.INFO)

import cs
import file_utils
import utils

tfd = tfp.distributions

flags.DEFINE_string(
    'mode', 'recons', 'Model mode.')
flags.DEFINE_integer(
    'num_training_iterations', 10000000,
    'Number of training iterations.')
flags.DEFINE_integer(
    'batch_size', 64, 'Training batch size.')
flags.DEFINE_integer(
    'num_measurements', 25, 'The number of measurements')
flags.DEFINE_integer(
    'num_latents', 100, 'The number of latents')
flags.DEFINE_integer(
    'num_z_iters', 3, 'The number of latent optimisation steps.')
flags.DEFINE_float(
    'z_step_size', 0.01, 'Step size for latent optimisation.')
flags.DEFINE_string(
    'z_project_method', 'norm', 'The method to project z.')
flags.DEFINE_integer(
    'summary_every_step', 10,
    'The interval at which to log debug ops.')
flags.DEFINE_integer(
    'export_every', 10,
    'The interval at which to export samples.')
flags.DEFINE_string(
    'dataset', 'mnist', 'The dataset used for learning (cifar|mnist.')
flags.DEFINE_float('learning_rate', 1e-4, 'Learning rate.')
flags.DEFINE_string(
    'output_dir', '/tmp/cs_gan/cs', 'Location where to save output files.')

FLAGS = flags.FLAGS

logging.info("TensorFlow version: {}".format(tf.__version__))
logging.info("    Sonnet version: {}".format(snt.__version__))
logging.info("    Numpy  version: {}".format(np.__version__))

output_dir = os.path.join(FLAGS.output_dir, 'summaries')
os.system('rm -rf '+output_dir)

utils.make_output_dir(FLAGS.output_dir)
data_processor = utils.DataProcessor()
images = utils.get_train_dataset(data_processor, FLAGS.dataset,
                                     FLAGS.batch_size)

logging.info('Learning rate: %d', FLAGS.learning_rate)


    # Construct optimizers.
optimizer = snt.optimizers.Adam(FLAGS.learning_rate)

    # Create the networks and models.
generator = utils.get_generator(FLAGS.dataset)
metric_net = utils.get_metric_net(FLAGS.dataset, FLAGS.num_measurements)
model = cs.CS(metric_net, generator,
                FLAGS.num_z_iters, FLAGS.z_step_size, FLAGS.z_project_method, optimizer)

sample_exporter = file_utils.FileExporter(
        os.path.join(FLAGS.output_dir, 'reconstructions'))

writer = tf.summary.create_file_writer(output_dir)
logging.info('starting training')
for num_epochs in range(FLAGS.num_training_iterations):
    images = utils.get_train_dataset(data_processor, FLAGS.dataset,
                                        FLAGS.batch_size)
    for i, batch in enumerate(images):
        prior = utils.make_prior(FLAGS.num_latents)
        generator_inputs = prior.sample(FLAGS.batch_size)
        model_output = model.step(batch, generator_inputs)
        debug_ops = model_output.debug_ops
        reconstructions, _ = utils.optimise_and_sample(generator_inputs, model, batch, is_training=False)
        debug_ops['it'] = i
        if i % FLAGS.summary_every_step == 0:
            with writer.as_default():
                for name, op in debug_ops.items():
                    tf.summary.scalar(name, op, debug_ops['it'])

        # if num_epochs % FLAGS.export_every == 0:
        #     reconstructions_np, data_np = sess.run([reconstructions, images])
        #     # Create an object which gets data and does the processing.
        #     data_np = data_processor.postprocess(data_np)
        #     reconstructions_np = data_processor.postprocess(reconstructions_np)
        #     sample_exporter.save(reconstructions_np, 'reconstructions')
        #     sample_exporter.save(data_np, 'data')

