# Differentiable bitonic sort

[Bitonic sorts](https://en.wikipedia.org/wiki/Bitonic_sorter) allow creation of sorting networks with a sequence of fixed conditional swapping operations. This is a map from $\mathbb{R}^n \rightarrow \mathbb{R}^n$, where $n=2^k$.

<img src="BitonicSort1.svg.png">

*[Image: from Wikipedia, by user Bitonic, CC0](https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg)*

The sorting network has $\frac{k(k-1)}{2}$ "layers" where parallel compare-and-swap operations are used to rearrange a vector.

### Differentiable compare-and-swap

If we define the `softmax` function (not the traditional "softmax" used for classification!) as:

$$\text{softmax}(a,b) = \log(e^a + e^b).$$

We can then fairly obviously write `softmin` as:

$$\text{softmin}(a,b) = -\log(e^{-a} + e^{-b})$$

These aren't *exactly* equal to max and min, but are relatively close, and differentiable.

Note that we can now do a differentiable compare-and-swap operation:

$$\text{high} = \text{softmax}(a,b), \text{low} = \text{softmin}(a,b), \text{where } \text{low}\leq {high}$$

## Differentiable sorting.

For each layer in the sorting network, we can split all of the comparisons into the left-hand and right-hand sides. We can any function that selects the relevant elements of the vector as a multipy with a binary matrix.

For each layer, we can derive two binary matrices $L \in \mathbb{R}^{k \times \frac{k}{2}}$ and $R \in \mathbb{R}^{k \times \frac{k}{2}}$ which select the elements to be compared for the left and right hands respectively. This will result in the comparison between two $\frac{k}{2}$ length vectors. We can also derive two matrices $L' \in \mathbb{R}^{\frac{k}{2} \times k}$ and $R' \in \mathbb{R}^{\frac{k}{2} \times k}$ which put the results of the compare-and-swap operation back into the right positions.

Then, each layer $i$ of the sorting process is just:
$${\bf x}_{i+1} = L'_i\text{softmin}(L_i{\bf x_i}, R_i{\bf x_i}) + R'_i\text{softmax}(L_i{\bf x_i}, R_i{\bf x_i})$$
$$ = L'_i\left(-\log\left(e^{-L_i{\bf x}_i} + e^{-R_i{\bf x}_i}\right)\right) +  R'_i\left(\log\left(e^{L_i{\bf x}_i} + e^{R_i{\bf x}_i}\right)\right)$$
which is clearly differentiable (though not very numerically stable -- the usable range of elements $x$ is quite limited in single float precision).

All that remains is to compute the matrices $L_i, R_i, L'_i, R'_i$ for each of the layers of the network.



# Test functions

In [568]:
def bitonic_network(n):
    """Check the computation of a bitonic network"""
    layers = int(np.log2(n))          
    for layer in range(1, layers+1):                
        for sub in reversed(range(1, layer+1)):
            for i in range(0,n,2**sub):
                for j in range(2**(sub-1)):
                    ix = i + j
                    a, b = ix, ix+(2**(sub-1))
                    swap = "<" if (ix >> layer) & 1 else ">"
                    print(f"{a:>2}{swap}{b:<d}", end="\t")                                    
            print()
        print("-"*n*4)    
        
# this should match the diagram at the top of the notebook
bitonic_network(16)

 0>1	 2<3	 4>5	 6<7	 8>9	10<11	12>13	14<15	
----------------------------------------------------------------
 0>2	 1>3	 4<6	 5<7	 8>10	 9>11	12<14	13<15	
 0>1	 2>3	 4<5	 6<7	 8>9	10>11	12<13	14<15	
----------------------------------------------------------------
 0>4	 1>5	 2>6	 3>7	 8<12	 9<13	10<14	11<15	
 0>2	 1>3	 4>6	 5>7	 8<10	 9<11	12<14	13<15	
 0>1	 2>3	 4>5	 6>7	 8<9	10<11	12<13	14<15	
----------------------------------------------------------------
 0>8	 1>9	 2>10	 3>11	 4>12	 5>13	 6>14	 7>15	
 0>4	 1>5	 2>6	 3>7	 8>12	 9>13	10>14	11>15	
 0>2	 1>3	 4>6	 5>7	 8>10	 9>11	12>14	13>15	
 0>1	 2>3	 4>5	 6>7	 8>9	10>11	12>13	14>15	
----------------------------------------------------------------


In [569]:
def pretty_bitonic_network(n):
    """Pretty print a bitonic network,
    to check the logic is correct"""
    layers = int(np.log2(n))  
    # header
    for i in range(n):
        print(f"{i:<2d} ", end='')
    print()
    
    # layers
    for layer in range(1, layers+1):                
        for sub in reversed(range(1, layer+1)):
            for i in range(0,n,2**sub):
                for j in range(2**(sub-1)):
                    ix = i + j
                    a, b = ix, ix+(2**(sub-1))                    
                    way = "<" if (ix >> layer) & 1 else ">"
                    
                    # this could be neater...
                    for i in range(n):           
                        if i==b:
                            print("┙", end='')                                                
                        elif i==a:
                            print("┕", end='')                                                
                        elif not(a < i < b):
                            print("│", end='')
                        else:
                            print(way, end='')
                        if a <= i < b:
                            print(way*2, end='')
                        else:
                            print(" "*2, end='')
                    print()
                    
            
        
        
pretty_bitonic_network(16)            

0  1  2  3  4  5  6  7  8  9  10 11 12 13 14 15 
┕>>┙  │  │  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  ┕<<┙  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  ┕>>┙  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  │  │  ┕<<┙  │  │  │  │  │  │  │  │  
│  │  │  │  │  │  │  │  ┕>>┙  │  │  │  │  │  │  
│  │  │  │  │  │  │  │  │  │  ┕<<┙  │  │  │  │  
│  │  │  │  │  │  │  │  │  │  │  │  ┕>>┙  │  │  
│  │  │  │  │  │  │  │  │  │  │  │  │  │  ┕<<┙  
┕>>>>>┙  │  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  ┕<<<<<┙  │  │  │  │  │  │  │  │  │  
│  │  │  │  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  │  │  │  │  ┕>>>>>┙  │  │  │  │  │  
│  │  │  │  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  │  │  │  │  │  │  │  │  ┕<<<<<┙  │  
│  │  │  │  │  │  │  │  │  │  │  │  │  │  │  │  
┕>>┙  │  │  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  ┕>>┙  │  │  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  ┕<<┙  │  │  │  │  │  │  │  │  │  │  
│  │  │  │  │  │  ┕<

# Vectorised functions

In [570]:
import numpy as np

def softmax(a,b):
    """The softmaximum of softmax(a,b) = log(e^a + a^b)."""
    return np.log(np.exp(a)+np.exp(b))

def softmin(a,b):
    """
    Return the soft-minimum of a and b
    The soft-minimum can be derived directly from softmax(a,b).
    """
    return -softmax(-a, -b)

def softrank(a, b):
    """Return a,b in 'soft-sorted' order, with the smaller value first"""
    return softmin(a,b), softmax(a,b)

In [571]:
def bitonic_matrices(n):
    """Compute a set of bitonic sort matrices to sort a sequence of
    length n. n *must* be a power of 2.
    
    See: https://en.wikipedia.org/wiki/Bitonic_sorter
    
    Set k=log2(n).
    There will be k "layers", i=1, 2, ... k
    
    Each ith layer will have i sub-steps, so there are (k*(k+1)) / 2 sorting steps total.
    
    For each step, we compute 4 matrices. l and r are binary matrices of size (k/2, k) and
    map_l and map_r are matrices of size (k, k/2).
    
    l and r "interleave" the inputs into two k/2 size vectors. map_l and map_r "uninterleave" these two k/2 vectors
    back into two k sized vectors that can be summed to get the correct output.
                    
    The result is such that to apply any layer's sorting, we can perform:
    
    l, r, map_l, map_r = layer[j]
    a, b =  l @ y, r @ y                
    permuted = map_l @ np.minimum(a, b) + map_r @ np.maximum(a,b)
        
    Applying this operation for each layer in sequence sorts the input vector.
            
    """
    # number of outer layers
    layers = int(np.log2(n))
    matrices = []
    for layer in range(1, layers+1):
        # we have 1..layer sub layers
        for sub in reversed(range(1, layer+1)):
            l, r = np.zeros((n//2, n)), np.zeros((n//2, n))            
            map_l, map_r = np.zeros((n, n//2)), np.zeros((n, n//2))                        
            out = 0 
            for i in range(0,n,2**sub):
                for j in range(2**(sub-1)):
                    ix = i + j
                    a, b = ix, ix+(2**(sub-1))                    
                    l[out, a] = 1
                    r[out, b] = 1
                    if (ix >> layer) & 1:                                            
                        a, b = b,a                                                            
                    map_l[a, out] = 1
                    map_r[b, out] = 1                                        
                    out += 1
            matrices.append((l, r, map_l, map_r))            
    return matrices

In [572]:
def bisort(matrices, x):
    """
    Given a set of bitonic sort matrices generated by bitonic_matrices(n), sort 
    a sequence x of length n. Sorts exactly.
    """    
    for l, r, map_l, map_r in matrices:
        a, b = l @ x, r @ x                
        x = map_l @ np.minimum(a,b) + map_r @ np.maximum(a,b)        
    return x

def diff_bisort(matrices, x):    
    """
    Approximate differentiable sort. Takes a set of bitonic sort matrices generated by bitonic_matrices(n), sort 
    a sequence x of length n. Values will be distorted slightly but will be ordered.
    """  
    for l, r, map_l, map_r in matrices:
        a, b = softrank(l @ x, r @ x)        
        x = map_l @ a + map_r @ b
    return x

## Testing

In [574]:
# Test sorting
matrices = bitonic_matrices(8)        
for i in range(10):
    test = np.random.randint(0,200,8)
    print(bisort(matrices, test))
    assert(np.allclose(bisort(matrices, test), np.sort(test)))

[ 23.  39.  46.  84.  87. 103. 178. 190.]
[  0.  36.  47.  66.  93. 107. 150. 171.]
[ 16.  38.  45.  63.  89. 126. 137. 170.]
[  4.  58.  76. 107. 112. 123. 160. 166.]
[  5.  14.  25.  42.  72. 103. 149. 192.]
[  1.   3.  58.  59. 114. 139. 147. 161.]
[ 13.  31.  41.  52.  56.  92. 102. 130.]
[ 23.  40.  41.  54. 112. 119. 195. 197.]
[ 13.  76.  78.  99. 124. 136. 172. 172.]
[ 13.  15.  38.  84. 103. 106. 112. 127.]


In [580]:
for i in range(1,11):
    k = 2 ** i
    matrices = bitonic_matrices(k)  
    print(f"Testing sorting for {k} elements")
    for j in range(100):
        test = np.random.randint(0, 200, k)
        assert(np.allclose(bisort(matrices, test), np.sort(test)))

Testing sorting for 2 elements
Testing sorting for 4 elements
Testing sorting for 8 elements
Testing sorting for 16 elements
Testing sorting for 32 elements
Testing sorting for 64 elements
Testing sorting for 128 elements
Testing sorting for 256 elements
Testing sorting for 512 elements
Testing sorting for 1024 elements


## Differentiable sorting test

In [582]:
# Differentiable sorting 
np.set_printoptions(precision=2)
matrices = bitonic_matrices(8) 
for i in range(10):
    print(diff_bisort(matrices, np.random.randint(-200,200,8)))

[-106.  -63.  -33.   41.   73.   91.  107.  139.]
[-114.  -77.   10.   53.   70.   76.   92.  143.]
[-132. -112.  -94.   31.   73.   84.  119.  154.]
[-172. -132.  -94.  -72.  -29.   44.  104.  130.]
[-134.  -38.    9.   83.  107.  117.  156.  175.]
[-195.   -152.   -120.    -52.    -19.     -2.     60.95   64.05]
[-193.    -15.     -2.     22.31   23.69   63.     78.     86.  ]
[-180. -159. -131.  -78.   36.   56.  137.  179.]
[-138.  -88.  -53.  -42.  -35.   52.  146.  179.]
[-105.  -29.  -23.  -16.   82.   96.  109.  151.]


# PyTorch example
We can verify that this is both parallelisable on the GPU and fully differentiable.

In [583]:
import torch
from torch.autograd import Variable
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)

Device: cuda:0


In [584]:
matrices = bitonic_matrices(16)
torch_matrices = [[torch.from_numpy(matrix).float().to(device) for matrix in matrix_set] for matrix_set in matrices]

In [585]:
# override softmax to use torch tensors
def softmax(a,b):
    """The softmaximum of softmax(a,b) = log(e^a + e^b)."""
    return torch.log(torch.exp(a)+torch.exp(b))


In [586]:
test_input = np.random.normal(0, 5, 16)
var_test_input = Variable(torch.from_numpy(test_input).float().to(device), requires_grad=True)
result = diff_bisort(torch_matrices, var_test_input)

# compute the Jacobian of the sorting function, to show we can differentiate through the 
# sorting function
jac = []
for i in range(len(result)):
    jac.append(torch.autograd.grad(result[i], var_test_input, retain_graph=True)[0])

# 16 x 16 jacobian of the sorting matrix
print(torch.stack(jac))    


tensor([[3.2104e-01, 3.8757e-05, 2.2428e-02, 1.9859e-01, 1.5259e-03, 3.4902e-04,
         1.6809e-03, 2.3111e-03, 2.8201e-04, 2.8588e-01, 3.6264e-02, 6.3002e-04,
         3.1380e-05, 1.2835e-01, 3.6958e-04, 2.3259e-04],
        [2.7163e-01, 3.5779e-05, 1.9345e-02, 1.7872e-01, 1.4394e-03, 3.2754e-04,
         1.6050e-03, 2.2146e-03, 3.1813e-04, 3.4291e-01, 4.1697e-02, 7.1266e-04,
         3.4758e-05, 1.3836e-01, 4.0324e-04, 2.5539e-04],
        [1.6584e-01, 9.0402e-05, 3.8121e-02, 2.4091e-01, 4.1187e-03, 8.7890e-04,
         4.6219e-03, 6.5168e-03, 5.5896e-04, 1.5907e-01, 8.6112e-02, 1.2246e-03,
         7.0083e-05, 2.9058e-01, 7.7589e-04, 5.2269e-04],
        [1.6807e-01, 8.9644e-05, 3.7758e-02, 2.4639e-01, 4.0916e-03, 8.7654e-04,
         4.5988e-03, 6.4769e-03, 5.6263e-04, 1.5814e-01, 8.6429e-02, 1.2373e-03,
         7.0285e-05, 2.8390e-01, 7.7929e-04, 5.2365e-04],
        [3.0768e-02, 7.6955e-04, 2.3025e-01, 5.5505e-02, 4.5084e-02, 9.6885e-03,
         4.8241e-02, 6.5950e-02, 6.0669