In [1]:
!pip install librosa scipy ipywidgets matplotlib absl-py==0.7.1 dm-sonnet==1.34 numpy==1.16.4 Pillow tensorflow==1.15 tensorflow-probability==0.7.0 tensorflow-gan==2.0.0 protobuf==3.20.3




In [3]:
import os
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

In [4]:
from os import listdir
import scipy.io.wavfile as wav
from os.path import isfile, join
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
def create_specs(fft, time_long, step_size, log_ref):
  audio_dir='/home/studio-lab-user/sagemaker-studiolab-notebooks/free-spoken-digit-dataset/recordings/'
  file_names = [f for f in listdir(audio_dir) if isfile(join(audio_dir, f)) and '.wav' in f]
  ms_list = np.zeros([0,int(fft/2),int(time_long/step_size)])

  #sp_sz=2046
  sp_sz = int(time_long)
  i = 0
  for file_name in file_names:
    i += 1
    audio_path = audio_dir + file_name

    sample_rate, samples = wav.read(audio_path)
    samples = np.append(samples, np.random.randn(sp_sz-samples.shape[0]%sp_sz)*10, axis=0)
    ms = np.transpose(pretty_spectrogram(samples.astype("float32"),fft_size=fft,step_size=step_size,log=False))
    n_ms = samples.shape[0]//sp_sz
    ms = np.expand_dims(librosa.power_to_db(ms,
                                            ref=log_ref), axis=0)
    lms = np.split(ms, n_ms, axis=2)
    ms2 = np.concatenate(lms)
    ms_list = np.append(ms_list,ms2,axis=0)
  clip = float(np.ceil(np.amax(ms_list)))
  print("clip",clip)
  print("min value spec", np.amin(ms_list))
  print("max value spec",np.amax(ms_list))
  X_train, X_test = train_test_split(
    ms_list, test_size=0.20, random_state=42)
  return X_train, X_test, clip

In [6]:
#@title biblio de audio

import IPython.display
from ipywidgets import interact, interactive, fixed

# Packages we're using
import numpy as np
import matplotlib.pyplot as plt
import copy
from scipy.io import wavfile
from scipy.signal import butter, lfilter
import scipy.ndimage
# Most of the Spectrograms and Inversion are taken from: https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe


def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype="band")
    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y


def overlap(X, window_size, window_step):
    """
    Create an overlapped version of X
    Parameters
    ----------
    X : ndarray, shape=(n_samples,)
        Input signal to window and overlap
    window_size : int
        Size of windows to take
    window_step : int
        Step size between windows
    Returns
    -------
    X_strided : shape=(n_windows, window_size)
        2D array of overlapped X
    """
    if window_size % 2 != 0:
        raise ValueError("Window size must be even!")
    # Make sure there are an even number of windows before stridetricks
    append = np.zeros((window_size - len(X) % window_size))
    X = np.hstack((X, append))

    ws = window_size
    ss = window_step
    a = X

    valid = len(a) - ws
    nw = (valid) // ss
    out = np.ndarray((nw, ws), dtype=a.dtype)

    for i in np.arange(nw):
        # "slide" the window along the samples
        start = i * ss
        stop = start + ws
        out[i] = a[start:stop]

    return out


def stft(
    X, fftsize=128, step=65, mean_normalize=True, real=False, compute_onesided=True
):
    """
    Compute STFT for 1D real valued input X
    """
    if real:
        local_fft = np.fft.rfft
        cut = -1
    else:
        local_fft = np.fft.fft
        cut = None
    if compute_onesided:
        cut = fftsize // 2
    if mean_normalize:
        X -= X.mean()

    X = overlap(X, fftsize, step)

    size = fftsize
    win = 0.54 - 0.46 * np.cos(2 * np.pi * np.arange(size) / (size - 1))
    X = X * win[None]
    X = local_fft(X)[:, :cut]
    return X


def pretty_spectrogram(d, log=True, thresh=5, fft_size=512, step_size=64):
    """
    creates a spectrogram
    log: take the log of the spectrgram
    thresh: threshold minimum power for log spectrogram
    """
    specgram = np.abs(
        stft(d, fftsize=fft_size, step=step_size, real=False, compute_onesided=True)
    )

    if log == True:
        specgram /= specgram.max()  # volume normalize to max 1
        specgram = np.log10(specgram)  # take log
        specgram[
            specgram < -thresh
        ] = -thresh  # set anything less than the threshold as the threshold
    else:
        specgram[
            specgram < thresh
        ] = thresh  # set anything less than the threshold as the threshold

    return specgram


# Also mostly modified or taken from https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
def invert_pretty_spectrogram(
    X_s, log=True, fft_size=512, step_size=512 / 4, n_iter=10
):

    if log == True:
        X_s = np.power(10, X_s)

    X_s = np.concatenate([X_s, X_s[:, ::-1]], axis=1)
    X_t = iterate_invert_spectrogram(X_s, fft_size, step_size, n_iter=n_iter)
    return X_t


def iterate_invert_spectrogram(X_s, fftsize, step, n_iter=10, verbose=False):
    """
    Under MSR-LA License
    Based on MATLAB implementation from Spectrogram Inversion Toolbox
    References
    ----------
    D. Griffin and J. Lim. Signal estimation from modified
    short-time Fourier transform. IEEE Trans. Acoust. Speech
    Signal Process., 32(2):236-243, 1984.
    Malcolm Slaney, Daniel Naar and Richard F. Lyon. Auditory
    Model Inversion for Sound Separation. Proc. IEEE-ICASSP,
    Adelaide, 1994, II.77-80.
    Xinglei Zhu, G. Beauregard, L. Wyse. Real-Time Signal
    Estimation from Modified Short-Time Fourier Transform
    Magnitude Spectra. IEEE Transactions on Audio Speech and
    Language Processing, 08/2007.
    """
    reg = np.max(X_s) / 1e8
    X_best = copy.deepcopy(X_s)
    for i in range(n_iter):
        if verbose:
            print("Runnning iter %i" % i)
        if i == 0:
            X_t = invert_spectrogram(
                X_best, step, calculate_offset=True, set_zero_phase=True
            )
        else:
            # Calculate offset was False in the MATLAB version
            # but in mine it massively improves the result
            # Possible bug in my impl?
            X_t = invert_spectrogram(
                X_best, step, calculate_offset=True, set_zero_phase=False
            )
        est = stft(X_t, fftsize=fftsize, step=step, compute_onesided=False)
        phase = est / np.maximum(reg, np.abs(est))
        X_best = X_s * phase[: len(X_s)]
    X_t = invert_spectrogram(X_best, step, calculate_offset=True, set_zero_phase=False)
    return np.real(X_t)


def invert_spectrogram(X_s, step, calculate_offset=True, set_zero_phase=True):
    """
    Under MSR-LA License
    Based on MATLAB implementation from Spectrogram Inversion Toolbox
    References
    ----------
    D. Griffin and J. Lim. Signal estimation from modified
    short-time Fourier transform. IEEE Trans. Acoust. Speech
    Signal Process., 32(2):236-243, 1984.
    Malcolm Slaney, Daniel Naar and Richard F. Lyon. Auditory
    Model Inversion for Sound Separation. Proc. IEEE-ICASSP,
    Adelaide, 1994, II.77-80.
    Xinglei Zhu, G. Beauregard, L. Wyse. Real-Time Signal
    Estimation from Modified Short-Time Fourier Transform
    Magnitude Spectra. IEEE Transactions on Audio Speech and
    Language Processing, 08/2007.
    """
    size = int(X_s.shape[1] // 2)
    wave = np.zeros((X_s.shape[0] * step + size))
    # Getting overflow warnings with 32 bit...
    wave = wave.astype("float64")
    total_windowing_sum = np.zeros((X_s.shape[0] * step + size))
    win = 0.54 - 0.46 * np.cos(2 * np.pi * np.arange(size) / (size - 1))

    est_start = int(size // 2) - 1
    est_end = est_start + size
    for i in range(X_s.shape[0]):
        wave_start = int(step * i)
        wave_end = wave_start + size
        if set_zero_phase:
            spectral_slice = X_s[i].real + 0j
        else:
            # already complex
            spectral_slice = X_s[i]

        # Don't need fftshift due to different impl.
        wave_est = np.real(np.fft.ifft(spectral_slice))[::-1]
        if calculate_offset and i > 0:
            offset_size = size - step
            if offset_size <= 0:
                print(
                    "WARNING: Large step size >50\% detected! "
                    "This code works best with high overlap - try "
                    "with 75% or greater"
                )
                offset_size = step
            offset = xcorr_offset(
                wave[wave_start : wave_start + offset_size],
                wave_est[est_start : est_start + offset_size],
            )
        else:
            offset = 0
        wave[wave_start:wave_end] += (
            win * wave_est[est_start - offset : est_end - offset]
        )
        total_windowing_sum[wave_start:wave_end] += win
    wave = np.real(wave) / (total_windowing_sum + 1e-6)
    return wave

def xcorr_offset(x1, x2):
    """
    Under MSR-LA License
    Based on MATLAB implementation from Spectrogram Inversion Toolbox
    References
    ----------
    D. Griffin and J. Lim. Signal estimation from modified
    short-time Fourier transform. IEEE Trans. Acoust. Speech
    Signal Process., 32(2):236-243, 1984.
    Malcolm Slaney, Daniel Naar and Richard F. Lyon. Auditory
    Model Inversion for Sound Separation. Proc. IEEE-ICASSP,
    Adelaide, 1994, II.77-80.
    Xinglei Zhu, G. Beauregard, L. Wyse. Real-Time Signal
    Estimation from Modified Short-Time Fourier Transform
    Magnitude Spectra. IEEE Transactions on Audio Speech and
    Language Processing, 08/2007.
    """
    x1 = x1 - x1.mean()
    x2 = x2 - x2.mean()
    frame_size = len(x2)
    half = frame_size // 2
    corrs = np.convolve(x1.astype("float32"), x2[::-1].astype("float32"))
    corrs[:half] = -1e30
    corrs[-half:] = -1e30
    offset = corrs.argmax() - len(x1)
    return offset

import scipy.io.wavfile as wav

### Parameters ###
fft_size = 512  # window size for the FFT
step_size = fft_size // 16  # distance to slide along the window (in time)
spec_thresh = 4  # threshold for spectrograms (lower filters out more noise)
lowcut = 500  # Hz # Low cut for our butter bandpass filter
highcut = 4000  # Hz # High cut for our butter bandpass filter
# For mels
n_mel_freq_components = 64  # number of mel frequency channels
shorten_factor = 10  # how much should we compress the x-axis (time)
start_freq = 50  # Hz # What frequency to start sampling our melS from
end_freq = 4000
audio_path='/home/studio-lab-user/sagemaker-studiolab-notebooks/free-spoken-digit-dataset/recordings/0_george_0.wav'
data_rate, data = wav.read(audio_path)
wav_spectrogram = pretty_spectrogram(
data.astype("float64"),
fft_size=fft_size,
step_size=step_size,
log=True,
thresh=spec_thresh,
)

# Invert from the spectrogram back to a waveform
recovered_audio_orig = invert_pretty_spectrogram(
    wav_spectrogram, fft_size=fft_size, step_size=step_size, log=True, n_iter=10
)

In [7]:
fft = 256
time_long = 512
step_size = 32
fft_step_size_ratio = int(fft/step_size)
log_ref = 5e-0
##put . to std to make it float
std=0.0
compression_ratio = 0.25
measures = int(compression_ratio * time_long)
CRF='DCS25'
X_train, X_test, clip = create_specs(fft=fft, time_long=time_long, step_size= step_size, log_ref= log_ref)
print('number of measures', measures)

clip 55.0
min value spec 0.0
max value spec 54.80099098502209
number of measures 128


In [8]:
#import sys
#sys.path.append('/content/deepmind-research')
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
sys.path.append('/content/')
from IPython.display import clear_output


####Delete all flags before declare#####
'''
def del_all_flags(FLAGS):
    flags_dict = FLAGS._flags()
    keys_list = [keys for keys in flags_dict]
    for keys in keys_list:
        FLAGS.__delattr__(keys)

del_all_flags(flags.FLAGS)
'''





# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training script."""


import os

from absl import app
from absl import flags
from absl import logging

import tensorflow as tf
import tensorflow_probability as tfp

from cs_gan import cs
from cs_gan import file_utils
from cs_gan import utils

tfd = tfp.distributions

flags.DEFINE_string(
    'mode', 'recons', 'Model mode.')
flags.DEFINE_integer(
    'num_training_iterations', 100000,
    'Number of training iterations.')
flags.DEFINE_integer(
    'batch_size', 64, 'Training batch size.')
flags.DEFINE_integer(
    'num_measurements', measures, 'The number of measurements')
flags.DEFINE_integer(
    'num_latents', 100, 'The number of latents')
flags.DEFINE_integer(
    'num_z_iters', 3, 'The number of latent optimisation steps.')
flags.DEFINE_float(
    'z_step_size', 0.01, 'Step size for latent optimisation.')
flags.DEFINE_string(
    'z_project_method', 'norm', 'The method to project z.')
flags.DEFINE_integer(
    'summary_every_step', 10000,
    'The interval at which to log debug ops.')
flags.DEFINE_integer(
    'export_every', 10,
    'The interval at which to export samples.')
flags.DEFINE_string(
    'dataset', 'mnist', 'The dataset used for learning (cifar|mnist.')
flags.DEFINE_float('learning_rate', 1e-4, 'Learning rate.')
flags.DEFINE_string(
    'output_dir','/gdrive/My Drive/'+CRF, 'Location where to save output files.')


FLAGS = flags.FLAGS

# Log info level (for Hooks).
tf.logging.set_verbosity(tf.logging.INFO)



flags.DEFINE_string('f', '', 'kernel')
FLAGS(sys.argv)





W0908 14:02:59.746343 140566818527040 module_wrapper.py:139] From /home/studio-lab-user/.conda/envs/dcs2/lib/python3.7/site-packages/sonnet/python/custom_getters/restore_initializer.py:27: The name tf.GraphKeys is deprecated. Please use tf.compat.v1.GraphKeys instead.

W0908 14:02:59.753797 140566818527040 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

W0908 14:03:00.139117 140566818527040 module_wrapper.py:139] From /home/studio-lab-user/.conda/envs/dcs2/lib/python3.7/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.

W09

['/home/studio-lab-user/.conda/envs/dcs2/lib/python3.7/site-packages/ipykernel_launcher.py']

In [9]:
print(X_train.shape)
temp = -np.ones((X_train.shape[0],1))
print(temp.shape)
X_train = np.reshape(X_train, [-1, X_train.shape[1], X_train.shape[2], 1])
X_test = np.reshape(X_test, [-1, X_test.shape[1], X_test.shape[2], 1])
print(X_train.shape, X_test.shape)

(17612, 128, 16)
(17612, 1)
(17612, 128, 16, 1) (4403, 128, 16, 1)


In [10]:
class specData():
  def __init__(self,fft,time_long, step_size, log_ref,clip):
    self.fft= fft
    self.time_long = time_long
    self.fft_step_size_ratio = int(fft/step_size)
    self.clip = clip
    self.log_ref = log_ref
    self.gen_output_shape = [int(fft/2), int(self.time_long / step_size) , 1]
    self.gen_net_shape = [1000, 1000, self.gen_output_shape[0]* self.gen_output_shape[1] ]

spec = specData(fft,time_long,step_size, log_ref, clip)

In [11]:
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GAN modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

import sonnet as snt
import tensorflow.compat.v1 as tf

from cs_gan import utils


class CS1(object):
  """Compressed Sensing Module."""

  def __init__(self, metric_net, generator,
               num_z_iters, z_step_size, z_project_method):
    """Constructs the module.

    Args:
      metric_net: the measurement network.
      generator: The generator network. A sonnet module. For examples, see
        `nets.py`.
      num_z_iters: an integer, the number of latent optimisation steps.
      z_step_size: an integer, latent optimisation step size.
      z_project_method: the method for projecting latent after optimisation,
        a string from {'norm', 'clip'}.
    """

    self._measure = metric_net
    self.generator = generator
    self.num_z_iters = num_z_iters
    self.z_project_method = z_project_method
    self._log_step_size_module = snt.TrainableVariable(
        [],
        initializers={'w': tf.constant_initializer(math.log(z_step_size))})
    self.z_step_size = tf.exp(self._log_step_size_module())

  def connect(self, data, generator_inputs, std):
    """Connects the components and returns the losses, outputs and debug ops.

    Args:
      data: a `tf.Tensor`: `[batch_size, ...]`. There are no constraints on the
        rank
        of this tensor, but it has to be compatible with the shapes expected
        by the discriminator.
      generator_inputs: a `tf.Tensor`: `[g_in_batch_size, ...]`. It does not
        have to have the same batch size as the `data` tensor. There are not
        constraints on the rank of this tensor, but it has to be compatible
        with the shapes the generator network supports as inputs.

    Returns:
      An `ModelOutputs` instance.
    """

    samples, optimised_z = utils.optimise_and_sample(
        generator_inputs, self, data, std, is_training=True)
    optimisation_cost = utils.get_optimisation_cost(generator_inputs,
                                                    optimised_z)
    debug_ops = {}

    initial_samples = self.generator(generator_inputs, is_training=True)
    eps = tf.random_normal(tf.shape(self._measure(data)), 0, 1, dtype=tf.float32)
    ruido = tf.multiply(0.0, eps)
    generator_loss = tf.reduce_mean(self.gen_loss_fn(data, samples, ruido))
    # compute the RIP loss
    # (\sqrt{F(x_1 - x_2)^2} - \sqrt{(x_1 - x_2)^2})^2
    # as a triplet loss for 3 pairs of images.

    r1 = self._get_rip_loss(samples, initial_samples)
    r2 = self._get_rip_loss(samples, data)
    r3 = self._get_rip_loss(initial_samples, data)
    rip_loss = tf.reduce_mean((r1 + r2 + r3) / 3.0)
    total_loss = generator_loss + rip_loss
    optimization_components = self._build_optimization_components(
        generator_loss=total_loss)
    debug_ops['rip_loss'] = rip_loss
    debug_ops['recons_loss'] = tf.reduce_mean(
        tf.norm(snt.BatchFlatten()(samples)
                - snt.BatchFlatten()(data), axis=-1))
    debug_ops['recons_loss'] = tf.reduce_mean(
        tf.norm(snt.BatchFlatten()(samples)
                - snt.BatchFlatten()(data), axis=-1))/4096
    debug_ops['z_step_size'] = self.z_step_size
    debug_ops['opt_cost'] = optimisation_cost
    debug_ops['gen_loss'] = generator_loss

    return utils.ModelOutputs(
        optimization_components, debug_ops)

  def _get_rip_loss(self, img1, img2):
    r"""Compute the RIP loss from two images.

      The RIP loss: (\sqrt{F(x_1 - x_2)^2} - \sqrt{(x_1 - x_2)^2})^2

    Args:
      img1: an image (x_1), 4D tensor of shape [batch_size, W, H, C].
      img2: an other image (x_2), 4D tensor of shape [batch_size, W, H, C].
    """

    m1 = self._measure(img1)
    m2 = self._measure(img2)

    img_diff_norm = tf.norm(snt.BatchFlatten()(img1)
                            - snt.BatchFlatten()(img2), axis=-1)
    m_diff_norm = tf.norm(m1 - m2, axis=-1)

    return tf.square(img_diff_norm - m_diff_norm)

  def _get_measurement_error(self, target_img, sample_img, ruido):
    """Compute the measurement error of sample images given the targets."""

    m_targets = self._measure(target_img)
    m_targets_con_ruido = tf.add(m_targets, ruido)
    m_samples = self._measure(sample_img)

    return tf.reduce_sum(tf.square(m_targets_con_ruido - m_samples), -1)

  def gen_loss_fn(self, data, samples, ruido):
    """Generator loss as latent optimisation's error function."""
    return self._get_measurement_error(data, samples, ruido)

  def _build_optimization_components(
      self, generator_loss=None, discriminator_loss=None):
    """Create the optimization components for this module."""

    metric_vars = _get_and_check_variables(self._measure)
    generator_vars = _get_and_check_variables(self.generator)
    step_vars = _get_and_check_variables(self._log_step_size_module)

    assert discriminator_loss is None
    optimization_components = utils.OptimizationComponent(
        generator_loss, generator_vars + metric_vars + step_vars)
    return optimization_components


def _get_and_check_variables(module):
  module_variables = module.get_all_variables()
  if not module_variables:
    raise ValueError(
        'Module {} has no variables! Variables needed for training.'.format(
            module.module_name))

  # TensorFlow optimizers require lists to be passed in.
  return list(module_variables)
class MLPGeneratorNet(snt.AbstractModule):
  """MNIST generator net."""

  def __init__(self, name='mlp_generator'):
    super(MLPGeneratorNet, self).__init__(name=name)

  def _build(self, inputs, is_training=True):
    del is_training
    #net = snt.nets.MLP([1000, 1000, 4096], activation=tf.nn.leaky_relu)
    net = snt.nets.MLP(spec.gen_net_shape, activation=tf.nn.leaky_relu)
    out = net(inputs)
    out = tf.nn.tanh(out)
    #return snt.BatchReshape([256, 16, 1])(out)
    return snt.BatchReshape(spec.gen_output_shape)(out)

In [12]:
%reload_ext tensorboard

In [10]:
%tensorboard --logdir .

In [13]:
utils.make_output_dir('/home/studio-lab-user/sagemaker-studiolab-notebooks/checkpoint/'+CRF)
data_processor = utils.DataProcessor()


def get_np_data(data_processor, dataset, split='train'):
  """Get the dataset as numpy arrays."""

  if split == 'train':
    # Construct the dataset.
    x = X_train
    # Note: tf dataset is binary so we convert it to float.
    x = x.astype(np.float32)
    x = x / spec.clip

  if split == 'valid':
    x =  X_test
    x = x.astype(np.float32)
    x = x / spec.clip

  if data_processor:
    # Normalize data if a processor is given.
    x = data_processor.preprocess(x)
  return x
def get_train_dataset2(data_processor, dataset, batch_size):
  """Creates the training data tensors."""
  x_train = get_np_data(data_processor, dataset, split='train')
  # Create the TF dataset.
  dataset = tf.data.Dataset.from_tensor_slices(x_train)

  # Shuffle and repeat the dataset for training.
  # This is required because we want to do multiple passes through the entire
  # dataset when training.
  dataset = dataset.shuffle(100000).repeat()

  # Batch the data and return the data batch.
  one_shot_iterator = dataset.batch(batch_size).make_one_shot_iterator()
  data_batch = one_shot_iterator.get_next()
  return data_batch

def get_real_data_for_eval2(num_eval_samples, dataset, split='valid'):
  data = get_np_data(data_processor=None, dataset=dataset, split=split)
  data = data[:num_eval_samples]
  return tf.constant(data)
images1=utils._get_np_data(data_processor, FLAGS.dataset, split='train')
print(images1.shape)

images = get_train_dataset2(data_processor, FLAGS.dataset,
                                  FLAGS.batch_size)

logging.info('Learning rate: %d', FLAGS.learning_rate)

# Construct optimizers.
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)

# Create the networks and models.
generator = MLPGeneratorNet(FLAGS.dataset)
#generator = utils.get_generator(FLAGS.dataset)
metric_net = utils.get_metric_net(FLAGS.dataset, FLAGS.num_measurements)
model = CS1(metric_net, generator,
              FLAGS.num_z_iters, FLAGS.z_step_size, FLAGS.z_project_method)
prior = utils.make_prior(FLAGS.num_latents)
generator_inputs = prior.sample(FLAGS.batch_size)

std=10.0
model_output = model.connect(images, generator_inputs, std)
optimization_components = model_output.optimization_components
debug_ops = model_output.debug_ops

reconstructions, _ = utils.optimise_and_sample(
    generator_inputs, model, images, std, is_training=False)
images2 = tf.placeholder(tf.float32, shape=(64, spec.gen_output_shape[0], spec.gen_output_shape[1], 1))
reconstructionsa, _ = utils.optimise_and_sample(
  generator_inputs, model, images2,std,  is_training=False)
global_step = tf.train.get_or_create_global_step()
update_op = optimizer.minimize(
    optimization_components.loss,
    var_list=optimization_components.vars,
    global_step=global_step)

sample_exporter = file_utils.FileExporter(
    os.path.join('/home/studio-lab-user/sagemaker-studiolab-notebooks/checkpoint/'+CRF, 'reconstructions'))

# Hooks.
debug_ops['it'] = global_step
# Abort training on Nans.
nan_hook = tf.train.NanTensorHook(optimization_components.loss)
# Step counter.


checkpoint_saver_hook = tf.train.CheckpointSaverHook(
    checkpoint_dir='/home/studio-lab-user/sagemaker-studiolab-notebooks/checkpoint/'+CRF, save_steps=100000)

loss_summary_saver_hook = tf.train.SummarySaverHook(
    output_dir='/home/studio-lab-user/sagemaker-studiolab-notebooks/checkpoint/'+CRF,
    save_steps=5000,
    summary_op=utils.get_summaries(debug_ops))

hooks = [checkpoint_saver_hook, nan_hook, loss_summary_saver_hook]



(60000, 28, 28, 1)


W0908 14:03:45.216815 140566818527040 deprecation.py:323] From /tmp/ipykernel_265/1456660690.py:36: DatasetV1.make_one_shot_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.
W0908 14:03:47.337588 140566818527040 module_wrapper.py:139] From /home/studio-lab-user/.conda/envs/dcs2/lib/python3.7/site-packages/sonnet/python/modules/base.py:177: The name tf.make_template is deprecated. Please use tf.compat.v1.make_template instead.

W0908 14:03:47.360165 140566818527040 module_wrapper.py:139] From /home/studio-lab-user/.conda/envs/dcs2/lib/python3.7/site-packages/sonnet/python/modules/base.py:278: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph ins

In [None]:
import time
#Start training.
t0 = time.time()
with tf.train.MonitoredTrainingSession(hooks=hooks,save_checkpoint_secs=None,
                                       save_checkpoint_steps=None,
                                       save_summaries_steps=100,log_step_count_steps=None,checkpoint_dir='/home/studio-lab-user/sagemaker-studiolab-notebooks/checkpoint/'+CRF) as sess:

  #train_writer = tf.summary.FileWriter( './logs/1/train ', sess.graph)
  logging.info('starting training')
  for i in range(FLAGS.num_training_iterations):
    sess.run(update_op)
    if i % 2000 == 0:
      clear_output()



    if i % FLAGS.export_every == 0:
      reconstructions_np, data_np = sess.run([reconstructions, images])
      # Create an object which gets data and does the processing.
      data_np = data_processor.postprocess(data_np)
      reconstructions_np = data_processor.postprocess(reconstructions_np)
      sample_exporter.save(reconstructions_np, 'reconstructions')
      sample_exporter.save(data_np, 'data')


print( 'time: {}s'.  format(int(time.time()-t0)))

it [10100]
z_step_size [0.00657674577]
opt_cost [100.396896]
recons_loss [0.00115100236]
gen_loss [1.42176127]
rip_loss [4.24925089]
z_step_size [0.00653786771]
it [10200]
opt_cost [96.7141113]
recons_loss [0.0011322191]
gen_loss [1.34456801]
rip_loss [4.08750343]
it [10300]
z_step_size [0.00650236243]
opt_cost [100.076431]
recons_loss [0.0011378444]
gen_loss [1.41349697]
rip_loss [4.08294153]
it [10400]
z_step_size [0.00646156073]
opt_cost [99.1120911]
recons_loss [0.00115863734]
gen_loss [1.41787136]
rip_loss [4.33518791]
it [10500]
z_step_size [0.00642602844]
opt_cost [98.2628784]
recons_loss [0.0011413]
rip_loss [4.18857765]
gen_loss [1.33865726]
it [10600]
z_step_size [0.00639024703]
opt_cost [98.6155396]
recons_loss [0.00112214475]
gen_loss [1.36333907]
rip_loss [4.00278664]
it [10700]
z_step_size [0.00635891175]
opt_cost [101.09256]
recons_loss [0.00111240335]
gen_loss [1.33482635]
rip_loss [3.94307852]
z_step_size [0.00632436341]
it [10800]
opt_cost [99.9594193]
recons_loss [0.