In [1]:
try:
    from google.colab import drive
    NOTEBOOK = 'colab'
except:
    import os    
    if list(os.walk('/kaggle/input')):            
        NOTEBOOK = 'kaggle'
    else:
        NOTEBOOK = 'home'
        
if NOTEBOOK == 'colab':
    !pip install -q dm-sonnet
    !pip install -q tensorflow-gan
!grep Model: /proc/driver/nvidia/gpus/*/information | awk '{$1="";print$0}'

 GeForce GTX 650


In [2]:
import tensorflow as tf
tf.debugging.set_log_device_placement(True)
tf.config.set_visible_devices([], 'GPU')
gpus = tf.config.experimental.list_physical_devices('GPU')
if tf.test.gpu_device_name():
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))
else:
    print("Please install GPU version of TF")
print(gpus)
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

Please install GPU version of TF
[]


In [3]:
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Network utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import re
import math
# from PIL import Image
from tqdm import tqdm
import numpy as np
import sonnet as snt
import tensorflow_probability as tfp
import tensorflow_gan as tfgan
import collections
import os
import sys


from absl import app
from absl import flags
from absl import logging


def _sn_custom_getter():
    def name_filter(name):
        match = re.match(r'.*w(_.*)?$', name)
        return match is not None

    return tfgan.features.spectral_normalization_custom_getter(name_filter=name_filter)


class SNGenNet(snt.Module):
    """As in the SN paper."""
    
    def __init__(self, name='conv_gen'):
        super(SNGenNet, self).__init__(name=name)
        
    @tf.function
    def __call__(self, inputs, is_training):
        batch_size = inputs.get_shape().as_list()[0]
        first_shape = [4, 4, 512]
        norm_ctor = snt.BatchNormV2
        norm_ctor_config = {'scale': True}
        up_tensor = snt.Linear(np.prod(first_shape))(inputs)
        first_tensor = tf.reshape(up_tensor, shape=[batch_size] + first_shape)

        net = snt.nets.ConvNet2DTranspose(
            output_channels=[256, 128, 64, 3],
            output_shapes=[(8, 8), (16, 16), (32, 32), (32, 32)],
            kernel_shapes=[(4, 4), (4, 4), (4, 4), (3, 3)],
            strides=[2, 2, 2, 1],
            normalization_ctor=norm_ctor,
            normalization_kwargs=norm_ctor_config,
            normalize_final=False,
            paddings=[snt.SAME], activate_final=False, activation=tf.nn.relu)
        output = net(first_tensor, is_training=is_training)
        return tf.nn.tanh(output)


class SNMetricNet(snt.Module):
    """Spectral normalization discriminator (metric) architecture."""
    
    def __init__(self, num_outputs=2, name='sn_metric'):
        super(SNMetricNet, self).__init__(name=name)
        self._num_outputs = num_outputs

    @tf.function
    def __call__(self, inputs):
        with tf.variable_scope('', custom_getter=_sn_custom_getter()):
          net = snt.nets.ConvNet2D(
              output_channels=[64, 64, 128, 128, 256, 256, 512],
              kernel_shapes=[
                  (3, 3), (4, 4), (3, 3), (4, 4), (3, 3), (4, 4), (3, 3)],
              strides=[1, 2, 1, 2, 1, 2, 1],
              paddings=[snt.SAME], activate_final=True,
              activation=functools.partial(tf.nn.leaky_relu, alpha=0.1))
          linear = snt.Linear(self._num_outputs)
        output = linear(snt.BatchFlatten()(net(inputs)))
        return output


class MLPGeneratorNet(snt.Module):
    """MNIST generator net."""

    def __init__(self, name='mlp_generator'):
        super(MLPGeneratorNet, self).__init__(name=name)
        self.net = snt.nets.MLP([500, 500, 784], activation=tf.nn.leaky_relu)

    @tf.function
    def __call__(self, inputs, is_training=True):
        del is_training
        out = self.net(inputs)
        out = tf.nn.tanh(out)
        out = snt.Reshape([28, 28, 1])(out)
        return out


class MLPMetricNet(snt.Module):
    """Same as in Grover and Ermon, ICLR workshop 2017."""

    def __init__(self, num_outputs=2, name='mlp_metric'):
        super(MLPMetricNet, self).__init__(name=name)
        self._layer_size = [500, 500, num_outputs]
        self.net = snt.nets.MLP(self._layer_size,
                       activation=tf.nn.leaky_relu)
        
    @tf.function
    def __call__(self, inputs):
        output = self.net(snt.flatten(inputs))
        return output





In [4]:
# python3

# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for latent optimisation."""


# import nets

tfd = tfp.distributions


class ModelOutputs(
    collections.namedtuple('AdversarialModelOutputs',
                           ['optimization_components', 'debug_ops'])):
  """All the information produced by the adversarial module.

  Fields:

    * `optimization_components`: A dictionary. Each entry in this dictionary
      corresponds to a module to train using their own optimizer. The keys are
      names of the components, and the values are `common.OptimizationComponent`
      instances. The keys of this dict can be made keys of the configuration
      used by the main train loop, to define the configuration of the
      optimization details for each module.
    * `debug_ops`: A dictionary, from string to a scalar `tf.Tensor`. Quantities
      used for tracking training.
  """


class OptimizationComponent(
    collections.namedtuple('OptimizationComponent', ['loss', 'vars'])):
  """Information needed by the optimizer to train modules.

  Usage:
      `optimizer.minimize(
          opt_compoment.loss, var_list=opt_component.vars)`

  Fields:

    * `loss`: A `tf.Tensor` the loss of the module.
    * `vars`: A list of variables, the ones which will be used to minimize the
      loss.
  """


def cross_entropy_loss(logits, expected):
  """The cross entropy classification loss between logits and expected values.

  The loss proposed by the original GAN paper: https://arxiv.org/abs/1406.2661.

  Args:
    logits: a `tf.Tensor`, the model produced logits.
    expected: a `tf.Tensor`, the expected output.

  Returns:
    A scalar `tf.Tensor`, the average loss obtained on the given inputs.

  Raises:
    ValueError: if the logits do not have shape [batch_size, 2].
  """

  num_logits = logits.get_shape()[1]
  if num_logits != 2:
    raise ValueError(('Invalid number of logits for cross_entropy_loss! '
                      'cross_entropy_loss supports only 2 output logits!'))
  return tf.reduce_mean(
      tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=logits, labels=expected))


def optimise_and_sample(init_z, module, data, is_training):
  """Optimising generator latent variables and sample."""

  if module.num_z_iters == 0:
    z_final = init_z
  else:
    init_z = _project_z(init_z, module.z_project_method)    
    z = tf.Variable(init_z)
    for i in range(module.num_z_iters):
      with tf.GradientTape() as tape:
        loop_samples = module.generator(z, is_training)
        gen_loss = module.gen_loss_fn(data, loop_samples)
      z_grad = tape.gradient(gen_loss, z)
      z.assign_sub(module.z_step_size * z_grad)
      z.assign(_project_z(z, module.z_project_method))
#       z = tf.Variable()
    z_final = z
  return module.generator(z_final, is_training), z_final


def get_optimisation_cost(initial_z, optimised_z):
  optimisation_cost = tf.reduce_mean(
      tf.reduce_sum((optimised_z - initial_z)**2, -1))
  return optimisation_cost


def _project_z(z, project_method='clip'):
  """To be used for projected gradient descent over z."""
  if project_method == 'norm':
    z_p = tf.nn.l2_normalize(z, axis=-1)
  elif project_method == 'clip':
    z_p = tf.clip_by_value(z, -1, 1)
  else:
    raise ValueError('Unknown project_method: {}'.format(project_method))
  return z_p


class DataProcessor(object):

  def preprocess(self, x):
    return x * 2 - 1

  def postprocess(self, x):
    return (x + 1) / 2.


def _get_np_data(data_processor, dataset, split='train'):
  """Get the dataset as numpy arrays."""
  index = 0 if split == 'train' else 1
  if dataset == 'mnist':
    # Construct the dataset.
    x, _ = tf.keras.datasets.mnist.load_data()[index]
    # Note: tf dataset is binary so we convert it to float.
    x = x.astype(np.float32)
    x = x / 255.
    x = x.reshape((-1, 28, 28, 1))

  if dataset == 'cifar':
    x, _ = tf.keras.datasets.cifar10.load_data()[index]
    x = x.astype(np.float32)
    x = x / 255.

  if data_processor:
    # Normalize data if a processor is given.
    x = data_processor.preprocess(x)
  return x


def make_output_dir(output_dir):
  tf.print('Creating output dir %s', output_dir)
  if not tf.io.gfile.isdir(output_dir):
    tf.io.gfile.makedirs(output_dir)


def get_ckpt_dir(output_dir):
  ckpt_dir = os.path.join(output_dir, 'ckpt')
  if not tf.io.gfile.isdir(ckpt_dir):
    tf.io.gfile.makedirs(ckpt_dir)
  return ckpt_dir


def get_real_data_for_eval(num_eval_samples, dataset, split='valid'):
  data = _get_np_data(data_processor=None, dataset=dataset, split=split)
  data = data[:num_eval_samples]
  return tf.constant(data)


def get_summaries(ops, logger):
  summaries = []
  for name, op in ops.items():
    # Ensure to log the value ops before writing them in the summary.
    # We do this instead of a hook to ensure IS/FID are never computed twice.
    print_op = tf.print(name, [op], output_stream=logger)#tf.logging.info)
    with tf.control_dependencies([print_op]):
      summary = tf.summary.scalar(name, op)
      summaries.append(summary)
  return summaries


def get_train_dataset(data_processor, dataset, batch_size):
  """Creates the training data tensors."""
  x_train = _get_np_data(data_processor, dataset, split='train')
  tf.print(x_train.shape)
  # Create the TF dataset.
  dataset = tf.data.Dataset.from_tensor_slices(x_train)

  # Shuffle and repeat the dataset for training.
  # This is required because we want to do multiple passes through the entire
  # dataset when training.
  dataset = dataset.shuffle(60000).batch(batch_size)

  # Batch the data and return the data batch.
  # one_shot_iterator = dataset.batch(batch_size)#.make_one_shot_iterator()
  # data_batch = one_shot_iterator.__iter__()
  # return next(data_batch)
  return dataset


def get_generator(dataset):
  if dataset == 'mnist':
    return MLPGeneratorNet()
  if dataset == 'cifar':
    return SNGenNet()


def get_metric_net(dataset, num_outputs=2):
  if dataset == 'mnist':
    return MLPMetricNet(num_outputs)
  if dataset == 'cifar':
    return SNMetricNet(num_outputs)


def make_prior(num_latents):
  # Zero mean, unit variance prior.
  prior_mean = tf.zeros(shape=(num_latents), dtype=tf.float32)
  prior_scale = tf.ones(shape=(num_latents), dtype=tf.float32)

  return tfd.Normal(loc=prior_mean, scale=prior_scale)



In [5]:
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GAN modules."""



# class TrainableVariable(tf.Module):
#   def __call__(self, x):
#     if not hasattr(self, 'w'):
#       self.w = tf.Variable(tf.random.normal([x.shape[1], 64]))
#     return tf.matmul(x, self.w)

class TrainableVariable(snt.Module):
    """Provides learnable parameter Tensor."""
    def __init__(self, shape, dtype=tf.float32, initializers=None,  name="trainable_variable"):
        super(TrainableVariable, self).__init__(name=name)
        self._shape = tuple(shape)
        self._dtype = dtype
        self._initializers = initializers

#     @tf.function
    def __call__(self):
        """Connects the TrainableTensor module into the graph.
        Returns:
          A Tensor of shape as determined in the constructor.
        """
        self._w = tf.Variable(name = "z_step_size",
                              shape=self._shape,
                              dtype=self._dtype,
                              initial_value=self._initializers,
                              )
        return self._w




class CS(snt.Module):
    """Compressed Sensing Module."""

    def __init__(self, metric_net, generator,
               num_z_iters, z_step_size, z_project_method, optimizer):
        """Constructs the module.

        Args:
          metric_net: the measurement network.
          generator: The generator network. A sonnet module. For examples, see
            `nets.py`.
          num_z_iters: an integer, the number of latent optimisation steps.
          z_step_size: an integer, latent optimisation step size.
          z_project_method: the method for projecting latent after optimisation,
            a string from {'norm', 'clip'}.
        """
        super(CS, self).__init__()
        self._measure = metric_net
        self.generator = generator
        self.num_z_iters = num_z_iters
        self.z_project_method = z_project_method
        self._log_step_size_module = TrainableVariable(shape=[], dtype=tf.float32, initializers=math.log(z_step_size))
        self.z_step_size = tf.exp(self._log_step_size_module())
        self.optimizer = optimizer
        self.z = None



#     @tf.function
    def step(self, data, generator_inputs):
        """Connects the components and returns the losses, outputs and debug ops.

        Args:
          data: a `tf.Tensor`: `[batch_size, ...]`. There are no constraints on the
            rank
            of this tensor, but it has to be compatible with the shapes expected
            by the discriminator.
          generator_inputs: a `tf.Tensor`: `[g_in_batch_size, ...]`. It does not
            have to have the same batch size as the `data` tensor. There are not
            constraints on the rank of this tensor, but it has to be compatible
            with the shapes the generator network supports as inputs.

        Returns:
          An `ModelOutputs` instance.
        """
        debug_ops = {}

        with tf.GradientTape() as tape:
            samples, optimised_z = optimise_and_sample(
                generator_inputs, self, data, is_training=True)
            optimisation_cost = get_optimisation_cost(generator_inputs,
                                                        optimised_z)
            initial_samples = self.generator(generator_inputs, is_training=True)
            generator_loss = tf.reduce_mean(self.gen_loss_fn(data, samples))
            # compute the RIP loss
            # (\sqrt{F(x_1 - x_2)^2} - \sqrt{(x_1 - x_2)^2})^2
            # as a triplet loss for 3 pairs of images.
            r1 = self._get_rip_loss(samples, initial_samples)
            r2 = self._get_rip_loss(samples, data)
            r3 = self._get_rip_loss(initial_samples, data)
            rip_loss = tf.reduce_mean((r1 + r2 + r3) / 3.0)
            total_loss = generator_loss + rip_loss
        optimization_components = self._build_optimization_components(
            generator_loss=total_loss)
        grads = tape.gradient(optimization_components.loss, optimization_components.vars)
        self.optimizer.apply(grads, optimization_components.vars)
        
        

        debug_ops['rip_loss'] = rip_loss
        debug_ops['recons_loss'] = tf.reduce_mean(
          tf.norm(snt.flatten(samples)
                  - snt.flatten(data), axis=-1))
        debug_ops['z_step_size'] = self.z_step_size
        debug_ops['opt_cost'] = optimisation_cost
        debug_ops['gen_loss'] = generator_loss

        return ModelOutputs(
            optimization_components, debug_ops)

    def _get_rip_loss(self, img1, img2):
        r"""Compute the RIP loss from two images.

          The RIP loss: (\sqrt{F(x_1 - x_2)^2} - \sqrt{(x_1 - x_2)^2})^2

        Args:
          img1: an image (x_1), 4D tensor of shape [batch_size, W, H, C].
          img2: an other image (x_2), 4D tensor of shape [batch_size, W, H, C].
        """

        m1 = self._measure(img1)
        m2 = self._measure(img2)

        img_diff_norm = tf.norm(snt.flatten(img1)
                                - snt.flatten(img2), axis=-1)
        m_diff_norm = tf.norm(m1 - m2, axis=-1)

        return tf.square(img_diff_norm - m_diff_norm)

    def _get_measurement_error(self, target_img, sample_img):
        """Compute the measurement error of sample images given the targets."""

        m_targets = self._measure(target_img)
        m_samples = self._measure(sample_img)

        return tf.reduce_sum(tf.square(m_targets - m_samples), -1)

    def gen_loss_fn(self, data, samples):
        """Generator loss as latent optimisation's error function."""
        return self._get_measurement_error(data, samples)

    def _build_optimization_components(
        self, generator_loss=None, discriminator_loss=None):
        """Create the optimization components for this module."""

        metric_vars = _get_and_check_variables(self._measure)
        generator_vars = _get_and_check_variables(self.generator)
        step_vars = self._log_step_size_module.trainable_variables


        assert discriminator_loss is None
        optimization_components = OptimizationComponent(
            generator_loss, generator_vars + metric_vars + step_vars)
        return optimization_components


def _get_and_check_variables(module):
    # module_variables = module.get_all_variables()
    module_variables = module.net.trainable_variables
    if not module_variables:
        raise ValueError(
            'Module {} has no variables! Variables needed for training.'.format(
                module.module_name))

    # TensorFlow optimizers require lists to be passed in.
    return module_variables


In [6]:
from datetime import datetime as dt
# Clear any logs from previous runs
!rm -rf ./logs/
log_dir = "logs/" + dt.now().strftime("%Y%m%d-%H%M%S")
writer = tf.summary.create_file_writer(log_dir)

%reload_ext tensorboard
%tensorboard --logdir logs


Executing op Add in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op SummaryWriter in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op CreateSummaryFileWriter in device /job:localhost/replica:0/task:0/device:CPU:0


Launching TensorBoard...

KeyboardInterrupt: 

In [None]:
def del_all_flags(FLAGS):
    flags_dict = FLAGS.__flags
    keys_list = [keys for keys in flags_dict]
    for name in list(flags.FLAGS):
        delattr(flags.FLAGS, name)

FIRST = False

def run(main=None, argv=None):
  args = argv[1:] if argv else None
  flags_passthrough = app.parse_flags_with_usage(args=args)
  main = main or sys.modules['__main__'].main
  main(sys.argv[:1] + flags_passthrough)





In [None]:
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training script."""

if FIRST:
    del_all_flags(FLAGS)
FIRST = True


# import logging as log

logger = tf.get_logger()
logger.setLevel(logging.INFO)


flags.DEFINE_string('f', '', 'kernel')
flags.DEFINE_boolean('debug', True, 'Produces debugging output.')
flags.DEFINE_string(
    'mode', 'recons', 'Model mode.')
flags.DEFINE_integer(
    'num_training_iterations', 20,
    'Number of training iterations.')
flags.DEFINE_integer(
    'batch_size', 64, 'Training batch size.')
flags.DEFINE_integer(
    'num_measurements', 25, 'The number of measurements')
flags.DEFINE_integer(
    'num_latents', 100, 'The number of latents')
flags.DEFINE_integer(
    'num_z_iters', 3, 'The number of latent optimisation steps.')
flags.DEFINE_float(
    'z_step_size', 0.01, 'Step size for latent optimisation.')
flags.DEFINE_string(
    'z_project_method', 'norm', 'The method to project z.')
flags.DEFINE_integer(
    'summary_every_step', 1000,
    'The interval at which to log debug ops.')
flags.DEFINE_integer(
    'export_every', 1000,
    'The interval at which to export samples.')
flags.DEFINE_string(
    'dataset', 'mnist', 'The dataset used for learning (cifar|mnist.')
flags.DEFINE_float('learning_rate', 1e-4, 'Learning rate.')
flags.DEFINE_string(
    'output_dir', './cs', 'Location where to save output files.')

FLAGS = flags.FLAGS


def main(argv):
    if FLAGS.debug:
        tf.print('non-flag arguments:', argv)
    tf.print(FLAGS.output_dir)

    tf.print("TensorFlow version: {}".format(tf.__version__))
    tf.print("Tensorflow eager execution: {}".format(tf.executing_eagerly()))
    tf.print("    Sonnet version: {}".format(snt.__version__))
    tf.print("    Numpy  version: {}".format(np.__version__))

    output_dir = os.path.join(FLAGS.output_dir, 'summaries')
    os.system('rm -rf ' + output_dir)

    make_output_dir(FLAGS.output_dir)
    data_processor = DataProcessor()
    
    tf.print('Learning rate: %d', FLAGS.learning_rate)

    # Construct optimizers.
    optimizer = snt.optimizers.Adam(FLAGS.learning_rate)

    # Create the networks and models.
    generator = get_generator(FLAGS.dataset)
    metric_net = get_metric_net(FLAGS.dataset, FLAGS.num_measurements)
    model = CS(metric_net, generator, FLAGS.num_z_iters, FLAGS.z_step_size, FLAGS.z_project_method, optimizer)

    sample_exporter = FileExporter(os.path.join(FLAGS.output_dir, 'reconstructions'))
    
    t = tqdm(range(FLAGS.num_training_iterations * int(60000/64)), unit='sig', unit_scale=FLAGS.batch_size, position=0)
    prior = make_prior(FLAGS.num_latents)
    tf.print('starting training')        
    for num_epochs in range(FLAGS.num_training_iterations):        
        images = get_train_dataset(data_processor, FLAGS.dataset, FLAGS.batch_size)        
        for i, batch in enumerate(images):            
            generator_inputs = prior.sample(FLAGS.batch_size)
            model_output = model.step(batch, generator_inputs)
            if num_epochs == 0 and i==0:
                print(snt.format_variables(model_output.optimization_components.vars))
            debug_ops = model_output.debug_ops
            reconstructions, _ = optimise_and_sample(
                generator_inputs, model, batch, is_training=False)
            t.update(1)               
            debug_ops['it'] = i
            if i % FLAGS.summary_every_step == 0:
                t.write('Epoch = {}/{} (lr_mult = {:0.09f}, loss = {:0.02f}, accuracy = {:0.02f}, incorrect = {} done.'.format(
                num_epochs, FLAGS.num_training_iterations, FLAGS.learning_rate, model_output.optimization_components.loss, 0, 0))  
                with writer.as_default():
                    for name, op in debug_ops.items():
                        tf.summary.scalar(name, op, debug_ops['it'])

    return
run(main, ['Avoid', 'problems', 'with', 'FLGAS', ])