In [35]:
import tensorflow as tf
tfl = tf.keras.layers

import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

import numpy as np

In [2]:
def gumbel_sample(shape, eps=1e-20):
    
    u = tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32)
    
    return -tf.math.log(-tf.math.log(u + eps))

In [3]:
class SinkhornNormalization(tfl.Layer):
    
    def __init__(self, num_samples=1, temperature=0.1, iters=20, noise=0.1, name="sinkhorn_layer", **kwargs):
        
        super(SinkhornNormalization, self).__init__(name=name, **kwargs)
        
        self.num_samples = num_samples
        self.iters = iters
        self.noise = noise
        self.temperature = tf.Variable(temperature, dtype=tf.float32, trainable=False, name="temperature")
        
    def call(self, inputs, training=False):
        
        n = inputs.shape[1]
        
        
        # Reshape to batch of square matrices
        inputs = tf.reshape(inputs, [-1, n, n])
        batch_size = inputs.shape[0]
        
        inputs = tf.tile(inputs, [self.num_samples if training else 1, 1, 1])
        
        inputs = inputs + gumbel_sample(inputs.shape) * self.noise
        
        inputs = inputs / self.temperature
    
        for _ in range(self.iters):
            
            # row normalization
            inputs -= tf.reshape(tf.reduce_logsumexp(inputs, axis=1), [-1, 1, n])
            
            # column normalization
            inputs -= tf.reshape(tf.reduce_logsumexp(inputs, axis=2), [-1, n, 1])
            
        soft_perm = tf.exp(inputs)
        
        soft_perm = tf.reshape(soft_perm, [-1, batch_size, n, n])
        soft_perm = tf.transpose(soft_perm, [1, 0, 2, 3])
        
        return soft_perm

# Sorting

In [4]:
class SortingNetwork(tf.keras.Model):
    
    def __init__(self, 
                 units, 
                 output_size, 
                 temperature=0.5, 
                 iters=5, 
                 num_gumbel_samples=1,
                 noise=0.1,
                 name="sorting_network", 
                 **kwargs):
        
        super(SortingNetwork, self).__init__(name=name, **kwargs)
        
        self.units = units
        self.output_size = output_size
        
        # Create Layers
        
        self.layer1 = tfl.Dense(self.units,
                                activation="relu")
        
        self.layer2 = tfl.Dense(self.output_size,
                                activation=None)
        
        self.sinkhorn_layer = SinkhornNormalization(temperature=temperature, 
                                                    iters=iters,
                                                    num_samples=num_gumbel_samples,
                                                    noise=noise)
        
        
    def call(self, inputs, training=False):
        
        n = inputs.shape[1]
        inputs = tf.reshape(inputs, [-1, n])
        
        batch_size = inputs.shape[0]
        
        flattened_inputs = tf.reshape(inputs, [-1, 1])
        
        activations = self.layer1(flattened_inputs)
        activations = self.layer2(activations)
        
        activations = tf.reshape(activations, [batch_size, self.output_size, self.output_size])
        
        soft_perm = self.sinkhorn_layer(activations, training=training)

        return soft_perm

In [5]:
def create_sorting_dataset(num_examples=1000, size=10, lower=0., upper=1., batch_size=32, shuffle_buffer=1000):
    
    shuffled = tf.random.uniform(shape=(num_examples, size),
                                 minval=lower,
                                 maxval=upper)
    
    sort = tf.sort(shuffled, axis=1)
    
    ds = tf.data.Dataset.from_tensor_slices((shuffled, sort))
    ds = ds.shuffle(shuffle_buffer)
    ds = ds.batch(batch_size)
    ds = ds.repeat()
    
    return ds

In [41]:
train_iters = 10000
output_size = 6
noise=0.5
num_samples=10
temperature=0.3

optimizer = tf.optimizers.Adam(1e-4)

dataset = create_sorting_dataset(num_examples=50000,
                                 size=output_size,
                                 batch_size=num_samples)

counter = 0

sorting_net = SortingNetwork(units=32, 
                             output_size=output_size, 
                             noise=noise, 
                             num_gumbel_samples=num_samples,
                             temperature=temperature)

for shuffled, sort in tqdm(dataset.take(train_iters), total=train_iters):
    counter += 1

    with tf.GradientTape() as tape:
        
        soft_perms = sorting_net(shuffled, training=True)
        
        tiled_shuffled = tf.reshape(tf.tile(shuffled, [num_samples, 1]), [num_samples, num_samples, output_size])
        tiled_shuffled = tf.transpose(tiled_shuffled, [1, 0, 2])
        
        tiled_sorted = tf.reshape(tf.tile(sort, [num_samples, 1]), [num_samples, num_samples, output_size])
        tiled_sorted = tf.transpose(tiled_sorted, [1, 0, 2])
        
        inv_soft_perms = tf.transpose(soft_perms, [0, 1, 3, 2])
        
        tiled_unshuffled = tf.einsum("ijkl, ijl -> ijk", inv_soft_perms, tiled_shuffled)
        
        loss = tf.reduce_mean(tf.math.squared_difference(tiled_unshuffled, tiled_sorted))
        
    gradients = tape.gradient(loss, sorting_net.trainable_variables)
    optimizer.apply_gradients(zip(gradients, sorting_net.trainable_variables))
    
    if counter % 1000 == 0:
        print(f"counter: {counter}, loss: {loss}")

HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))

counter: 1000, loss: 0.07081587612628937
counter: 2000, loss: 0.04847860336303711
counter: 3000, loss: 0.025779765099287033
counter: 4000, loss: 0.02271021530032158
counter: 5000, loss: 0.017460448667407036
counter: 6000, loss: 0.015931902453303337
counter: 7000, loss: 0.013747588731348515
counter: 8000, loss: 0.01528224628418684
counter: 9000, loss: 0.01077294535934925
counter: 10000, loss: 0.012659034691751003



In [40]:
x = tf.random.uniform((1, 6))

soft_perm = tf.squeeze(sorting_net(x, training=False))
perm = tf.argmax(soft_perm, axis=1).numpy()

print(perm)
print(x.numpy())
print(tf.matmul(x, soft_perm)[0].numpy())
print(x.numpy()[0, perm])

[4 0 5 1 3 2]
[[0.7391434  0.13792253 0.9787675  0.22269034 0.806396   0.36929846]]
[0.18392637 0.4347278  0.40168625 0.78888756 0.8034184  0.6415718 ]
[0.806396   0.7391434  0.36929846 0.13792253 0.22269034 0.9787675 ]


# Latent Matching