# Differentiable bitonic sort

[Bitonic sorts](https://en.wikipedia.org/wiki/Bitonic_sorter) allow creation of sorting networks with a sequence of fixed conditional swapping operations executed in parallel. A sorting network implements  a map from $\mathbb{R}^n \rightarrow \mathbb{R}^n$, where $n=2^k$ (sorting networks for non-power-of-2 sizes are possible but not trickier).

<img src="BitonicSort1.svg.png">

*[Image: from Wikipedia, by user Bitonic, CC0](https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg)*

The sorting network for $2^k$ elements has $\frac{k(k-1)}{2}$ "layers" where parallel compare-and-swap operations are used to rearrange a $k$ element vector into sorted order.

### Differentiable compare-and-swap

If we define the `softmax(a,b)` function (not the traditional "softmax" used for classification!) as the continuous approximation to the `max(a,b)` function:

$$\text{softmax}(a,b) = \log(e^a + e^b) \approx \max(a,b).$$

We can then fairly obviously write `softmin(a,b)` as:

$$\text{softmin}(a,b) = -\log(e^{-a} + e^{-b}) \approx \min(a,b).$$

These functions obviously aren't equal to max and min, but are relatively close, and differentiable. Note that we now have a differentiable compare-and-swap operation:

$$\text{high} = \text{softmax}(a,b), \text{low} = \text{softmin}(a,b), \text{where } \text{low}\leq \text{high}$$

## Differentiable sorting

For each layer in the sorting network, we can split all of the pairwise comparison-and-swaps into left-hand and right-hand sides which can be done simultaneously. We can any write function that selects the relevant elements of the vector as a multiply with a binary matrix.

For each layer, we can derive two binary matrices $L \in \mathbb{R}^{k \times \frac{k}{2}}$ and $R \in \mathbb{R}^{k \times \frac{k}{2}}$ which select the elements to be compared for the left and right hands respectively. This will result in the comparison between two $\frac{k}{2}$ length vectors. We can also derive two matrices $L' \in \mathbb{R}^{\frac{k}{2} \times k}$ and $R' \in \mathbb{R}^{\frac{k}{2} \times k}$ which put the results of the compare-and-swap operation back into the right positions.

Then, each layer $i$ of the sorting process is just:
$${\bf x}_{i+1} = L'_i[\text{softmin}(L_i{\bf x_i}, R_i{\bf x_i})] + R'_i[\text{softmax}(L_i{\bf x_i}, R_i{\bf x_i})]$$
$$ = L'_i\left(-\log\left(e^{-L_i{\bf x}_i} + e^{-R_i{\bf x}_i}\right)\right) +  R'_i\left(\log\left(e^{L_i{\bf x}_i} + e^{R_i{\bf x}_i}\right)\right)$$
which is clearly differentiable (though not very numerically stable -- the usable range of elements $x$ is quite limited in single float precision).

All that remains is to compute the matrices $L_i, R_i, L'_i, R'_i$ for each of the layers of the network. 

## Example

To sort four elements, we have a network like:

    0  1  2  3  
    ┕>>┙  │  │  
    │  │  ┕<<┙  
    ┕>>>>>┙  │  
    │  │  │  │  
    ┕>>┙  │  │  
    │  │  ┕>>┙  
    
This is equivalent to: 

    x[0], x[1] = cswap(x[0], x[1])
    x[3], x[2] = cswap(x[2], x[3])
    x[0], x[2] = cswap(x[0], x[2])
    x[0], x[1] = cswap(x[0], x[1])
    x[2], x[3] = cswap(x[2], x[3])
    
where `cswap(a,b) = (min(a,b), max(a,b))`

Replacing the indexing with matrix multiplies and `cswap` with a `softcswap = (softmin(a,b), softmax(a,b))` we then have the differentiable form.



# Test functions

In [1]:
import numpy as np

def bitonic_network(n):
    """Check the computation of a bitonic network"""
    layers = int(np.log2(n))
    for layer in range(1, layers + 1):
        for sub in reversed(range(1, layer + 1)):
            for i in range(0, n, 2**sub):
                for j in range(2**(sub - 1)):
                    ix = i + j
                    a, b = ix, ix + (2**(sub - 1))
                    swap = "<" if (ix >> layer) & 1 else ">"
                    print(f"{a:>2}{swap}{b:<d}", end="\t")
            print()
        print("-" * n * 4)


# this should match the diagram at the top of the notebook
bitonic_network(16)

 0>1	 2<3	 4>5	 6<7	 8>9	10<11	12>13	14<15	
----------------------------------------------------------------
 0>2	 1>3	 4<6	 5<7	 8>10	 9>11	12<14	13<15	
 0>1	 2>3	 4<5	 6<7	 8>9	10>11	12<13	14<15	
----------------------------------------------------------------
 0>4	 1>5	 2>6	 3>7	 8<12	 9<13	10<14	11<15	
 0>2	 1>3	 4>6	 5>7	 8<10	 9<11	12<14	13<15	
 0>1	 2>3	 4>5	 6>7	 8<9	10<11	12<13	14<15	
----------------------------------------------------------------
 0>8	 1>9	 2>10	 3>11	 4>12	 5>13	 6>14	 7>15	
 0>4	 1>5	 2>6	 3>7	 8>12	 9>13	10>14	11>15	
 0>2	 1>3	 4>6	 5>7	 8>10	 9>11	12>14	13>15	
 0>1	 2>3	 4>5	 6>7	 8>9	10>11	12>13	14>15	
----------------------------------------------------------------


In [2]:
def pretty_bitonic_network(n):
    """Pretty print a bitonic network,
    to check the logic is correct"""
    layers = int(np.log2(n))
    # header
    for i in range(n):
        print(f"{i:<2d} ", end='')
    print()

    # layers
    for layer in range(1, layers + 1):
        for sub in reversed(range(1, layer + 1)):
            for i in range(0, n, 2**sub):
                for j in range(2**(sub - 1)):
                    ix = i + j
                    a, b = ix, ix + (2**(sub - 1))
                    way = "<" if (ix >> layer) & 1 else ">"

                    # this could be neater...
                    for i in range(n):
                        if i == b:
                            print("┙", end='')
                        elif i == a:
                            print("┕", end='')
                        elif not (a < i < b):
                            print("│", end='')
                        else:
                            print(way, end='')
                        if a <= i < b:
                            print(way * 2, end='')
                        else:
                            print(" " * 2, end='')
                    print()


pretty_bitonic_network(16)

0  1  2  3  4  5  6  7  8  9  10 11 12 13 14 15 
┕>>┙  │  │  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  ┕<<┙  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  ┕>>┙  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  │  │  ┕<<┙  │  │  │  │  │  │  │  │  
│  │  │  │  │  │  │  │  ┕>>┙  │  │  │  │  │  │  
│  │  │  │  │  │  │  │  │  │  ┕<<┙  │  │  │  │  
│  │  │  │  │  │  │  │  │  │  │  │  ┕>>┙  │  │  
│  │  │  │  │  │  │  │  │  │  │  │  │  │  ┕<<┙  
┕>>>>>┙  │  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  ┕<<<<<┙  │  │  │  │  │  │  │  │  │  
│  │  │  │  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  │  │  │  │  ┕>>>>>┙  │  │  │  │  │  
│  │  │  │  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  │  │  │  │  │  │  │  │  ┕<<<<<┙  │  
│  │  │  │  │  │  │  │  │  │  │  │  │  │  │  │  
┕>>┙  │  │  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  ┕>>┙  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  ┕<<┙  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  │  │  ┕<

# Vectorised functions

In [3]:
import numpy as np


def softmax(a, b):
    """The softmaximum of softmax(a,b) = log(e^a + a^b)."""
    return np.log(np.exp(a) + np.exp(b))


def softmin(a, b):
    """
    Return the soft-minimum of a and b
    The soft-minimum can be derived directly from softmax(a,b).
    """
    return -softmax(-a, -b)


def softrank(a, b):
    """Return a,b in 'soft-sorted' order, with the smaller value first"""
    return softmin(a, b), softmax(a, b)

In [4]:
def bitonic_matrices(n):
    """Compute a set of bitonic sort matrices to sort a sequence of
    length n. n *must* be a power of 2.
    
    See: https://en.wikipedia.org/wiki/Bitonic_sorter
    
    Set k=log2(n).
    There will be k "layers", i=1, 2, ... k
    
    Each ith layer will have i sub-steps, so there are (k*(k+1)) / 2 sorting steps total.
    
    For each step, we compute 4 matrices. l and r are binary matrices of size (k/2, k) and
    inv_l and inv_r are matrices of size (k, k/2).
    
    l and r "interleave" the inputs into two k/2 size vectors. inv_l and inv_r "uninterleave" these two k/2 vectors
    back into two k sized vectors that can be summed to get the correct output.
                    
    The result is such that to apply any layer's sorting, we can perform:
    
    l, r, inv_l, inv_r = layer[j]
    a, b =  l @ y, r @ y                
    permuted = inv_l @ np.minimum(a, b) + inv_r @ np.maximum(a,b)
        
    Applying this operation for each layer in sequence sorts the input vector.
            
    """
    # number of outer layers
    layers = int(np.log2(n))
    matrices = []
    for layer in range(1, layers + 1):
        # we have 1..layer sub layers
        for sub in reversed(range(1, layer + 1)):
            l, r = np.zeros((n // 2, n)), np.zeros((n // 2, n))
            inv_l, inv_r = np.zeros((n, n // 2)), np.zeros((n, n // 2))
            out = 0
            for i in range(0, n, 2**sub):
                for j in range(2**(sub - 1)):
                    ix = i + j
                    a, b = ix, ix + (2**(sub - 1))
                    l[out, a] = 1
                    r[out, b] = 1
                    if (ix >> layer) & 1:
                        a, b = b, a
                    inv_l[a, out] = 1
                    inv_r[b, out] = 1
                    out += 1
            matrices.append((l, r, inv_l, inv_r))
    return matrices


In [5]:
def bisort(matrices, x):
    """
    Given a set of bitonic sort matrices generated by bitonic_matrices(n), sort 
    a sequence x of length n. Sorts exactly.
    """
    for l, r, map_l, map_r in matrices:
        a, b = l @ x, r @ x
        x = map_l @ np.minimum(a, b) + map_r @ np.maximum(a, b)
    return x


def diff_bisort(matrices, x):
    """
    Approximate differentiable sort. Takes a set of bitonic sort matrices generated by bitonic_matrices(n), sort 
    a sequence x of length n. Values will be distorted slightly but will be ordered.
    """
    for l, r, map_l, map_r in matrices:
        a, b = softrank(l @ x, r @ x)
        x = map_l @ a + map_r @ b
    return x

## Testing

In [6]:
# Test sorting
matrices = bitonic_matrices(8)

for i in range(10):
    # these should all be in sorted order
    test = np.random.randint(0, 200, 8)
    print(bisort(matrices, test))    

[  1.  35.  39. 111. 122. 128. 164. 186.]
[  3.  27.  38.  93. 124. 148. 168. 195.]
[ 32.  36.  38.  45.  72.  72.  87. 180.]
[  3.   5.  18.  58.  66. 106. 107. 127.]
[ 11.  20.  55.  59. 129. 162. 178. 198.]
[ 15.  24.  52.  90. 147. 152. 162. 179.]
[  9.  18.  66.  81. 116. 168. 186. 192.]
[  6.  32.  42.  49.  59. 107. 107. 145.]
[ 23.  34.  60.  66. 108. 128. 135. 152.]
[ 26.  34.  64.  73.  99. 107. 141. 197.]


In [7]:
for i in range(1, 11):
    k = 2**i
    matrices = bitonic_matrices(k)
    print(f"Testing sorting for {k} elements")
    for j in range(100):
        test = np.random.randint(0, 200, k)
        assert (np.allclose(bisort(matrices, test), np.sort(test)))

Testing sorting for 2 elements
Testing sorting for 4 elements
Testing sorting for 8 elements
Testing sorting for 16 elements
Testing sorting for 32 elements
Testing sorting for 64 elements
Testing sorting for 128 elements
Testing sorting for 256 elements
Testing sorting for 512 elements
Testing sorting for 1024 elements


## Differentiable sorting test

In [8]:
# Differentiable sorting 
np.set_printoptions(precision=2)
matrices = bitonic_matrices(8) 
def neat_vec(n):
    return "\t".join([f"{x:.2f}" for x in n])

for i in range(10):
    test = np.random.randint(-200,200,8)
    print("Differentiable", neat_vec(diff_bisort(matrices, test)))
    print("Exact sorting ", neat_vec(bisort(matrices, test)))
    print()

Differentiable -190.00	-160.00	-50.02	-45.98	-37.00	13.00	153.00	161.00
Exact sorting  -190.00	-160.00	-50.00	-46.00	-37.00	13.00	153.00	161.00

Differentiable -184.00	-162.00	-123.00	-102.00	-34.00	-21.00	-14.00	75.00
Exact sorting  -184.00	-162.00	-123.00	-102.00	-34.00	-21.00	-14.00	75.00

Differentiable -123.02	-118.98	-66.00	-32.00	44.00	63.00	94.87	97.13
Exact sorting  -123.00	-119.00	-66.00	-32.00	44.00	63.00	95.00	97.00

Differentiable -186.00	-145.00	-64.31	-62.69	-26.00	38.00	136.00	175.00
Exact sorting  -186.00	-145.00	-64.00	-63.00	-26.00	38.00	136.00	175.00

Differentiable -200.00	-17.00	38.00	70.00	123.00	130.00	144.00	152.00
Exact sorting  -200.00	-17.00	38.00	70.00	123.00	130.00	144.00	152.00

Differentiable -142.00	-133.00	54.00	82.00	99.99	104.96	108.05	120.00
Exact sorting  -142.00	-133.00	54.00	82.00	100.00	105.00	108.00	120.00

Differentiable -185.00	-156.13	-153.87	-140.00	-88.00	-67.00	104.00	134.00
Exact sorting  -185.00	-156.00	-154.00	-140.00	-88.00	-67.00	104

# PyTorch example
We can verify that this is both parallelisable on the GPU and fully differentiable.

In [9]:
import torch
from torch.autograd import Variable
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)

Device: cuda:0


In [10]:
matrices = bitonic_matrices(16)
torch_matrices = [[torch.from_numpy(matrix).float().to(device) for matrix in matrix_set] for matrix_set in matrices]

In [11]:
# override softmax to use torch tensors
def softmax(a, b):
    """The softmaximum of softmax(a,b) = log(e^a + e^b)."""
    return torch.log(torch.exp(a) + torch.exp(b))

In [12]:
test_input = np.random.normal(0, 5, 16)
var_test_input = Variable(torch.from_numpy(test_input).float().to(device),
                          requires_grad=True)
result = diff_bisort(torch_matrices, var_test_input)

# compute the Jacobian of the sorting function, to show we can differentiate through the
# sorting function
jac = []
for i in range(len(result)):
    jac.append(
        torch.autograd.grad(result[i], var_test_input, retain_graph=True)[0])

# 16 x 16 jacobian of the sorting matrix
print(torch.stack(jac))

tensor([[4.5800e-08, 1.8749e-08, 3.3806e-05, 4.6596e-07, 1.1410e-04, 5.9777e-03,
         1.8987e-04, 8.2044e-02, 2.2655e-04, 4.6479e-06, 9.8265e-04, 9.1464e-07,
         4.2966e-05, 9.1038e-01, 2.3056e-07, 4.2127e-09],
        [3.3739e-07, 1.3609e-07, 2.1835e-04, 3.2274e-06, 9.4283e-04, 4.3112e-02,
         1.6340e-03, 8.6150e-01, 6.8937e-04, 1.8300e-05, 5.1409e-03, 3.4280e-06,
         1.4065e-04, 8.6599e-02, 6.8678e-07, 1.2289e-08],
        [8.3022e-06, 3.3489e-06, 6.4338e-03, 8.6423e-05, 1.5105e-02, 6.9114e-01,
         2.2909e-02, 3.7742e-02, 4.0939e-02, 7.8800e-04, 1.7737e-01, 1.4511e-04,
         5.3704e-03, 1.9177e-03, 3.9163e-05, 6.7272e-07],
        [1.5567e-05, 6.2741e-06, 1.2063e-02, 1.6179e-04, 2.8746e-02, 2.0295e-01,
         4.3566e-02, 1.5883e-02, 8.3621e-02, 1.7831e-03, 5.9980e-01, 3.4132e-04,
         1.0057e-02, 9.4583e-04, 6.1377e-05, 1.0538e-06],
        [1.4671e-04, 5.4957e-05, 9.7047e-02, 1.6404e-03, 1.6013e-01, 1.7751e-02,
         2.1403e-01, 9.3366e-04, 3.2895