# Differentiable bitonic sort

[Bitonic sorts](https://en.wikipedia.org/wiki/Bitonic_sorter) allow creation of sorting networks with a sequence of fixed conditional swapping operations executed in parallel. A sorting network implements  a map from $\mathbb{R}^n \rightarrow \mathbb{R}^n$, where $n=2^k$ (sorting networks for non-power-of-2 sizes are possible but not trickier).

<img src="BitonicSort1.svg.png">

*[Image: from Wikipedia, by user Bitonic, CC0](https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg)*

The sorting network for $n=2^k$ elements has $\frac{k(k-1)}{2}$ "layers" where parallel compare-and-swap operations are used to rearrange a $k$ element vector into sorted order.

### Differentiable compare-and-swap

If we define the `softmax(a,b)` function (not the traditional "softmax" used for classification!) as the continuous approximation to the `max(a,b)` function:

$$\text{softmax}(a,b) = \log(e^a + e^b) \approx \max(a,b).$$

We can then fairly obviously write `softmin(a,b)` as:

$$\text{softmin}(a,b) = -\log(e^{-a} + e^{-b}) \approx \min(a,b).$$

These functions obviously aren't equal to max and min, but are relatively close, and differentiable. Note that we now have a differentiable compare-and-swap operation:

$$\text{high} = \text{softmax}(a,b), \text{low} = \text{softmin}(a,b), \text{where } \text{low}\leq \text{high}$$

## Differentiable sorting

For each layer in the sorting network, we can split all of the pairwise comparison-and-swaps into left-hand and right-hand sides which can be done simultaneously. We can any write function that selects the relevant elements of the vector as a multiply with a binary matrix.

For each layer, we can derive two binary matrices $L \in \mathbb{R}^{n \times \frac{n}{2}}$ and $R \in \mathbb{R}^{n \times \frac{n}{2}}$ which select the elements to be compared for the left and right hands respectively. This will result in the comparison between two $\frac{k}{2}$ length vectors. We can also derive two matrices $L' \in \mathbb{R}^{\frac{n}{2} \times n}$ and $R' \in \mathbb{R}^{\frac{n}{2} \times n}$ which put the results of the compare-and-swap operation back into the right positions.

Then, each layer $i$ of the sorting process is just:
$${\bf x}_{i+1} = L'_i[\text{softmin}(L_i{\bf x_i}, R_i{\bf x_i})] + R'_i[\text{softmax}(L_i{\bf x_i}, R_i{\bf x_i})]$$
$$ = L'_i\left(-\log\left(e^{-L_i{\bf x}_i} + e^{-R_i{\bf x}_i}\right)\right) +  R'_i\left(\log\left(e^{L_i{\bf x}_i} + e^{R_i{\bf x}_i}\right)\right)$$
which is clearly differentiable (though not very numerically stable -- the usable range of elements $x$ is quite limited in single float precision).

All that remains is to compute the matrices $L_i, R_i, L'_i, R'_i$ for each of the layers of the network. 

## Example

To sort four elements, we have a network like:

    0  1  2  3  
    ┕>>┙  │  │  
    │  │  ┕<<┙  
    ┕>>>>>┙  │  
    │  │  │  │  
    ┕>>┙  │  │  
    │  │  ┕>>┙  
    
This is equivalent to: 

    x[0], x[1] = cswap(x[0], x[1])
    x[3], x[2] = cswap(x[2], x[3])
    x[0], x[2] = cswap(x[0], x[2])
    x[0], x[1] = cswap(x[0], x[1])
    x[2], x[3] = cswap(x[2], x[3])
    
where `cswap(a,b) = (min(a,b), max(a,b))`

Replacing the indexing with matrix multiplies and `cswap` with a `softcswap = (softmin(a,b), softmax(a,b))` we then have the differentiable form.



# Test functions

In [None]:
import numpy as np

def bitonic_network(n):
    """Check the computation of a bitonic network"""
    layers = int(np.log2(n))
    for layer in range(1, layers + 1):
        for sub in reversed(range(1, layer + 1)):
            for i in range(0, n, 2**sub):
                for j in range(2**(sub - 1)):
                    ix = i + j
                    a, b = ix, ix + (2**(sub - 1))
                    swap = "<" if (ix >> layer) & 1 else ">"
                    print(f"{a:>2}{swap}{b:<d}", end="\t")
            print()
        print("-" * n * 4)


# this should match the diagram at the top of the notebook
bitonic_network(16)

In [None]:
def pretty_bitonic_network(n):
    """Pretty print a bitonic network,
    to check the logic is correct"""
    layers = int(np.log2(n))
    # header
    for i in range(n):
        print(f"{i:<2d} ", end='')
    print()

    # layers
    for layer in range(1, layers + 1):
        for sub in reversed(range(1, layer + 1)):
            for i in range(0, n, 2**sub):
                for j in range(2**(sub - 1)):
                    ix = i + j
                    a, b = ix, ix + (2**(sub - 1))
                    way = "<" if (ix >> layer) & 1 else ">"

                    # this could be neater...
                    for i in range(n):
                        if i == b:
                            print("┙", end='')
                        elif i == a:
                            print("┕", end='')
                        elif not (a < i < b):
                            print("│", end='')
                        else:
                            print(way, end='')
                        if a <= i < b:
                            print(way * 2, end='')
                        else:
                            print(" " * 2, end='')
                    print()


pretty_bitonic_network(16)

# Vectorised functions

In [None]:
import numpy as np


def softmax(a, b):
    """The softmaximum of softmax(a,b) = log(e^a + a^b)."""
    return np.log(np.exp(a) + np.exp(b))


def softmin(a, b):
    """
    Return the soft-minimum of a and b
    The soft-minimum can be derived directly from softmax(a,b).
    """
    return -softmax(-a, -b)


def softrank(a, b):
    """Return a,b in 'soft-sorted' order, with the smaller value first"""
    return softmin(a, b), softmax(a, b)

In [None]:
def bitonic_matrices(n):
    """Compute a set of bitonic sort matrices to sort a sequence of
    length n. n *must* be a power of 2.
    
    See: https://en.wikipedia.org/wiki/Bitonic_sorter
    
    Set k=log2(n).
    There will be k "layers", i=1, 2, ... k
    
    Each ith layer will have i sub-steps, so there are (k*(k+1)) / 2 sorting steps total.
    
    For each step, we compute 4 matrices. l and r are binary matrices of size (k/2, k) and
    inv_l and inv_r are matrices of size (k, k/2).
    
    l and r "interleave" the inputs into two k/2 size vectors. inv_l and inv_r "uninterleave" these two k/2 vectors
    back into two k sized vectors that can be summed to get the correct output.
                    
    The result is such that to apply any layer's sorting, we can perform:
    
    l, r, inv_l, inv_r = layer[j]
    a, b =  l @ y, r @ y                
    permuted = inv_l @ np.minimum(a, b) + inv_r @ np.maximum(a,b)
        
    Applying this operation for each layer in sequence sorts the input vector.
            
    """
    # number of outer layers
    layers = int(np.log2(n))
    matrices = []
    for layer in range(1, layers + 1):
        # we have 1..layer sub layers
        for sub in reversed(range(1, layer + 1)):
            l, r = np.zeros((n // 2, n)), np.zeros((n // 2, n))
            inv_l, inv_r = np.zeros((n, n // 2)), np.zeros((n, n // 2))
            out = 0
            for i in range(0, n, 2**sub):
                for j in range(2**(sub - 1)):
                    ix = i + j
                    a, b = ix, ix + (2**(sub - 1))
                    l[out, a] = 1
                    r[out, b] = 1
                    if (ix >> layer) & 1:
                        a, b = b, a
                    inv_l[a, out] = 1
                    inv_r[b, out] = 1
                    out += 1
            matrices.append((l, r, inv_l, inv_r))
    return matrices


In [None]:
def bisort(matrices, x):
    """
    Given a set of bitonic sort matrices generated by bitonic_matrices(n), sort 
    a sequence x of length n. Sorts exactly.
    """
    for l, r, map_l, map_r in matrices:
        a, b = l @ x, r @ x
        x = map_l @ np.minimum(a, b) + map_r @ np.maximum(a, b)
    return x


def diff_bisort(matrices, x):
    """
    Approximate differentiable sort. Takes a set of bitonic sort matrices generated by bitonic_matrices(n), sort 
    a sequence x of length n. Values will be distorted slightly but will be ordered.
    """
    for l, r, map_l, map_r in matrices:
        a, b = softrank(l @ x, r @ x)
        x = map_l @ a + map_r @ b
    return x

## Testing

In [None]:
# Test sorting
matrices = bitonic_matrices(8)

for i in range(10):
    # these should all be in sorted order
    test = np.random.randint(0, 200, 8)
    print(bisort(matrices, test))    

In [None]:
for i in range(1, 11):
    k = 2**i
    matrices = bitonic_matrices(k)
    print(f"Testing sorting for {k} elements")
    for j in range(100):
        test = np.random.randint(0, 200, k)
        assert (np.allclose(bisort(matrices, test), np.sort(test)))

## Differentiable sorting test

In [None]:
# Differentiable sorting 
np.set_printoptions(precision=2)
matrices = bitonic_matrices(8) 
def neat_vec(n):
    return "\t".join([f"{x:.2f}" for x in n])

for i in range(10):
    test = np.random.randint(-200,200,8)
    print("Differentiable", neat_vec(diff_bisort(matrices, test)))
    print("Exact sorting ", neat_vec(bisort(matrices, test)))
    print()

# Relaxed sorting
We can define a slighly modified function which interpolates between `softmax(a,b)` and `mean(a,b)`. The result is a sorting function that can be relaxed from sorting to averaging.

In [None]:
def softmax_smooth(a, b, smooth=0):
    """The smoothed softmaximum of softmax(a,b) = log(e^a + a^b).
    With smooth=0.0, is softmax; with smooth=1.0, averages a and b"""
    t = smooth / 2.0
    return np.log(np.exp((1-t) * a + b * t) + np.exp((1-t)*b + t *a) ) - np.log(1+smooth)

def softrank_smooth(a, b, smooth=0):
    """The smoothed compare and swap of a and b
    With smooth=0, if softrank; with smooth=1.0, geometrically averages a and b"""
    return -softmax_smooth(-a, -b, smooth), softmax_smooth(a, b, smooth)

def diff_bisort_smooth(matrices, x, smooth=0):
    """
    Approximate differentiable sort. Takes a set of bitonic sort matrices generated by bitonic_matrices(n), sort 
    a sequence x of length n. Values will be distorted slightly but will be ordered.
    """
    for l, r, map_l, map_r in matrices:
        a, b = softrank_smooth (l @ x, r @ x, smooth)
        x = map_l @ a + map_r @ b
    return x


In [None]:
# Differentiable smoothed sorting 
test = np.random.randint(-200,200,8)
print(f"Mean {np.mean(test)}")
print()
print("Exact sorting       ", neat_vec(bisort(matrices, test)))
for smooth in np.linspace(0, 1, 8):    
    print(f"Diff. smooth[{smooth:.2f}]  ", neat_vec(diff_bisort_smooth(matrices, test, smooth)))
        

## Smooth ranking
We can use a differentiable similarity measure between the input and output of the vector, e.g. an RBF kernel. This gives a similarity matrix which we can then apply (traditional) softmax to.

In [None]:
def order_matrix(original, sortd, sigma=0.1):
    diff = np.subtract.outer(original, sortd)**2
    rbf = np.exp(-(diff**2) / (2*sigma**2))
    return (rbf.T / np.sum(rbf, axis=1)).T
    return rbf

# this is the traditional softmax e^x_i / (sum e^x_i) applied to rows
def softmax_matrix(matrix):
    e = np.exp(matrix)
    return (e.T / np.sum(e, axis=1)).T

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
matrices = bitonic_matrices(4)
test = [18, 1, 9, 2]
sortd = diff_bisort(matrices, test)
similarity = order_matrix(test, sortd, sigma=1)
print(similarity)
plt.imshow(similarity)

In [None]:
print(softmax_matrix(similarity))
print(np.sum(softmax_matrix(similarity), axis=1))
plt.imshow(softmax_matrix(similarity), vmin=0, vmax=1)

# PyTorch example
We can verify that this is both parallelisable on the GPU and fully differentiable.

In [None]:
import torch
from torch.autograd import Variable
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)

In [None]:
matrices = bitonic_matrices(16)
torch_matrices = [[torch.from_numpy(matrix).float().to(device) for matrix in matrix_set] for matrix_set in matrices]

In [None]:
# override softmax to use torch tensors
def softmax(a, b):
    """The softmaximum of softmax(a,b) = log(e^a + e^b)."""
    return torch.log(torch.exp(a) + torch.exp(b))

In [None]:
test_input = np.random.normal(0, 5, 16)
var_test_input = Variable(torch.from_numpy(test_input).float().to(device),
                          requires_grad=True)
result = diff_bisort(torch_matrices, var_test_input)

# compute the Jacobian of the sorting function, to show we can differentiate through the
# sorting function
jac = []
for i in range(len(result)):
    jac.append(
        torch.autograd.grad(result[i], var_test_input, retain_graph=True)[0])

# 16 x 16 jacobian of the sorting matrix
print(torch.stack(jac))