# Optimal Transport Dataset Distances

This notebook builds off the [Optimal Transport](https://github.com/kheyer/OTDD/blob/main/Algorithms/Optimal%20Transport.ipynb) notebook which defines the Optimal Transport problem and the Sinkhorn algorithm


## Overview

The [Optimal Transport Dataset Distances](https://arxiv.org/pdf/2002.02923.pdf) algorithm is designed to incorporate class/label information into the optimal transport problem.

Consider the following labeled datasets:

<img src="https://raw.githubusercontent.com/kheyer/OTDD/main/media/otdd_1.png" height="80%" width="80%">

Here the points are the same, but the label distribution is different. We would like to augment the optimal transport problem to include this information somehow.

To do this, we take a hierarchical approach to the transport problem. First we calculate the transport distance between labed-based subsets of the data distributitions. Then we incorporate these label-to-label distances in the transport problem over the full data measures.

## Label to Label Distances

We define our two samples as discrete labeled measures in $\mathbb{R}^{n}$ with locations $x_{1}, ..., x_{n} \in \chi_{x}$, $y_{1}, ..., y_{m} \in \chi_{y}$, along with weights $\alpha \in P(\chi_{x})$, $\beta \in P(\chi_{y})$.

For these labeled measures, we define $\omega_{c}(\chi_{x}) = \text{P}(\chi_{x} ~|~ \text{label} = c)$

We can then define the label-to-label transport cost as $OT(\omega_{c_{x}}, \omega_{c_{y}})$ following the Kantorovich formulation

\begin{align}
\text{OT}(\alpha, \beta) ~=~ \min_{x\in\Pi(\alpha, \beta)}\int_{\chi_{x} \times \chi_{y}} c(x,y)d\pi(x,y)
\end{align}

From here, the algorithms from the [Optimal Transport](https://github.com/kheyer/OTDD/blob/main/Algorithms/Optimal%20Transport.ipynb) notebook apply towards solving this problem.

We iterate over all classes in each dataset to develop a matrix of label-to-label distances. If $\chi_{x}$ as $I$ distinct classes and $\chi_{y}$ has $J$ distinct classes, then we will have a matrix of label-to-label distances $L$ such that $L_{i,j} = OT(\omega_{i}, \omega_{j})$.

## OTDD

Once we have a matrix of label-to-label distances, we can move towards solving the full transport problem.

The standard transport problem solves 

\begin{align}
\text{OT}(\alpha, \beta) ~=~ \min_{x\in\Pi(\alpha, \beta)}\int_{\chi_{x} \times \chi_{y}} c(x,y)d\pi(x,y)
\end{align}

Where $c(x,y)$ is the ground cost function for transport between a point in $\chi_{x}$ and a point in $\chi_{y}$. For the OTDD algorithm, we augment this cost with the label-to-label distance calculated earlier.

\begin{align}
\eta((x,y), (c_{x}, c_{y})) = (c(x,y)^2 + OT(\omega_{c_{x}}, \omega_{c_{y}})^2)^{1/2}
\end{align}

We then solve the transport over $\eta$

\begin{align}
\text{OT}(\alpha, \beta) ~=~ \min_{x\in\Pi(\alpha, \beta)}\int_{\chi_{x} \times \chi_{y}} \eta(x,y)d\pi(x,y)
\end{align}

## Example

We can see the impact this has on the transport plan calculated. For an example problem, we will look at calculating a transport plan between the MNIST and USPS handwritten digit datasets. The full example can be found [here](https://github.com/kheyer/OTDD/blob/main/Examples/MNIST%20USPS.ipynb).

Both datasets have images of handwritten digits from 0-9. If we calculate the transport plan between 5000 MNIST images and 5000 USPS images, we get the following:

<img src="https://raw.githubusercontent.com/kheyer/OTDD/main/media/mnist_OT.png" height="40%" width="40%">

Now we compute the label-to-label distances for the different digit classes. This gives us the following label-to-label distance matrix:

<img src="https://raw.githubusercontent.com/kheyer/OTDD/main/media/mnist_class.png" height="40%" width="40%">

Next we augment the cost matrix with these label-to-label distances and solve the transport problem again, generating a new coupling:

<img src="https://raw.githubusercontent.com/kheyer/OTDD/main/media/mnist_otdd.png" height="40%" width="40%">

We can see coupling along the diagonal (ie matching classes between datasets) is much stronger when label-to-label distances are incorporated.