# Differentiable bitonic sort

[Bitonic sorts](https://en.wikipedia.org/wiki/Bitonic_sorter) allow creation of sorting networks with a sequence of fixed conditional swapping operations executed in parallel. A sorting network implements  a map from $\mathbb{R}^n \rightarrow \mathbb{R}^n$, where $n=2^k$ (sorting networks for non-power-of-2 sizes are possible but not trickier).

<img src="BitonicSort1.svg.png">

*[Image: from Wikipedia, by user Bitonic, CC0](https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg)*

The sorting network for $n=2^k$ elements has $\frac{k(k-1)}{2}$ "layers" where parallel compare-and-swap operations are used to rearrange a $k$ element vector into sorted order.

### Differentiable compare-and-swap

If we define the `softmax(a,b)` function (not the traditional "softmax" used for classification!) as the continuous approximation to the `max(a,b)` function:

$$\text{softmax}(a,b) = \log(e^a + e^b) \approx \max(a,b).$$

We can then fairly obviously write `softmin(a,b)` as:

$$\text{softmin}(a,b) = -\log(e^{-a} + e^{-b}) \approx \min(a,b).$$

These functions obviously aren't equal to max and min, but are relatively close, and differentiable. Note that we now have a differentiable compare-and-swap operation:

$$\text{high} = \text{softmax}(a,b), \text{low} = \text{softmin}(a,b), \text{where } \text{low}\leq \text{high}$$

## Differentiable sorting

For each layer in the sorting network, we can split all of the pairwise comparison-and-swaps into left-hand and right-hand sides which can be done simultaneously. We can any write function that selects the relevant elements of the vector as a multiply with a binary matrix.

For each layer, we can derive two binary matrices $L \in \mathbb{R}^{n \times \frac{n}{2}}$ and $R \in \mathbb{R}^{n \times \frac{n}{2}}$ which select the elements to be compared for the left and right hands respectively. This will result in the comparison between two $\frac{k}{2}$ length vectors. We can also derive two matrices $L' \in \mathbb{R}^{\frac{n}{2} \times n}$ and $R' \in \mathbb{R}^{\frac{n}{2} \times n}$ which put the results of the compare-and-swap operation back into the right positions.

Then, each layer $i$ of the sorting process is just:
$${\bf x}_{i+1} = L'_i[\text{softmin}(L_i{\bf x_i}, R_i{\bf x_i})] + R'_i[\text{softmax}(L_i{\bf x_i}, R_i{\bf x_i})]$$
$$ = L'_i\left(-\log\left(e^{-L_i{\bf x}_i} + e^{-R_i{\bf x}_i}\right)\right) +  R'_i\left(\log\left(e^{L_i{\bf x}_i} + e^{R_i{\bf x}_i}\right)\right)$$
which is clearly differentiable (though not very numerically stable -- the usable range of elements $x$ is quite limited in single float precision).

All that remains is to compute the matrices $L_i, R_i, L'_i, R'_i$ for each of the layers of the network. 

This process is excessively computation heavy, but easy to compute. We could also simplify this into two matrix multiplies, at the cost of a vector split and join in the middle (see the `woven` form later in this text). 

## Example

To sort four elements, we have a network like:

    0  1  2  3  
    ┕>>┙  │  │  
    │  │  ┕<<┙  
    ┕>>>>>┙  │  
    │  │  │  │  
    ┕>>┙  │  │  
    │  │  ┕>>┙  
    
This is equivalent to: 

    x[0], x[1] = cswap(x[0], x[1])
    x[3], x[2] = cswap(x[2], x[3])
    x[0], x[2] = cswap(x[0], x[2])
    x[0], x[1] = cswap(x[0], x[1])
    x[2], x[3] = cswap(x[2], x[3])
    
where `cswap(a,b) = (min(a,b), max(a,b))`

Replacing the indexing with matrix multiplies and `cswap` with a `softcswap = (softmin(a,b), softmax(a,b))` we then have the differentiable form.



# Test functions

In [1]:
from bitonic_tests import bitonic_network, pretty_bitonic_network

def neat_vec(n):
    # print a vector neatly    
    return "\t".join([f"{x:.2f}" for x in n])

# this should match the diagram at the top of the notebook
bitonic_network(16)

 0>1	 2<3	 4>5	 6<7	 8>9	10<11	12>13	14<15	
----------------------------------------------------------------
 0>2	 1>3	 4<6	 5<7	 8>10	 9>11	12<14	13<15	
 0>1	 2>3	 4<5	 6<7	 8>9	10>11	12<13	14<15	
----------------------------------------------------------------
 0>4	 1>5	 2>6	 3>7	 8<12	 9<13	10<14	11<15	
 0>2	 1>3	 4>6	 5>7	 8<10	 9<11	12<14	13<15	
 0>1	 2>3	 4>5	 6>7	 8<9	10<11	12<13	14<15	
----------------------------------------------------------------
 0>8	 1>9	 2>10	 3>11	 4>12	 5>13	 6>14	 7>15	
 0>4	 1>5	 2>6	 3>7	 8>12	 9>13	10>14	11>15	
 0>2	 1>3	 4>6	 5>7	 8>10	 9>11	12>14	13>15	
 0>1	 2>3	 4>5	 6>7	 8>9	10>11	12>13	14>15	
----------------------------------------------------------------


In [64]:
pretty_bitonic_network(8)

 0  1  2  3  4  5  6  7 
 ╭──╯  │  │  │  │  │  │ 
 │  │  ╰──╮  │  │  │  │ 
 │  │  │  │  ╭──╯  │  │ 
 │  │  │  │  │  │  ╰──╮ 
 ╭─────╯  │  │  │  │  │ 
 │  ╭─────╯  │  │  │  │ 
 │  │  │  │  ╰─────╮  │ 
 │  │  │  │  │  ╰─────╮ 
 ╭──╯  │  │  │  │  │  │ 
 │  │  ╭──╯  │  │  │  │ 
 │  │  │  │  ╰──╮  │  │ 
 │  │  │  │  │  │  ╰──╮ 
 ╭───────────╯  │  │  │ 
 │  ╭───────────╯  │  │ 
 │  │  ╭───────────╯  │ 
 │  │  │  ╭───────────╯ 
 ╭─────╯  │  │  │  │  │ 
 │  ╭─────╯  │  │  │  │ 
 │  │  │  │  ╭─────╯  │ 
 │  │  │  │  │  ╭─────╯ 
 ╭──╯  │  │  │  │  │  │ 
 │  │  ╭──╯  │  │  │  │ 
 │  │  │  │  ╭──╯  │  │ 
 │  │  │  │  │  │  ╭──╯ 


In [58]:
pretty_bitonic_network(16)

 0  1  2  3  4  5  6  7  8  9  10 11 12 13 14 15
 ╭──╯  │  │  │  │  │  │  │  │  │  │  │  │  │  │ 
 │  │  ╰──╮  │  │  │  │  │  │  │  │  │  │  │  │ 
 │  │  │  │  ╭──╯  │  │  │  │  │  │  │  │  │  │ 
 │  │  │  │  │  │  ╰──╮  │  │  │  │  │  │  │  │ 
 │  │  │  │  │  │  │  │  ╭──╯  │  │  │  │  │  │ 
 │  │  │  │  │  │  │  │  │  │  ╰──╮  │  │  │  │ 
 │  │  │  │  │  │  │  │  │  │  │  │  ╭──╯  │  │ 
 │  │  │  │  │  │  │  │  │  │  │  │  │  │  ╰──╮ 
 ╭─────╯  │  │  │  │  │  │  │  │  │  │  │  │  │ 
 │  ╭─────╯  │  │  │  │  │  │  │  │  │  │  │  │ 
 │  │  │  │  ╰─────╮  │  │  │  │  │  │  │  │  │ 
 │  │  │  │  │  ╰─────╮  │  │  │  │  │  │  │  │ 
 │  │  │  │  │  │  │  │  ╭─────╯  │  │  │  │  │ 
 │  │  │  │  │  │  │  │  │  ╭─────╯  │  │  │  │ 
 │  │  │  │  │  │  │  │  │  │  │  │  ╰─────╮  │ 
 │  │  │  │  │  │  │  │  │  │  │  │  │  ╰─────╮ 
 ╭──╯  │  │  │  │  │  │  │  │  │  │  │  │  │  │ 
 │  │  ╭──╯  │  │  │  │  │  │  │  │  │  │  │  │ 
 │  │  │  │  ╰──╮  │  │  │  │  │  │  │  │  │  │ 
 │  │  │  │  │  │  ╰

# Vectorised functions

## Testing

In [3]:
# Test sorting
import autograd.numpy as np # we can use plain numpy as well (but can't take grad!)


from differentiable_sorting import bitonic_matrices, diff_bisort, diff_argsort, softmax, softmin, softcswap, bisort
matrices = bitonic_matrices(8)

for i in range(10):
    # these should all be in sorted order
    test = np.random.randint(0, 200, 8)
    print(bisort(matrices, test))    

[ 38.  54.  59.  67. 106. 124. 150. 192.]
[ 15.  30.  49.  85.  94. 103. 119. 174.]
[ 38.  53.  70.  83.  98. 108. 159. 195.]
[  2.  42.  65.  68.  89. 118. 161. 197.]
[ 17.  24.  29.  29.  61. 147. 164. 180.]
[ 43.  67.  91. 134. 140. 153. 185. 196.]
[  3. 138. 146. 156. 166. 170. 171. 189.]
[  0.  28.  34.  62.  94.  98. 165. 168.]
[ 21.  22.  34.  68.  93. 125. 165. 187.]
[  7.  19.  37.  39.  41.  73. 164. 180.]


In [4]:
for i in range(1, 11):
    k = 2**i
    matrices = bitonic_matrices(k)
    print(f"Testing sorting for {k} elements")
    for j in range(100):
        test = np.random.randint(0, 200, k)
        assert (np.allclose(bisort(matrices, test), np.sort(test)))

Testing sorting for 2 elements
Testing sorting for 4 elements
Testing sorting for 8 elements
Testing sorting for 16 elements
Testing sorting for 32 elements
Testing sorting for 64 elements
Testing sorting for 128 elements
Testing sorting for 256 elements
Testing sorting for 512 elements
Testing sorting for 1024 elements


## Differentiable sorting test

In [5]:
# Differentiable sorting 
np.set_printoptions(precision=2)
matrices = bitonic_matrices(8) 


for i in range(10):
    test = np.random.randint(-200,200,8)
    print("Differentiable", neat_vec(diff_bisort(matrices, test)))
    print("Exact sorting ", neat_vec(bisort(matrices, test)))
    print()

Differentiable -146.00	-70.00	17.00	31.69	33.31	50.00	142.00	182.00
Exact sorting  -146.00	-70.00	17.00	32.00	33.00	50.00	142.00	182.00

Differentiable -189.00	-116.00	-84.00	-63.01	-57.99	-48.00	75.00	87.00
Exact sorting  -189.00	-116.00	-84.00	-63.00	-58.00	-48.00	75.00	87.00

Differentiable -150.01	-144.99	-137.13	-134.87	-46.92	-45.08	9.00	16.00
Exact sorting  -150.00	-145.00	-137.00	-135.00	-46.00	-46.00	9.00	16.00

Differentiable -183.00	-162.00	-108.00	-97.00	87.00	99.00	132.00	171.00
Exact sorting  -183.00	-162.00	-108.00	-97.00	87.00	99.00	132.00	171.00

Differentiable -167.00	-155.00	15.00	39.87	42.13	97.00	144.00	177.00
Exact sorting  -167.00	-155.00	15.00	40.00	42.00	97.00	144.00	177.00

Differentiable -153.00	-105.00	-21.00	34.00	120.87	123.13	139.00	155.00
Exact sorting  -153.00	-105.00	-21.00	34.00	121.00	123.00	139.00	155.00

Differentiable -198.00	-155.00	-127.00	-117.00	-75.00	-23.00	28.00	122.00
Exact sorting  -198.00	-155.00	-127.00	-117.00	-75.00	-23.00	28.00	122.0

# Relaxed sorting
We can define a slighly modified function which interpolates between `softmax(a,b)` and `mean(a,b)`. The result is a sorting function that can be relaxed from sorting to averaging.

In [6]:
from differentiable_sorting import diff_bisort_smooth
# Differentiable smoothed sorting 
test = np.random.randint(-200,200,8)
print(f"Mean {np.mean(test):.2f}")
print()
print("Exact sorting       ", neat_vec(bisort(matrices, test)))
for smooth in np.linspace(0, 1, 8):    
    print(f"Diff. smooth[{smooth:.2f}]  ", neat_vec(diff_bisort_smooth(matrices, test, smooth)))
        

Mean -15.50

Exact sorting        -146.00	-92.00	-78.00	-64.00	-38.00	68.00	102.00	124.00
Diff. smooth[0.00]   -146.00	-92.00	-78.00	-64.00	-38.00	68.00	102.00	124.00
Diff. smooth[0.14]   -98.67	-73.37	-55.99	-49.32	-12.67	39.16	56.54	70.31
Diff. smooth[0.29]   -67.44	-60.66	-35.65	-34.75	-3.47	17.71	26.27	33.99
Diff. smooth[0.43]   -46.54	-45.65	-24.58	-23.79	-3.31	2.75	6.77	10.35
Diff. smooth[0.57]   -32.51	-32.03	-18.81	-18.21	-7.00	-6.34	-5.11	-3.97
Diff. smooth[0.71]   -23.40	-23.07	-17.80	-17.43	-11.41	-11.10	-10.04	-9.73
Diff. smooth[0.86]   -18.01	-17.86	-17.35	-17.20	-13.66	-13.51	-13.27	-13.12
Diff. smooth[1.00]   -15.50	-15.50	-15.50	-15.50	-15.50	-15.50	-15.50	-15.50


In [7]:
from autograd import jacobian
# show that we can take the derivative
jac_sort = jacobian(diff_bisort_smooth, argnum=1)
jac_sort(matrices, test, 0.05) # slight relaxation

array([[0.  , 0.02, 0.05, 0.86, 0.02, 0.  , 0.02, 0.02],
       [0.01, 0.02, 0.22, 0.03, 0.66, 0.02, 0.02, 0.02],
       [0.02, 0.01, 0.66, 0.03, 0.22, 0.02, 0.01, 0.03],
       [0.02, 0.  , 0.02, 0.02, 0.03, 0.02, 0.02, 0.86],
       [0.86, 0.04, 0.02, 0.  , 0.  , 0.02, 0.02, 0.02],
       [0.02, 0.02, 0.02, 0.  , 0.02, 0.86, 0.02, 0.02],
       [0.02, 0.02, 0.  , 0.02, 0.02, 0.02, 0.86, 0.02],
       [0.04, 0.86, 0.  , 0.02, 0.02, 0.02, 0.02, 0.  ]])

## Woven form
We can "weave" the four matrices into two matrices for fewer multiplies at the cost of having to split and join the matrices at each layer.

In [8]:
from differentiable_sorting import bitonic_woven_matrices

def diff_bisort_weave(matrices, x):
    """
    Given a set of bitonic sort matrices generated by bitonic_woven_matrices(n), sort 
    a sequence x of length n.
    """
    split = len(x) // 2
    for weave, unweave in matrices:
        woven = weave @ x
        x = unweave @ np.concatenate(softcswap(woven[:split], woven[split:]))
    return x


woven_matrices = bitonic_woven_matrices(8)

print("Exact sorting       ", neat_vec(bisort(matrices, test)))
print(f"Diff. (std.)       ", neat_vec(diff_bisort(matrices, test)))
print(f"Diff. (woven)      ", neat_vec(diff_bisort_weave(woven_matrices, test)))
        

Exact sorting        -146.00	-92.00	-78.00	-64.00	-38.00	68.00	102.00	124.00
Diff. (std.)        -146.00	-92.00	-78.00	-64.00	-38.00	68.00	102.00	124.00
Diff. (woven)       -146.00	-92.00	-78.00	-64.00	-38.00	68.00	102.00	124.00


## Differentiable ranking / argsort
We can use a differentiable similarity measure between the input and output of the vector, e.g. an RBF kernel. We can use this to generate a normalised similarity matrix and apply this to a vector `[1, 2, 3, ..., n]`. This gives a differentiable ranking function.

As `sigma` gets larger, the result converges to giving all values the mean rank; as it goes to zero the result converges to the true rank.

In [19]:
from differentiable_sorting import order_matrix, diff_argsort

In [20]:
matrices = bitonic_matrices(8)

In [21]:
x = [5.0, -1.0, 9.5, 13.2, 16.2, 10.5, 42.0, 18.0]
np.set_printoptions(suppress=True)
print(x)
# show argsort
ranks = diff_argsort(matrices, x, sigma=0.1)
print(neat_vec(ranks))

[5.0, -1.0, 9.5, 13.2, 16.2, 10.5, 42.0, 18.0]
1.00	0.00	2.00	5.00	3.00	4.00	7.00	6.00


In [22]:
# we now have differentiable argmax and argmin by indexing the rank vector
print(np.argmin(x), int(ranks[0]))
print(np.argmax(x), int(ranks[-1]))

1 1
6 6


In [13]:
print("Smoothed ranks")
test = x
for sigma in [0.1, 1, 10, 100, 1000]:     
    ranks = diff_argsort(matrices, test, sigma=sigma) 
    print(f"sigma={sigma:7.1f}  |", neat_vec(ranks))

Smoothed ranks
sigma=    0.1  | 1.00	0.00	2.00	5.00	3.00	4.00	7.00	6.00
sigma=    1.0  | 1.00	0.00	2.89	4.01	3.06	4.42	6.61	6.00
sigma=   10.0  | 2.14	2.70	3.09	3.24	3.46	3.71	3.89	5.91
sigma=  100.0  | 3.46	3.48	3.48	3.49	3.49	3.50	3.50	3.55
sigma= 1000.0  | 3.50	3.50	3.50	3.50	3.50	3.50	3.50	3.50


In [14]:
np.set_printoptions(precision=3)
jac_rank = jacobian(diff_argsort, argnum=1)
print(jac_rank(matrices, np.array(test), 1.0) )

[[ 0.    -0.    -0.    -0.    -0.    -0.    -0.    -0.   ]
 [ 0.     0.    -0.     0.     0.    -0.     0.     0.   ]
 [ 0.004  0.     0.655  0.031  0.002 -0.692  0.     0.   ]
 [ 0.003  0.    -0.64   0.113  0.003  0.522  0.     0.   ]
 [-0.    -0.    -0.007 -0.07  -0.038  0.116 -0.    -0.001]
 [ 0.     0.     0.003 -0.011  0.605  0.007  0.    -0.603]
 [ 0.     0.     0.     0.004 -0.576  0.001  0.     0.571]
 [-0.    -0.    -0.    -0.    -0.    -0.    -0.     0.   ]]


# PyTorch example
We can verify that this is both parallelisable on the GPU and fully differentiable.

In [15]:
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)

Device: cuda:0


In [16]:
from differentiable_sorting_torch import diff_bisort, bitonic_matrices, diff_argsort
matrices = bitonic_matrices(16)
torch_matrices = [[torch.from_numpy(matrix).float().to(device) for matrix in matrix_set] for matrix_set in matrices]


In [17]:
test_input = np.random.normal(0, 5, 16)
var_test_input = Variable(torch.from_numpy(test_input).float().to(device),
                          requires_grad=True)

result = diff_bisort(torch_matrices, var_test_input)

# compute the Jacobian of the sorting function, to show we can differentiate through the
# sorting function
jac = []
for i in range(len(result)):
    jac.append(
        torch.autograd.grad(result[i], var_test_input, retain_graph=True)[0])

# 16 x 16 jacobian of the sorting matrix
print(torch.stack(jac))

tensor([[6.7075e-03, 1.6772e-05, 9.5515e-01, 3.1831e-05, 2.1090e-02, 1.5687e-02,
         6.0490e-05, 9.9590e-07, 2.2132e-05, 2.8194e-05, 2.2722e-04, 3.4077e-04,
         2.9672e-05, 1.6629e-04, 2.0245e-04, 2.4289e-04],
        [1.0620e-01, 1.9133e-04, 3.4910e-02, 3.4753e-04, 4.8142e-01, 3.2459e-01,
         9.9611e-04, 1.7100e-05, 8.8798e-04, 1.1313e-03, 9.4041e-03, 1.4216e-02,
         1.1910e-03, 6.6996e-03, 8.1069e-03, 9.6946e-03],
        [3.0832e-01, 7.1638e-04, 5.1102e-03, 1.3963e-03, 2.1782e-01, 3.0067e-01,
         2.5627e-03, 4.4879e-05, 2.4988e-03, 3.1937e-03, 2.9483e-02, 4.7383e-02,
         3.5332e-03, 1.9902e-02, 2.5691e-02, 3.1670e-02],
        [2.8261e-01, 5.1883e-04, 3.5821e-03, 1.0642e-03, 1.6873e-01, 2.1003e-01,
         2.0646e-03, 4.5296e-05, 4.9956e-03, 6.3725e-03, 6.0788e-02, 9.8363e-02,
         7.0049e-03, 3.9861e-02, 5.1130e-02, 6.2835e-02],
        [1.4036e-01, 4.5704e-03, 5.2332e-04, 7.1419e-03, 4.9832e-02, 6.7240e-02,
         1.1709e-02, 3.2598e-04, 1.1847

In [18]:
result = diff_argsort(torch_matrices, var_test_input)
print(result)


tensor([2.0000e+00, 4.0000e+00, 4.9946e+00, 2.6924e-36, 3.1385e-26, 1.1046e+01,
        1.5000e+01, 1.3000e+01, 1.3078e+01, 6.0000e+00, 6.0000e+00, 8.0001e+00,
        7.7404e+00, 1.0000e+00, 1.0000e+00, 7.0000e+00], device='cuda:0',
       grad_fn=<MvBackward>)
