# Differentiable bitonic sort

[Bitonic sorts](https://en.wikipedia.org/wiki/Bitonic_sorter) allow creation of sorting networks with a sequence of fixed conditional swapping operations executed in parallel. A sorting network implements  a map from $\mathbb{R}^n \rightarrow \mathbb{R}^n$, where $n=2^k$ (sorting networks for non-power-of-2 sizes are possible but not trickier).

<img src="BitonicSort1.svg.png">

*[Image: from Wikipedia, by user Bitonic, CC0](https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg)*

The sorting network for $n=2^k$ elements has $\frac{k(k-1)}{2}$ "layers" where parallel compare-and-swap operations are used to rearrange a $k$ element vector into sorted order.

### Differentiable compare-and-swap

If we define the `softmax(a,b)` function (not the traditional "softmax" used for classification!) as the continuous approximation to the `max(a,b)` function:

$$\text{softmax}(a,b) = \log(e^a + e^b) \approx \max(a,b).$$

We can then fairly obviously write `softmin(a,b)` as:

$$\text{softmin}(a,b) = -\log(e^{-a} + e^{-b}) \approx \min(a,b).$$

These functions obviously aren't equal to max and min, but are relatively close, and differentiable. Note that we now have a differentiable compare-and-swap operation:

$$\text{high} = \text{softmax}(a,b), \text{low} = \text{softmin}(a,b), \text{where } \text{low}\leq \text{high}$$

## Differentiable sorting

For each layer in the sorting network, we can split all of the pairwise comparison-and-swaps into left-hand and right-hand sides which can be done simultaneously. We can any write function that selects the relevant elements of the vector as a multiply with a binary matrix.

For each layer, we can derive two binary matrices $L \in \mathbb{R}^{n \times \frac{n}{2}}$ and $R \in \mathbb{R}^{n \times \frac{n}{2}}$ which select the elements to be compared for the left and right hands respectively. This will result in the comparison between two $\frac{k}{2}$ length vectors. We can also derive two matrices $L' \in \mathbb{R}^{\frac{n}{2} \times n}$ and $R' \in \mathbb{R}^{\frac{n}{2} \times n}$ which put the results of the compare-and-swap operation back into the right positions.

Then, each layer $i$ of the sorting process is just:
$${\bf x}_{i+1} = L'_i[\text{softmin}(L_i{\bf x_i}, R_i{\bf x_i})] + R'_i[\text{softmax}(L_i{\bf x_i}, R_i{\bf x_i})]$$
$$ = L'_i\left(-\log\left(e^{-L_i{\bf x}_i} + e^{-R_i{\bf x}_i}\right)\right) +  R'_i\left(\log\left(e^{L_i{\bf x}_i} + e^{R_i{\bf x}_i}\right)\right)$$
which is clearly differentiable (though not very numerically stable -- the usable range of elements $x$ is quite limited in single float precision).

All that remains is to compute the matrices $L_i, R_i, L'_i, R'_i$ for each of the layers of the network. 

This process is excessively computation heavy, but simple -- we could simplify this into two matrix multiplies, at the cost of a vector split and join in the middle. 

## Example

To sort four elements, we have a network like:

    0  1  2  3  
    ┕>>┙  │  │  
    │  │  ┕<<┙  
    ┕>>>>>┙  │  
    │  │  │  │  
    ┕>>┙  │  │  
    │  │  ┕>>┙  
    
This is equivalent to: 

    x[0], x[1] = cswap(x[0], x[1])
    x[3], x[2] = cswap(x[2], x[3])
    x[0], x[2] = cswap(x[0], x[2])
    x[0], x[1] = cswap(x[0], x[1])
    x[2], x[3] = cswap(x[2], x[3])
    
where `cswap(a,b) = (min(a,b), max(a,b))`

Replacing the indexing with matrix multiplies and `cswap` with a `softcswap = (softmin(a,b), softmax(a,b))` we then have the differentiable form.



# Test functions

In [None]:
from bitonic_tests import bitonic_network, pretty_bitonic_network

# this should match the diagram at the top of the notebook
bitonic_network(16)

In [None]:
pretty_bitonic_network(16)

# Vectorised functions

## Testing

In [None]:
# Test sorting
import numpy as np
from differentiable_sorting import bitonic_matrices, diff_bisort, diff_rank, softmax, softmin, softcswap, bisort
matrices = bitonic_matrices(8)

for i in range(10):
    # these should all be in sorted order
    test = np.random.randint(0, 200, 8)
    print(bisort(matrices, test))    

In [None]:
for i in range(1, 11):
    k = 2**i
    matrices = bitonic_matrices(k)
    print(f"Testing sorting for {k} elements")
    for j in range(100):
        test = np.random.randint(0, 200, k)
        assert (np.allclose(bisort(matrices, test), np.sort(test)))

## Differentiable sorting test

In [None]:
# Differentiable sorting 
np.set_printoptions(precision=2)
matrices = bitonic_matrices(8) 
def neat_vec(n):
    return "\t".join([f"{x:.2f}" for x in n])

for i in range(10):
    test = np.random.randint(-200,200,8)
    print("Differentiable", neat_vec(diff_bisort(matrices, test)))
    print("Exact sorting ", neat_vec(bisort(matrices, test)))
    print()

# Relaxed sorting
We can define a slighly modified function which interpolates between `softmax(a,b)` and `mean(a,b)`. The result is a sorting function that can be relaxed from sorting to averaging.

In [None]:
from differentiable_sorting import diff_bisort_smooth
# Differentiable smoothed sorting 
test = np.random.randint(-200,200,8)
print(f"Mean {np.mean(test)}")
print()
print("Exact sorting       ", neat_vec(bisort(matrices, test)))
for smooth in np.linspace(0, 1, 8):    
    print(f"Diff. smooth[{smooth:.2f}]  ", neat_vec(diff_bisort_smooth(matrices, test, smooth)))
        

## Woven form
We can "weave" the four matrices into two matrices for fewer multiplies at the cost of having to split and join the matrices at each layer.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from differentiable_sorting import diff_bisort_weave, bisort_woven_matrices

woven_matrices = bisort_woven_matrices(8)

print("Exact sorting       ", neat_vec(bisort(matrices, test)))
print(f"Diff. (std.)       ", neat_vec(diff_bisort(matrices, test)))
print(f"Diff. (woven)      ", neat_vec(diff_bisort_weave(woven_matrices, test)))
        

## Differentiable ranking
We can use a differentiable similarity measure between the input and output of the vector, e.g. an RBF kernel. We can use this to generate a normalised similarity matrix and apply this to a vector `[1, 2, 3, ..., n]`. This gives a differentiable ranking function.

As `sigma` gets larger, the result converges to giving all values the mean rank; as it goes to zero the result converges to the true rank.

In [None]:
from differentiable_sorting import order_matrix, diff_rank

In [None]:
matrices = bitonic_matrices(8)
test = np.random.randint(0, 200, 8)
sortd = diff_bisort(matrices, test)

In [None]:
print("Smoothed ranks")
for sigma in [0.1, 1, 10, 100, 1000]:
    similarity = order_matrix(test, sortd, sigma=sigma)    
    ranks = np.arange(len(test))    
    print(f"sigma={sigma:7.1f}  |", neat_vec(similarity @ ranks))

In [None]:
def diff_rank(matrices, x, sigma=0.1):
    """Return the smoothed, differentiable ranking of each element of x. Sigma
    specifies the smoothing of the ranking. """
    sortd = diff_bisort(matrices, x)
    return order_matrix(x, sortd, sigma=sigma) @ np.arange(len(x))

In [None]:
x = [5.0, -1.0, 9.5, 13.2, 16.2, 20.5, 42.0, 18.0]
np.set_printoptions(suppress=True)
print(neat_vec((diff_rank(matrices, x, sigma=0.1))))

# PyTorch example
We can verify that this is both parallelisable on the GPU and fully differentiable.

In [None]:
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)

In [None]:
from differentiable_sorting_torch import diff_bisort, bitonic_matrices, diff_rank
matrices = bitonic_matrices(16)
torch_matrices = [[torch.from_numpy(matrix).float().to(device) for matrix in matrix_set] for matrix_set in matrices]


In [None]:
test_input = np.random.normal(0, 5, 16)
var_test_input = Variable(torch.from_numpy(test_input).float().to(device),
                          requires_grad=True)

result = diff_bisort(torch_matrices, var_test_input)

# compute the Jacobian of the sorting function, to show we can differentiate through the
# sorting function
jac = []
for i in range(len(result)):
    jac.append(
        torch.autograd.grad(result[i], var_test_input, retain_graph=True)[0])

# 16 x 16 jacobian of the sorting matrix
print(torch.stack(jac))

In [None]:
result = diff_rank(torch_matrices, var_test_input)
print(result)
