In [1]:
import sys

sys.path.append("/homes/gf332/compression-without-quantization/code")
sys.path.append("/homes/gf332/compression-without-quantization/code/thesis_code")

import os, glob
from tqdm import tqdm as tqdm

import tensorflow.compat.v1 as tf
import tensorflow_compression as tfc
import tensorflow_probability as tfp
import tensorflow.contrib.eager as tfe
tfd = tfp.distributions
tfk = tf.keras
tfl = tf.keras.layers
tfq = tf.quantization

from binary_io import to_bit_string, from_bit_string

from architectures import ProbabilisticLadderNetwork, VariationalAutoEncoder

from miracle import create_dataset, quantize_image, read_png

from greedy_compression import code_grouped_greedy_sample, code_grouped_importance_sample

import matplotlib.pyplot as plt
import numpy as np

#tf.enable_eager_execution()

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
def pln_code_image_greedy(image, 
                          latent_dist_dir,
                          latent_dist_format,
                          seed, 
                          n_steps,
                          n_bits_per_step,
                          comp_file_path,
                          backfitting_steps_level_1=0,
                          backfitting_steps_level_2=0,
                          use_log_prob=False,
                          rho=1.,
                          use_importance_sampling=True,
                          first_level_max_group_size_bits=12,
                          second_level_n_bits_per_group=20,
                          second_level_max_group_size_bits=4,
                          second_level_dim_kl_bit_limit=12,
                          outlier_index_bytes=3,
                          outlier_sample_bytes=2,
                          verbose=False):
        
        # -------------------------------------------------------------------------------------
        # Step 1: Set the latent distributions for the image
        # -------------------------------------------------------------------------------------
        
        q1_loc = np.load(latent_dist_dir + latent_dist_format.format("q1_loc"))
        q1_scale = np.load(latent_dist_dir + latent_dist_format.format("q1_scale"))

        p1_loc = np.load(latent_dist_dir + latent_dist_format.format("p1_loc"))
        p1_scale = np.load(latent_dist_dir + latent_dist_format.format("p1_scale"))

        q2_loc = np.load(latent_dist_dir + latent_dist_format.format("q2_loc"))
        q2_scale = np.load(latent_dist_dir + latent_dist_format.format("q2_scale"))

        p2_loc = np.load(latent_dist_dir + latent_dist_format.format("p2_loc"))
        p2_scale = np.load(latent_dist_dir + latent_dist_format.format("p2_scale"))
        
        q1 = tfd.Normal(loc=q1_loc,
                        scale=q1_scale)
        
        p1 = tfd.Normal(loc=p1_loc,
                        scale=p1_scale)
        
        q2 = tfd.Normal(loc=q2_loc,
                        scale=q2_scale)
        
        p2 = tfd.Normal(loc=p2_loc,
                        scale=p2_scale)
        
        first_level_shape = q1.loc.shape.as_list()
        second_level_shape = q2.loc.shape.as_list()
        
        # -------------------------------------------------------------------------------------
        # Step 2: Create a coded sample of the latent space
        # -------------------------------------------------------------------------------------
        
        if verbose: print("Coding second level")
            
        if use_importance_sampling:
            
            sample2, code2, group_indices2, outlier_extras2 = code_grouped_importance_sample(
                target=q2, 
                proposal=p2, 
                n_bits_per_group=second_level_n_bits_per_group, 
                seed=seed, 
                max_group_size_bits=second_level_max_group_size_bits,
                dim_kl_bit_limit=second_level_dim_kl_bit_limit)
            
            outlier_extras2 = list(map(lambda x: tf.reshape(x, [-1]).numpy(), outlier_extras2))
            
        else:
            sample2, code2, group_indices2 = code_grouped_greedy_sample(target=q2, 
                                                                        proposal=p2, 
                                                                        n_bits_per_step=n_bits_per_step, 
                                                                        n_steps=n_steps, 
                                                                        seed=seed, 
                                                                        max_group_size_bits=second_level_max_group_size_bits,
                                                                        adaptive=True,
                                                                        backfitting_steps=backfitting_steps_level_2,
                                                                        use_log_prob=use_log_prob,
                                                                        rho=rho)
            
        # We will encode the group differences as this will cost us less
        group_differences2 = [0]
        
        for i in range(1, len(group_indices2)):
            group_differences2.append(group_indices2[i] - group_indices2[i - 1])
        
        # We need to adjust the priors to the second stage sample
        latents = (tf.reshape(sample2, second_level_shape), latents[1])
        
        
        # Calculate the priors
        self.decode(latents)
        
        if verbose: print("Coding first level")
            
        sample1, code1, group_indices1 = code_grouped_greedy_sample(target=self.latent_posteriors[0], 
                                                                    proposal=self.latent_priors[0], 
                                                                    n_bits_per_step=n_bits_per_step, 
                                                                    n_steps=n_steps, 
                                                                    seed=seed, 
                                                                    max_group_size_bits=first_level_max_group_size_bits,
                                                                    backfitting_steps=backfitting_steps_level_1,
                                                                    use_log_prob=use_log_prob,
                                                                    adaptive=True)
        
        # We will encode the group differences as this will cost us less
        group_differences1 = [0]
        
        for i in range(1, len(group_indices1)):
            group_differences1.append(group_indices1[i] - group_indices1[i - 1])
        
        bitcode = code1 + code2
        # -------------------------------------------------------------------------------------
        # Step 3: Write the compressed file
        # -------------------------------------------------------------------------------------
        
        extras = [seed, n_steps, n_bits_per_step] + first_level_shape[1:3] + second_level_shape[1:3]
        
        var_length_extras = [group_differences1, group_differences2]
        var_length_bits = [first_level_max_group_size_bits,  
                           second_level_max_group_size_bits]
        
        if use_importance_sampling:
            
            var_length_extras += outlier_extras2
            var_length_bits += [ outlier_index_bytes * 8, outlier_sample_bytes * 8 ]
    
        write_bin_code(bitcode, 
                       comp_file_path, 
                       extras=extras,
                       var_length_extras=var_length_extras,
                       var_length_bits=var_length_bits)
        
        # -------------------------------------------------------------------------------------
        # Step 4: Some logging information
        # -------------------------------------------------------------------------------------
        
        total_kls = [tf.reduce_sum(x) for x in self.kl_divergence]
        total_kl = sum(total_kls)

        theoretical_byte_size = (total_kl + 2 * np.log(total_kl + 1)) / np.log(2) / 8
        extra_byte_size = len(group_indices1) * var_length_bits[0] // 8 + \
                          len(group_indices2) * var_length_bits[1] // 8 + 7 * 2
        actual_byte_size = os.path.getsize(comp_file_path)

        actual_no_extra = actual_byte_size - extra_byte_size
        
        first_level_theoretical = (total_kls[0] + 2 * np.log(total_kls[0] + 1)) / np.log(2) / 8
        first_level_actual_no_extra = len(code1) / 8
        first_level_extra = len(group_indices1) * var_length_bits[0] // 8

        sample1_reshaped = tf.reshape(sample1, first_level_shape)
        first_level_avg_log_lik = tf.reduce_mean(self.latent_posteriors[0].log_prob(sample1_reshaped))
        first_level_sample_avg = tf.reduce_mean(self.latent_posteriors[0].log_prob(self.latent_posteriors[0].sample()))
        
        second_level_theoretical = (total_kls[1] + 2 * np.log(total_kls[1] + 1)) / np.log(2) / 8
        second_level_actual_no_extra = len(code2) / 8
        second_level_extra = len(group_indices2) * var_length_bits[1] // 8 + 1
        
        second_bpp = (second_level_actual_no_extra + second_level_extra) * 8 / (image_shape[1] * image_shape[2]) 

        sample2_reshaped = tf.reshape(sample2, second_level_shape)
        second_level_avg_log_lik = tf.reduce_mean(self.latent_posteriors[1].log_prob(sample2_reshaped))
        second_level_sample_avg = tf.reduce_mean(self.latent_posteriors[1].log_prob(self.latent_posteriors[1].sample()))
        
        bpp = 8 * actual_byte_size / (image_shape[1] * image_shape[2]) 
        
        summaries = {
            "image_shape": image_shape,
            "theoretical_byte_size": float(theoretical_byte_size.numpy()),
            "actual_byte_size": actual_byte_size,
            "extra_byte_size": extra_byte_size,
            "actual_no_extra": actual_no_extra,
            "second_bpp": second_bpp,
            "bpp": bpp
        }
        
        if verbose:

            print("Image dimensions: {}".format(image_shape))
            print("Theoretical size: {:.2f} bytes".format(theoretical_byte_size))
            print("Actual size: {:.2f} bytes".format(actual_byte_size))
            print("Extra information size: {:.2f} bytes {:.2f}% of actual size".format(extra_byte_size, 
                                                                                       100 * extra_byte_size / actual_byte_size))
            print("Actual size without extras: {:.2f} bytes".format(actual_no_extra))
            print("Efficiency: {:.3f}".format(actual_byte_size / theoretical_byte_size))
            print("")
            
            print("First level theoretical size: {:.2f} bytes".format(first_level_theoretical))
            print("First level actual (no extras) size: {:.2f} bytes".format(first_level_actual_no_extra))
            print("First level extras size: {:.2f} bytes".format(first_level_extra))
            print("First level Efficiency: {:.3f}".format(
                (first_level_actual_no_extra + first_level_extra) / first_level_theoretical))
            
            print("First level # of groups: {}".format(len(group_indices1)))
            print("First level greedy sample average log likelihood: {:.4f}".format(first_level_avg_log_lik))
            print("First level average sample log likelihood on level 1: {:.4f}".format(first_level_sample_avg))
            print("")
           
            print("Second level theoretical size: {:.2f} bytes".format(second_level_theoretical))
            print("Second level actual (no extras) size: {:.2f} bytes".format(second_level_actual_no_extra))
            print("Second level extras size: {:.2f} bytes".format(second_level_extra))

            if use_importance_sampling:
                print("{} outliers were not compressed (higher than {} bits of KL)".format(len(outlier_extras2[0]),
                                                                                           second_level_dim_kl_bit_limit))
            print("Second level Efficiency: {:.3f}".format(
                (second_level_actual_no_extra + second_level_extra) / second_level_theoretical))
            print("Second level's contribution to bpp: {:.4f}".format(second_bpp))
            print("Second level # of groups: {}".format(len(group_indices2)))
            print("Second level greedy sample average log likelihood: {:.4f}".format(second_level_avg_log_lik))
            print("Second level average sample log likelihood on level 1: {:.4f}".format(second_level_sample_avg))
            print("")
            
            print("{:.4f} bits / pixel".format( bpp ))
        
        return (sample2, sample1), summaries

In [29]:
latent_dist_dir = "/scratch/gf332/data/kodak_cwoq/"
latent_dist_format = "pln_{}.npy"

comp_file_path = "/scratch/gf332/data/kodak_cwoq/test.miracle"

n_bits_per_step = 14
n_steps = 30
seed = 1
rho = 1.
first_level_max_group_size_bits=12
second_level_max_group_size_bits=4

q1_loc = np.load(latent_dist_dir + latent_dist_format.format("q1_loc")).flatten()[:100]
q1_scale = np.load(latent_dist_dir + latent_dist_format.format("q1_scale")).flatten()[:100]

p1_loc = np.load(latent_dist_dir + latent_dist_format.format("p1_loc")).flatten()[:100]
p1_scale = np.load(latent_dist_dir + latent_dist_format.format("p1_scale")).flatten()[:100]

q2_loc = np.load(latent_dist_dir + latent_dist_format.format("q2_loc")).flatten()
q2_scale = np.load(latent_dist_dir + latent_dist_format.format("q2_scale")).flatten()

p2_loc = np.load(latent_dist_dir + latent_dist_format.format("p2_loc")).flatten()
p2_scale = np.load(latent_dist_dir + latent_dist_format.format("p2_scale")).flatten()

q1 = tfd.Normal(loc=q1_loc,
                scale=q1_scale)

p1 = tfd.Normal(loc=p1_loc,
                scale=p1_scale)

q2 = tfd.Normal(loc=q2_loc,
                scale=q2_scale)

p2 = tfd.Normal(loc=p2_loc,
                scale=p2_scale)

In [34]:
def code_importance_sample_(t_loc,
                            t_scale,
                            p_loc,
                            p_scale,
                            n_coding_bits,
                            seed):
    
    
    target=tfd.Normal(loc=t_loc,
                      scale=t_scale)

    proposal=tfd.Normal(loc=p_loc,
                        scale=p_scale)
        
    rank = len(proposal.loc.shape.as_list())
    
    #print("Taking {} samples per step".format(n_samples))
    
    sample_index = []
    
    kls = tfd.kl_divergence(target, proposal)
    total_kl = tf.reduce_sum(kls)
    
    num_samples = tf.cast(tf.ceil(tf.exp(total_kl)), tf.int32)
    
    # Set new seed
    #samples = proposal.sample(num_samples, seed=seed)
    
    # Draw 0 mean, 1 variance samples
    samples = tf.random.stateless_normal(shape=[num_samples] + t_loc.shape.as_list(), 
                                         seed=tf.Variable([1, 42]))
    
    # Transform them to the right thing by scaling and translating appropriately
    samples = tf.tile(tf.expand_dims(p_scale, 0), [num_samples] + [1] * rank) * samples
    samples = tf.tile(tf.expand_dims(p_loc, 0), [num_samples] + [1] * rank) + samples

    importance_weights = tf.reduce_sum(target.log_prob(samples) - proposal.log_prob(samples), axis=1)

    index = tf.argmax(importance_weights)
    best_sample = samples[index:index + 1, :]
    
    #index, best_sample = sess.run([idx, best_samp])
    
#     if np.log(index + 1) / np.log(2) > n_coding_bits:
#         raise Exception("Not enough bits to code importance sample!")
    
    # Turn the index into a bitstring
    bitcode = tf.numpy_function(to_bit_string, [index, n_coding_bits], tf.string)

    return best_sample, bitcode


def decode_importance_sample_(sample_index, 
                              p_loc,
                              p_scale,
                              seed):
    
    proposal = tfd.Normal(loc=p_loc,
                          scale=p_scale)
    
    # Make sure the distributions have the correct type
    if proposal.dtype is not tf.float32:
        raise Exception("Proposal datatype must be float32!")
        
    dim = proposal.loc.shape.as_list()[0]
    
    index = tf.numpy_function(from_bit_string, [sample_index], tf.int64)
    
    #tf.random.set_random_seed(seed)
    samples = proposal.sample(tf.cast(index, tf.int32) + 1, seed=seed)
    
    return samples[index:, ...]


def code_grouped_importance_sample_(sess,
                                    target, 
                                    proposal, 
                                    seed,
                                    n_bits_per_group,
                                    max_group_size_bits=4,
                                    dim_kl_bit_limit=12):
    
    # Make sure the distributions have the correct type
    if target.dtype is not tf.float32:
        raise Exception("Target datatype must be float32!")
        
    if proposal.dtype is not tf.float32:
        raise Exception("Proposal datatype must be float32!")
        
        
    num_dimensions = np.prod(proposal.loc.shape.as_list())
    
    # rescale proposal by the proposal
    p_loc = sess.run(tf.reshape(tf.zeros_like(proposal.loc), [-1]))
    p_scale = sess.run(tf.reshape(tf.ones_like(proposal.scale), [-1]))
    
    # rescale target by the proposal
    t_loc = tf.reshape((target.loc - proposal.loc) / proposal.scale, [-1])
    t_scale = tf.reshape(target.scale / proposal.scale, [-1])
    
    # If we're going to do importance sampling, separate out dimensions with large KL,
    # we'll deal with them separately.
    kl_bits = tf.reshape(tfd.kl_divergence(target, proposal), [-1]) / np.log(2)

    t_loc = sess.run(tf.where(kl_bits <= dim_kl_bit_limit, t_loc, p_loc))
    t_scale = sess.run(tf.where(kl_bits <= dim_kl_bit_limit, t_scale, p_scale))

    # We'll send the quantized samples for dimensions with high KL
    outlier_indices = tf.where(kl_bits > dim_kl_bit_limit)

    target_samples = tf.reshape(target.sample(), [-1])

    # Select only the bits of the sample that are relevant
    outlier_samples = tf.gather_nd(target_samples, outlier_indices)

    # Halve precision
    outlier_samples = tfq.quantize(outlier_samples, -30, 30, tf.quint16).output

    outlier_extras = (outlier_indices, outlier_samples)
    
    kl_divergences = tf.reshape(
        tfd.kl_divergence(tfd.Normal(loc=t_loc, scale=t_scale), 
                          tfd.Normal(loc=p_loc, scale=p_scale)), [-1])

    kl_divs = sess.run(kl_divergences)
    group_start_indices = [0]
    group_kls = []

    total_kl_bits = np.sum(kl_divs) / np.log(2)

    print("Total KL to split up: {:.2f} bits, "
          "maximum bits per group: {}, "
          "estimated number of groups: {},"
          "coding {} dimensions".format(total_kl_bits, 
                                        n_bits_per_group,
                                        total_kl_bits // n_bits_per_group + 1,
                                        num_dimensions
                                        ))

    current_group_size = 0
    current_group_kl = 0
    
    n_nats_per_group = n_bits_per_group * np.log(2) - 1

    for idx in range(num_dimensions):

        group_bits = np.log(current_group_size + 1) / np.log(2)
        
        if group_bits >= max_group_size_bits or \
           current_group_kl + kl_divs[idx] >= n_nats_per_group or \
           idx == num_dimensions - 1:

            group_start_indices.append(idx)
            group_kls.append(current_group_kl / np.log(2))

            current_group_size = 1
            current_group_kl = kl_divs[idx]
            
        else:
            current_group_kl += kl_divs[idx]
            current_group_size += 1
        
    print("Maximum group KL: {:.3f}".format(np.max(group_kls)))
    # ====================================================================== 
    # Sample each group
    # ====================================================================== 
    
    results = []
    
    group_start_indices += [num_dimensions] 
    
    # Get the importance sampling op before looping it to avoid graph construction cost
    # The length is variable, hence the shape is [None]
    target_loc = tf.placeholder(tf.float32, shape=[None])
    target_scale = tf.placeholder(tf.float32, shape=[None])
    
    prop_loc = tf.placeholder(tf.float32, shape=[None])
    prop_scale = tf.placeholder(tf.float32, shape=[None])
    
    seed_feed = tf.placeholder(tf.int32, shape=[2])

    result_ops = code_importance_sample_(t_loc=target_loc,
                                         t_scale=target_scale,
                                         p_loc=prop_loc,
                                         p_scale=prop_scale,
                                         seed=seed_feed,
                                         n_coding_bits=n_bits_per_group)
            
    for i in tqdm(range(len(group_start_indices) - 1)):
        
        start_idx = group_start_indices[i]
        end_idx = group_start_indices[i + 1]
        
        
        result = sess.run(result_ops, feed_dict={target_loc: t_loc[start_idx:end_idx],
                                                 target_scale: t_scale[start_idx:end_idx],
                                                 prop_loc: p_loc[start_idx:end_idx],
                                                 prop_scale: p_scale[start_idx:end_idx],
                                                 seed_feed: [42, seed + i]
                                                })
        results.append(result)
        
    samples, codes = zip(*results)
    
    bitcode = tf.numpy_function(lambda code_words: ''.join([cw.decode("utf-8") for cw in code_words]), 
                                [codes], 
                                tf.string)
    
    sample = tf.concat(samples, axis=1)
    
    # Rescale the sample
    sample = tf.reshape(proposal.scale, [-1]) * sample + tf.reshape(proposal.loc, [-1])
    
    sample = tf.where(kl_bits <= dim_kl_bit_limit, tf.squeeze(sample), target_samples)
    
    sample, bitcode, outlier_extras = sess.run([sample, bitcode, outlier_extras])
    
    return sample, bitcode, group_start_indices, outlier_extras


def decode_grouped_importance_sample_(sess,
                                     bitcode, 
                                     group_start_indices,
                                     proposal, 
                                     n_bits_per_group,
                                     seed,
                                     outlier_indices,
                                     outlier_samples):
    
    # Make sure the distributions have the correct type
    if proposal.dtype is not tf.float32:
        raise Exception("Proposal datatype must be float32!")
    
    num_dimensions = np.prod(proposal.loc.shape.as_list())
    
    # ====================================================================== 
    # Decode each group
    # ====================================================================== 
                
    samples = []
    
    group_start_indices += [num_dimensions]
    
    p_loc = sess.run(tf.reshape(tf.zeros_like(proposal.loc), [-1]))
    p_scale = sess.run(tf.reshape(tf.ones_like(proposal.scale), [-1]))

    # Placeholders
    sample_index = tf.placeholder(tf.string)
    
    prop_loc = tf.placeholder(tf.float32, shape=[None])
    prop_scale = tf.placeholder(tf.float32, shape=[None])
    
    # Get decoding op
    decode_op = decode_importance_sample_(sample_index=sample_index,
                                          p_loc=prop_loc,
                                          p_scale=prop_scale,
                                          seed=seed)

    for i in tqdm(range(len(group_start_indices) - 1)):
        
        samp = sess.run(decode_op, feed_dict = {
            sample_index: bitcode[n_bits_per_group * i: n_bits_per_group * (i + 1)],
            prop_loc: p_loc[group_start_indices[i]:group_start_indices[i + 1]],
            prop_scale: p_scale[group_start_indices[i]:group_start_indices[i + 1]]
        })
        
        samples.append(samp)
        
    sample = tf.concat(samples, axis=1)
    
    # Rescale the sample
    sample = tf.reshape(proposal.scale, [-1]) * sample + tf.reshape(proposal.loc, [-1])
    sample = tf.squeeze(sample)
    
    # Dequantize outliers
    outlier_samples = tfq.dequantize(tf.cast(outlier_samples, tf.quint16), -30, 30)
    
    # Add back the quantized outliers
    outlier_indices = tf.cast(tf.reshape(outlier_indices, [-1, 1]), tf.int32)
    outlier_samples = tf.reshape(outlier_samples, [-1])
    
    updates = tf.scatter_nd(outlier_indices, 
                            outlier_samples, 
                            sample.shape.as_list())
                            
    sample = tf.where(tf.equal(updates, 0), sample, updates)
    
    return sess.run(sample)

In [35]:
n_bits_per_step = 14
n_steps = 30
seed = 1
rho = 1.
first_level_max_group_size_bits=12
second_level_max_group_size_bits=4

with tf.Session() as sess:
    
#     ops = code_importance_sample_(t_loc=q2_loc,
#                                   t_scale=q2_scale,
#                                   p_loc=p2_loc,
#                                   p_scale=p2_scale,
#                                   n_coding_bits=20,
#                                   seed=seed)
    
    
#     best_sample, bitcode = sess.run(ops)
    
        
#     sample_index = tf.placeholder(tf.string)
#     samp_op = decode_importance_sample_(sample_index, proposal=p2, seed=seed)
    
#     samp = sess.run(samp_op, feed_dict={sample_index: bitcode})

    res = code_grouped_importance_sample_(sess=sess,
                                            target=q2, 
                                            proposal=p2, 
                                            seed=seed,
                                            n_bits_per_group=20,
                                            max_group_size_bits=4,
                                            dim_kl_bit_limit=12)
    
    sample, bitcode, group_start_indices, outlier_extras = res
    
#     decoded = decode_grouped_importance_sample_(sess=sess,
#                                                  bitcode=bitcode, 
#                                                  group_start_indices=group_start_indices,
#                                                  proposal=p2, 
#                                                  n_bits_per_group=20,
#                                                  seed=seed,
#                                                  outlier_indices=outlier_extras[0],
#                                                  outlier_samples=outlier_extras[1])

#     res = code_grouped_greedy_sample_(sess=sess,
#                                     target=q1, 
#                                    proposal=p1,
#                                    n_steps=n_steps, 
#                                    n_bits_per_step=n_bits_per_step,
#                                    seed=seed,
#                                    max_group_size_bits=12,
#                                    adaptive=True,
#                                    backfitting_steps=0,
#                                    use_log_prob=False,
#                                    rho=1.)
    
#     sample, bitcode, group_start_indices = res
    
#     dec = decode_grouped_greedy_sample_(sess=sess,
#                                   bitcode=bitcode, 
#                                  group_start_indices=group_start_indices,
#                                  proposal=p1, 
#                                  n_bits_per_step=n_bits_per_step, 
#                                  n_steps=n_steps, 
#                                  seed=seed,
#                                  adaptive=True,
#                                  rho=1.)

Total KL to split up: 22277.49 bits, maximum bits per group: 20, estimated number of groups: 1114.0,coding 12288 dimensions
Maximum group KL: 18.556


ValueError: None values not supported.

In [76]:
decoded[8]

0.001891787

In [77]:
sample[8]

0.001891787

In [None]:
latent_dist_dir = "/scratch/gf332/data/kodak_cwoq/"
latent_dist_format = "pln_{}.npy"

comp_file_path = "/scratch/gf332/data/kodak_cwoq/test.miracle"

n_bits_per_step = 14
n_steps = 30
seed = 1
rho = 1.
first_level_max_group_size_bits=12
second_level_max_group_size_bits=4

pln_code_image_greedy(image=None, 
                      latent_dist_dir=latent_dist_dir,
                      latent_dist_format=latent_dist_format,
                      seed=seed, 
                      n_steps=n_steps,
                      n_bits_per_step=n_bits_per_step,
                      comp_file_path=comp_file_path,
                      backfitting_steps_level_1=0,
                      backfitting_steps_level_2=0,
                      use_log_prob=False,
                      rho=rho,
                      use_importance_sampling=True,
                      first_level_max_group_size_bits=first_level_max_group_size_bits,
                      second_level_n_bits_per_group=20,
                      second_level_max_group_size_bits=second_level_max_group_size_bits,
                      second_level_dim_kl_bit_limit=12,
                      outlier_index_bytes=3,
                      outlier_sample_bytes=2,
                      verbose=False)

In [5]:
def code_greedy_sample_(sess,
                       target, 
                       proposal, 
                       n_bits_per_step, 
                       n_steps, 
                       seed, 
                       rho=1., 
                       backfitting_steps=0,
                       use_log_prob=False):
    
    # Make sure the distributions have the correct type
    if target.dtype is not tf.float32:
        raise Exception("Target datatype must be float32!")
        
    if proposal.dtype is not tf.float32:
        raise Exception("Proposal datatype must be float32!")
        
    dim = proposal.loc.shape.as_list()[0]
    
    n_samples = int(2**n_bits_per_step)
    
    #print("Taking {} samples per step".format(n_samples))

    best_sample = tf.Variable(tf.zeros((1, dim), dtype=tf.float32))
    sess.run(tf.global_variables_initializer())
    
    sample_index = []
    
    # The scale divisor needs to be square rooted because
    # we are dealing with standard deviations and not variances
    scale_divisor = np.sqrt(n_steps)
    
    proposal_shard = tfd.Normal(loc=proposal.loc / n_steps,
                                scale=rho * proposal.scale / scale_divisor)

    for i in range(n_steps):

        # Set new seed
        samples = proposal_shard.sample(n_samples, seed=1000 * seed + i)

        test_samples = tf.tile(best_sample, [n_samples, 1]) + samples

        log_probs = tf.reduce_sum(target.log_prob(test_samples), axis=1)

        index = tf.argmax(log_probs)

        update_samp_op = best_sample.assign(test_samples[index:index + 1, :])

        idx, _ = sess.run([index, update_samp_op])
        
        sample_index.append(idx)
    
    # ----------------------------------------------------------------------
    # Perform backfitting
    # ----------------------------------------------------------------------
    
    # TODO
#     for b in range(backfitting_steps):
        
#         # Single backfitting step
#         for i in range(n_steps):

#             # Set seed to regenerate the previously generated samples here
#             samples = proposal_shard.sample(n_samples, seed=1000 * seed + i)
            
#             idx = sample_index[i]
            
#             # Undo the addition of the current sample
#             best_sample.assign_sub(samples[idx : idx + 1, :])
            
#             # Generate candidate samples
#             test_samples = tf.tile(best_sample, [n_samples, 1]) + samples

#             if use_log_prob:
#                 test_scores = tf.reduce_sum(target.log_prob(test_samples), axis=1)
#             else:
#                 test_scores = tf.reduce_sum(-((test_samples - target.loc) / target.scale)**8,
#                                            axis=1)

#             index = tf.argmax(test_scores)

#             samp_update_op = best_sample.assign(test_samples[index:index + 1, :])

#             idx, _ = sess.run([index, samp_update_op])
            
#             sample_index[i] = idx
    
    
    sample_index = list(map(lambda x: to_bit_string(x, n_bits_per_step), sample_index))
    sample_index = ''.join(sample_index)
    
    return best_sample.eval(session=sess), sample_index



def decode_greedy_sample_(sess, sample_index, proposal, n_bits_per_step, n_steps, seed, rho=1.):
    
    # Make sure the distributions have the correct type
    if proposal.dtype is not tf.float32:
        raise Exception("Proposal datatype must be float32!")
        
    dim = proposal.loc.shape.as_list()[0]
    
    indices = [from_bit_string(sample_index[i:i + n_bits_per_step]) 
               for i in range(0, n_bits_per_step * n_steps, n_bits_per_step)]
        
    # The scale divisor needs to be square rooted because
    # we are dealing with standard deviations and not variances
    scale_divisor = np.sqrt(n_steps)    
    
    proposal_shard = tfd.Normal(loc=proposal.loc / n_steps,
                                scale=rho * proposal.scale / scale_divisor)    
    
    n_samples = int(2**n_bits_per_step)
    
    sample = tf.Variable(tf.zeros((1, dim), dtype=tf.float32))
    sess.run(tf.global_variables_initializer())
    
    for i in range(n_steps):
        
        # Set new seed
        samples = tf.tile(sample, [n_samples, 1]) + proposal_shard.sample(n_samples, seed=1000 * seed + i)

        index = indices[i]

        samp_update_op = sample.assign(samples[index:index + 1, :])
        
        sess.run(samp_update_op)
    
    return sample.eval(session=sess)


def code_grouped_greedy_sample_(sess,
                                target, 
                               proposal,
                               n_steps, 
                               n_bits_per_step,
                               seed,
                               max_group_size_bits=12,
                               adaptive=True,
                               backfitting_steps=0,
                               use_log_prob=False,
                               rho=1.):
    
    # Make sure the distributions have the correct type
    if target.dtype is not tf.float32:
        raise Exception("Target datatype must be float32!")
        
    if proposal.dtype is not tf.float32:
        raise Exception("Proposal datatype must be float32!")
    
    n_bits_per_group = n_bits_per_step * n_steps
    
    num_dimensions = np.prod(proposal.loc.shape.as_list())
    
    # rescale proposal by the proposal
    p_loc = tf.reshape(tf.zeros_like(proposal.loc), [-1])
    p_scale = tf.reshape(tf.ones_like(proposal.scale), [-1])
    
    # rescale target by the proposal
    t_loc = tf.reshape((target.loc - proposal.loc) / proposal.scale, [-1])
    t_scale = tf.reshape(target.scale / proposal.scale, [-1])
    
    kl_divergences = tf.reshape(tfd.kl_divergence(target, proposal), [-1])
        
    # ====================================================================== 
    # Preprocessing step: determine groups for sampling
    # ====================================================================== 

    group_start_indices = [0]
    group_kls = []
    
    kl_divs = sess.run(kl_divergences)

    total_kl_bits = np.sum(kl_divs) / np.log(2)

    print("Total KL to split up: {:.2f} bits, "
          "maximum bits per group: {}, "
          "estimated number of groups: {},"
          "coding {} dimensions".format(total_kl_bits, 
                                        n_bits_per_group,
                                        total_kl_bits // n_bits_per_group + 1,
                                        num_dimensions
                                        ))

    current_group_size = 0
    current_group_kl = 0
    
    n_nats_per_group = n_bits_per_group * np.log(2) - 1

    for idx in range(num_dimensions):

        group_bits = np.log(current_group_size + 1) / np.log(2)
        
        if group_bits >= max_group_size_bits or \
           current_group_kl + kl_divs[idx] >= n_nats_per_group or \
           idx == num_dimensions - 1:

            group_start_indices.append(idx)
            group_kls.append(current_group_kl / np.log(2))

            current_group_size = 1
            current_group_kl = kl_divs[idx]
            
        else:
            current_group_kl += kl_divs[idx]
            current_group_size += 1
            
    # ====================================================================== 
    # Sample each group
    # ====================================================================== 
    
    results = []
    
    group_start_indices += [num_dimensions] 
    
    for i in tqdm(range(len(group_start_indices) - 1)):
        
        start_idx = group_start_indices[i]
        end_idx = group_start_indices[i + 1]
        
        result = code_greedy_sample_(
            sess=sess,
            target=tfd.Normal(loc=t_loc[start_idx:end_idx],
                              scale=t_scale[start_idx:end_idx]), 

            proposal=tfd.Normal(loc=p_loc[start_idx:end_idx],
                                scale=p_scale[start_idx:end_idx]), 

            n_bits_per_step=n_bits_per_step, 
            n_steps=n_steps, 
            seed=i + seed,
            backfitting_steps=backfitting_steps,
            use_log_prob=use_log_prob,
            rho=rho)
        
        results.append(result)
        
    samples, codes = zip(*results)
    
    bitcode = ''.join(codes)
    sample = tf.concat(samples, axis=1)
    
    # Rescale the sample
    sample = tf.reshape(proposal.scale, [-1]) * sample + tf.reshape(proposal.loc, [-1])
    
    sample = sess.run(sample)
    
    return sample, bitcode, group_start_indices
  
    
def decode_grouped_greedy_sample_(sess,
                                  bitcode, 
                                 group_start_indices,
                                 proposal, 
                                 n_bits_per_step, 
                                 n_steps, 
                                 seed,
                                 adaptive=True,
                                 rho=1.):
    
    # Make sure the distributions have the correct type
    if proposal.dtype is not tf.float32:
        raise Exception("Proposal datatype must be float32!")
    
    n_bits_per_group = n_bits_per_step * n_steps
    
    num_dimensions = np.prod(proposal.loc.shape.as_list())
    
    # ====================================================================== 
    # Decode each group
    # ====================================================================== 
                
    samples = []
    
    group_start_indices += [num_dimensions]
    
    p_loc = tf.reshape(tf.zeros_like(proposal.loc), [-1])
    p_scale = tf.reshape(tf.ones_like(proposal.scale), [-1])
    
    for i in tqdm(range(len(group_start_indices) - 1)):
        
        samples.append(decode_greedy_sample_(
            sess=sess,
            sample_index=bitcode[n_bits_per_group * i: n_bits_per_group * (i + 1)],
            
            proposal=tfd.Normal(loc=p_loc[group_start_indices[i]:group_start_indices[i + 1]],
                                scale=p_scale[group_start_indices[i]:group_start_indices[i + 1]]), 
            
            n_bits_per_step=n_bits_per_step, 
            n_steps=n_steps, 
            seed=i + seed,
            rho=rho))
        
    sample = tf.concat(samples, axis=1)
    
    # Rescale the sample
    sample = tf.reshape(proposal.scale, [-1]) * sample + tf.reshape(proposal.loc, [-1])
    
    return sess.run(sample)

In [37]:
r = tf.random.stateless_normal(shape=[3], seed=[2, 8])

with tf.Session() as sess:
    
    print(sess.run(tf.tile(tf.expand_dims(r, 0), [2,1])))

[[1.3804178 1.9624459 0.2679907]
 [1.3804178 1.9624459 0.2679907]]
