#### STOCHASTIC OPTIMIZATION OF SORTING NETWORKS VIA CONTINUOUS RELAXATIONS

This paper deals with an object sorting problem, generally well known in many machine learning pipelines. For instance, the to k-multi-class classification, ranking documents for information retrieval and multi-object target tracking in computer vision.To solve these problems, algorithms are used that typically require the learning of informative representations of complex high-dimensional data, such as images, prior to sorting and subsequent downstream processing. 

However, for a downstream sorting problem, it is not possible to optimize it from end to end because the sorting operator is not differentiable with respect to its input. The goal of this paper is to propose a method that makes the sort operator differentiable almost everywhere with respect to the inputs. This proposed method is $\textbf{NeuralSort}$. This report concerns scientic aspects of NeuralSort. It is organized as follows: 

- $\textbf{Presents a well-understood summay of NeuralSort method}$;
- $\textbf{Give an application of this method on data}$

#### 𝐍𝐞𝐮𝐫𝐚𝐥𝐒𝐨𝐫𝐭 𝐦𝐞𝐭𝐡𝐨𝐝

$\textbf{How understand it}$: in the sorting problem, the output can be viewed as a permutation matrix, which is a square matrix with entries in $\{0,1\}$ such that every row and every column sums to 1. For NeuralSort, we consider other matrix called  unimodal row-stochastic matrix. It is a square matrix with positive real entries, where each row sums to 1 and has a distinct arg max. All permutation matrices are unimodal row-stochastic matrices. 


$\textbf{How NeuralSort is it trained ?}$: the goal is to optimize training objectives involving a sort operator with gradient-based methods.

The problem can be written in the following form:

$\mathcal{L}(\theta,s)= f(P_z,\theta)$ and $ z = sort(s)$

Here, 

- $s\in \mathbb{R}^n$ denotes a vector of n real-valued scores that follows a  Plackett-Luce distribution with 
probability mass function for any $z \in \mathcal{Z}_n $ is given by:
$q(z|s)=\dfrac{s_{z_1}}{Z} \dfrac{s_{z_2}}{Z-s_{z_2}}\cdots \dfrac{s_{z_n}}{Z-\sum_{i=1}^{n-1}s_{z_i}}$, $Z$ is the normalization constant is given by $Z=\sum_{i=1}^{n}s_i$.

- z is the permutation that (deterministically) sorts the scores s, Every
permutation  z is associated with a permutation  matrix $P_z \in \{0,1\}^{n*n}$ with $P_z[i,i]=\mathbb{1}(j=z_i)$.
$\textbf{Example}$, let $s = [9; 10; 5; 2]^T$ , then $sort(s) = [2; 1; 3; 4]^T$ since the largest element is at the second index, second largest element is at the first index and so on. In case of ties, elements are assigned indices in the order they appear. We can obtain the sorted vector simply via $P_{sort(s)}s$.

- $f(·)$ is an arbitrary function of interest assumed to be differentiable w.r.t a set
of parameters $\theta$ and z. 

Since, the sort operation is not, the proposed solution of the authors to derive a relaxation to the sort operator that leads to a surrogate objective with well-defined gradients. In particular, we seek to use such a relaxation to replace the permutation matrix $P_z$ in the objective function above with an approximation $\hat{P}_z$ such that the surrogate
objective $f(\hat{P}_z; \theta)$ is differentiable w.r.t. the scores s.



#### Our implementation

In [1]:
# run the the resnet algoritm with MNIST
!python run_baseline.py --dataset=mnist --nloglr=3

Namespace(k=None, tau=None, nloglr=3.0, method=None, resume=False, dataset='mnist')
Beginning epoch 0:  baseline-resnet-mnist-b3
train -0.6181631684303284
val 0.820400013923645
Saving...
Beginning epoch 1:  baseline-resnet-mnist-b3
train -0.31322935223579407
val 0.8936000159978866
Saving...
Beginning epoch 2:  baseline-resnet-mnist-b3
train -0.21430061757564545
val 0.9138000148534775
Saving...
Beginning epoch 3:  baseline-resnet-mnist-b3
train -0.11921847611665726
val 0.928800015091896
Saving...
Beginning epoch 4:  baseline-resnet-mnist-b3
train -0.3294110894203186
val 0.940600012421608
Saving...
Beginning epoch 5:  baseline-resnet-mnist-b3
train -0.2129727303981781
val 0.9464000124931335
Saving...
Beginning epoch 6:  baseline-resnet-mnist-b3
train -0.12585563957691193
val 0.954000011920929
Saving...
Beginning epoch 7:  baseline-resnet-mnist-b3
train -0.1471453607082367
val 0.9566000115871429
Saving...
Beginning epoch 8:  baseline-resnet-mnist-b3
train -0.15096329152584076
val 0.960600

In [2]:
# run the the resnet algoritm with EMNIST_MNIST
!python run_baseline.py --dataset=EMNIST --nloglr=3

Namespace(k=None, tau=None, nloglr=3.0, method=None, resume=False, dataset='EMNIST')
Beginning epoch 0:  baseline-resnet-EMNIST-b3
train -0.4970678985118866
val 0.8600000161528587
Saving...
Beginning epoch 1:  baseline-resnet-EMNIST-b3
train -0.26618102192878723
val 0.9132000160217285
Saving...
Beginning epoch 2:  baseline-resnet-EMNIST-b3
train -0.19032417237758636
val 0.9278000141382218
Saving...
Beginning epoch 3:  baseline-resnet-EMNIST-b3
train -0.20466314256191254
val 0.9432000132799149
Saving...
Beginning epoch 4:  baseline-resnet-EMNIST-b3
train -0.2037983387708664
val 0.9482000126838684
Saving...
Beginning epoch 5:  baseline-resnet-EMNIST-b3
train -0.18303890526294708
val 0.9560000116825104
Saving...
Beginning epoch 6:  baseline-resnet-EMNIST-b3
train -0.1262630969285965
val 0.9596000096797943
Saving...
Beginning epoch 7:  baseline-resnet-EMNIST-b3
train -0.07033740729093552
val 0.9618000105619431
Saving...
Beginning epoch 8:  baseline-resnet-EMNIST-b3
train -0.105169206857681

In [3]:
# run the the resnet algoritm with EMNIST_DIGITS
!python run_baseline.py --dataset=EMNIST_DIGITS --nloglr=3

Namespace(k=None, tau=None, nloglr=3.0, method=None, resume=False, dataset='EMNIST_DIGITS')
Beginning epoch 0:  baseline-resnet-EMNIST_DIGITS-b3
train -0.2991776764392853
val 0.9373000140190124
Saving...
Beginning epoch 1:  baseline-resnet-EMNIST_DIGITS-b3
train -0.18228211998939514
val 0.9653000090122222
Saving...
Beginning epoch 2:  baseline-resnet-EMNIST_DIGITS-b3
train -0.11829649657011032
val 0.9712000083327293
Saving...
Beginning epoch 3:  baseline-resnet-EMNIST_DIGITS-b3
train -0.03361757844686508
val 0.978400006711483
Saving...
Beginning epoch 4:  baseline-resnet-EMNIST_DIGITS-b3
train -0.09328191727399826
val 0.9811000062227249
Saving...
Beginning epoch 5:  baseline-resnet-EMNIST_DIGITS-b3
train -0.14441876113414764
val 0.9819000055789947
Saving...
Beginning epoch 6:  baseline-resnet-EMNIST_DIGITS-b3
train -0.07571319490671158
val 0.9825000052452088
Saving...
Beginning epoch 7:  baseline-resnet-EMNIST_DIGITS-b3
train -0.06098661571741104
val 0.9845000048279762
Saving...
Beginn

In [7]:
# run the the resnet algoritm with CIFAR10
!python run_baseline.py --dataset=cifar10 --nloglr=3


Files already downloaded and verified
Namespace(k=None, tau=None, nloglr=3.0, method=None, resume=False, dataset='cifar10')
Beginning epoch 0:  baseline-resnet-cifar10-b3
train -1.3611418008804321
val 0.4836000067293644
Saving...
Beginning epoch 1:  baseline-resnet-cifar10-b3
train -1.1654270887374878
val 0.5826000076085329
Saving...
Beginning epoch 2:  baseline-resnet-cifar10-b3
train -1.0595035552978516
val 0.6126000070422888
Saving...
Beginning epoch 3:  baseline-resnet-cifar10-b3
train -1.1651397943496704
val 0.6934000086188317
Saving...
Beginning epoch 4:  baseline-resnet-cifar10-b3
train -0.7546650767326355
val 0.7118000074625015
Saving...
Beginning epoch 5:  baseline-resnet-cifar10-b3
train -0.5743950605392456
val 0.7268000096082687
Saving...
Beginning epoch 6:  baseline-resnet-cifar10-b3
train -0.6584057807922363
val 0.7766000124216079
Saving...
Beginning epoch 7:  baseline-resnet-cifar10-b3
train -0.4958330988883972
val 0.7858000118732452
Saving...
Beginning epoch 8:  baseline

In [None]:
# run the the resnet algoritm with CIFAR100
!python run_baseline.py --dataset=cifar100 --nloglr=3

Files already downloaded and verified
Namespace(k=None, tau=None, nloglr=3.0, method=None, resume=False, dataset='cifar100')
Beginning epoch 0:  baseline-resnet-cifar100-b3
train -3.9608864784240723
val 0.07860000142455101
Saving...
Beginning epoch 1:  baseline-resnet-cifar100-b3
train -3.989877939224243
val 0.10640000207722188
Saving...
Beginning epoch 2:  baseline-resnet-cifar100-b3
train -3.689436435699463
val 0.14140000285208226
Saving...
Beginning epoch 3:  baseline-resnet-cifar100-b3
train -3.468449115753174
val 0.15660000321269035
Saving...
Beginning epoch 4:  baseline-resnet-cifar100-b3
train -3.2784814834594727
val 0.17660000374913215
Saving...
Beginning epoch 5:  baseline-resnet-cifar100-b3
train -3.1142818927764893
val 0.20220000448822975
Saving...
Beginning epoch 6:  baseline-resnet-cifar100-b3
train -3.189652442932129
val 0.24560000529885292
Saving...
Beginning epoch 7:  baseline-resnet-cifar100-b3
train -2.9700264930725098
val 0.27220000582933424
Saving...
Beginning epoch

#### Neuralsort algorithm

In [1]:
# run the  Neuralsort algoritm with cifar minist: deterministic
!python run_dknn.py --k=9 --tau=85 --nloglr=3 --method=deterministic --dataset=mnist --num_epochs=30


Namespace(k=9, tau=85.0, nloglr=3.0, method='deterministic', resume=False, dataset='mnist', num_train_queries=100, num_test_queries=10, num_train_neighbors=200, num_samples=10, num_epochs=30)
Beginning epoch 0:  dknn-resnet-mnist-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.6636374803504559
Avg. val acc: 0.969
Saving...
Beginning epoch 1:  dknn-resnet-mnist-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.8815347478847304
Avg. val acc: 0.9764
Saving...
Beginning epoch 2:  dknn-resnet-mnist-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.9207631741148047
Avg. val acc: 0.982
Saving...
Beginning epoch 3:  dknn-resnet-mnist-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.9397958742006854
Avg. val acc: 0.9834
Saving...
Beginning epoch 4:  dknn-resnet-mnist-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.9495229486022334
Avg. val acc: 0.9846
Saving...
Beginning epoch 5:  dknn-resnet-mnist-deterministic-k9-t8500-b3
Avg. train 

In [None]:
# run the  Neuralsort algoritm with cifar emnist: deterministic
!python run_dknn.py --k=9 --tau=85 --nloglr=3 --method=deterministic --dataset=EMNIST --num_epochs=30

Namespace(k=9, tau=85.0, nloglr=3.0, method='deterministic', resume=False, dataset='EMNIST', num_train_queries=100, num_test_queries=10, num_train_neighbors=200, num_samples=10, num_epochs=30)
Beginning epoch 0:  dknn-resnet-EMNIST-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.7044454515823206
Avg. val acc: 0.9682
Saving...
Beginning epoch 1:  dknn-resnet-EMNIST-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.8857237759985105
Avg. val acc: 0.977
Saving...
Beginning epoch 2:  dknn-resnet-EMNIST-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.9271725190287889
Avg. val acc: 0.9814
Saving...
Beginning epoch 3:  dknn-resnet-EMNIST-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.9455224536163641
Avg. val acc: 0.9844
Saving...
Beginning epoch 4:  dknn-resnet-EMNIST-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.9552618650956589
Avg. val acc: 0.9854
Saving...
Beginning epoch 7:  dknn-resnet-EMNIST-deterministic-k9-t8500-b3
Avg

In [None]:
# run the  Neuralsort algoritm with cifar emnist_digits: deterministic
!python run_dknn.py --k=9 --tau=85 --nloglr=3 --method=deterministic --dataset=EMNIST_DIGITS --num_epochs=30

Avg. train correctness of top k: 0.8727946271182256
Avg. val acc: 0.9884
Saving...
Beginning epoch 1:  dknn-resnet-EMNIST_DIGITS-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.9625406637514277
Avg. val acc: 0.9916
Saving...
Beginning epoch 3:  dknn-resnet-EMNIST_DIGITS-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.9772747499585721
Avg. val acc: 0.9928
Saving...
Beginning epoch 5:  dknn-resnet-EMNIST_DIGITS-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.9815984567467133
Avg. val acc: 0.9928
Beginning epoch 6:  dknn-resnet-EMNIST_DIGITS-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.9833403679483761
Avg. val acc: 0.9935
Saving...
Beginning epoch 7:  dknn-resnet-EMNIST_DIGITS-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.9848018208563607
Avg. val acc: 0.9937
Saving...
Beginning epoch 8:  dknn-resnet-EMNIST_DIGITS-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.985383667047473
Avg. val acc: 0.9937
Beginnin

In [2]:
# run the  Neuralsort algoritm with cifar cifar10: deterministic
!python run_dknn.py --k=9 --tau=85 --nloglr=3 --method=deterministic --dataset=cifar10 --num_epochs=30

Files already downloaded and verified
Namespace(k=9, tau=85.0, nloglr=3.0, method='deterministic', resume=False, dataset='cifar10', num_train_queries=100, num_test_queries=10, num_train_neighbors=200, num_samples=10, num_epochs=30)
Beginning epoch 0:  dknn-resnet-cifar10-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.2107735330675855
Avg. val acc: 0.4
Saving...
Beginning epoch 1:  dknn-resnet-cifar10-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.28406008938212457
Avg. val acc: 0.4848
Saving...
Beginning epoch 2:  dknn-resnet-cifar10-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.34430799119266453
Avg. val acc: 0.542
Saving...
Beginning epoch 3:  dknn-resnet-cifar10-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.40420540138527183
Avg. val acc: 0.5846
Saving...
Beginning epoch 4:  dknn-resnet-cifar10-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.45281866485689914
Avg. val acc: 0.6252
Saving...
Beginning epoch 5:  dkn

In [3]:
# run the  Neuralsort algoritm with cifar cifar100: deterministic
!python run_dknn.py --k=9 --tau=85 --nloglr=3 --method=deterministic --dataset=cifar100 --num_epochs=30

Files already downloaded and verified
Namespace(k=9, tau=85.0, nloglr=3.0, method='deterministic', resume=False, dataset='cifar100', num_train_queries=100, num_test_queries=10, num_train_neighbors=200, num_samples=10, num_epochs=30)
Beginning epoch 0:  dknn-resnet-cifar100-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.019102932608421934
Avg. val acc: 0.0888
Saving...
Beginning epoch 1:  dknn-resnet-cifar100-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.023594132113603904
Avg. val acc: 0.0938
Saving...
Beginning epoch 2:  dknn-resnet-cifar100-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.02636663934330882
Avg. val acc: 0.0964
Saving...
Beginning epoch 3:  dknn-resnet-cifar100-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.028641139121703147
Avg. val acc: 0.1064
Saving...
Beginning epoch 4:  dknn-resnet-cifar100-deterministic-k9-t8500-b3
Avg. train correctness of top k: 0.03012007193800842
Avg. val acc: 0.0968
Beginning epoch 5: 