# API Overview

This gives a high level overview of major abstractions and plotting.

## Datasets

Datasets contain a set of feature vectors and optionally a set of corresponding label vectors. Label vectors should be integer values that will be used for indexing.

```
data = ArrayDataset(my_features, my_labels)
```

Datasets support sampling features with and without labels.

```
sample_features = data.sample(500)
sample_features, sample_labels = data.sample_with_label(500)
```

Datasets also allow for creating sub-samples which returns another `ArrayDataset` object with the desired number of elements.

```
data_subsample = data.subsample(1000, equal_classes=True)
```

## Distance Functions

Distance functions subclass `DistanceFunction`. The purpose of this class is to generate a cost matrix. Custom methods should overwrite the `__call__` function.

```
class CustomDistance(DistanceFunction):
    def __init__(self):
        super().__init__()
        
    def __call__(self, x_vals, y_vals=None, mask_diagonal=False):
        return my_cost_metric(x_vals, y_vals)
        

```

If `mask_diagonal` is True, the diagonal of the cost matrix will be masked with large values. This allows for commputing transport plans within a dataset rather than between datasets.

Distance functions support computing transport distances using the Gaussian approximation of the p-Wasserstein distance (see the [Algorithms](https://github.com/kheyer/OTDD/blob/main/Algorithms/Optimal%20Transport.ipynb) notebook).

```
gaussian_transport_distance = DistanceFunction.gaussian_distance(x,y)
```

By default, this function is not implemented in the base `DistanceFunction` since the standard Gaussian approximation assumed a Euclidean distance metric, which might not always be the case.


## Cost Functions

Cost functions subclass `CostFunction`. This class uses a given `DistanceFunction` to solve a transport problem between two datasets


#### Optimal Transport Distance

Calculating the optimal transport between two sets of feature vectors is done in two steps.
1. Compute pairwise distances between elements in each dataset to generate a cost matrix
2. Use an optimal transport cost function to calculate optimal coupling between data items

The `distance` function returns the transport cost, coupling matrix and ground cost matrix.

```
distance_function = POTDistance(distance_metric='euclidean')
cost_function = SinkhornCost(distance_function, entropy=0.2)
cost, coupling, M_dist = cost_function.distance(x_vals, y_vals)
```

`cost` is the transport cost between datasets. `coupling` is the coupling matrix solved by the `cost_function` and `M_dist` is the distnce matrix calculated by `distance_function`.

#### Optimal Transport Dataset Distance

The `distance_with_labels` function calculates the transport cost using the OTDD algorithm from [Geometric Dataset Distances via Optimal Transport](https://arxiv.org/pdf/2002.02923.pdf). Label distnces are calculated by computing the optimal transport cost between label subsets of each dataset. Then the full transport problem is solved with the label augmented cost:

<img src="https://render.githubusercontent.com/render/math?math=d_{Z}\bigl((x,y), (x',y') \bigr) \triangleq \bigl( d_{X}(x,x')^p  %2B \text{W}_p^p(\alpha_y, \alpha_{y'}) \bigr)^{\frac{1}{p}}">

To calculate the OTDD distance between two datasets:
1. Compute pairwise distances between elements in each dataset to generate a cost matrix
2. Compute label-to-label optimal transport distances
3. Update the cost matrix with label distances
4. Use an optimal transport cost function to calculate optimal coupling between data items

```
cost, coupling, OTDD_matrix, class_distances, class_x_dict, class_y_dict = cost_function.distance_with_labels(x_vals, y_vals, x_labels, y_label)
```

`cost` is the OTDD transport cost between datasets. `coupling` is the coupling matrix solved by the `cost_function` and `OTDD_matrix` is the distnce matrix calculated by `distance_function` with label-to-label distances. `class_distances` is the matrix of label-to-label distances between datasets. `class_x_dict` and `class_y_dict` map label values in `x_labels` and `y_labels` to index values in `class_distances`

#### Gaussian Approximation

For large datasets, computing the optimal transport cost can be prohibative. The transport cost can be approximated by a closed form solution for the 2-Wasserstein distance between two Gaussians, also called the Fréchet distance.

<img src="https://render.githubusercontent.com/render/math?math=\text{W}_2^2(\alpha, \beta) = \| \mu_{\alpha} - \mu_{\beta} \|_2^2 %2B  \| \Sigma^\frac{1}{2}_\alpha - \Sigma_{\beta}^\frac{1}{2} \|_{F}^2">

Note that the derivation for this equation assumed a Euclidean cost metric.

To calculate the Gaussian approximation distance between two datasets (ie no labels):

```
cost = distance_function.gaussian_distance(x_vals, y_vals)
```

To calculate the transport distance with labels, there are two approaches. One is to calculte the label-to-label distances with the Gaussian approximation, then solve the optimal transport coupling between the two datasets

```
cost, coupling, OTDD_matrix, class_distances, class_x_dict, class_y_dict = cost_function.distance_with_labels(x_vals, y_vals, x_labels, y_label, gaussian_class_distance=True)
```

An even quicker approach is to calculate the transport cost using the class distance matrix, solving the optimal transport problem over a `c1 x c2` space rather than a `n1 x n2` space.

```
cost, coupling, OTDD_matrix, class_distances, class_x_dict, class_y_dict = cost_function.distance_with_labels(x_vals, y_vals, x_labels, y_label, gaussian_class_distance=True, gaussian_data_distance=True)
```


#### Bootstrapping

Another approach to reducing the compute for calculating transport costs is to calculate the transport cost between bootstrapped samples from the data.

For bootstrapping standard transport:

```
distances = cost_function.bootstrap_distance(num_iterations, dataset_x, sample_size_x, 
                                      dataset_y, sample_size_y)
```

For bootstrapping transport with labels:

```
distances = cost_function.bootstrap_label_distance(num_iterations, dataset_x, sample_size_x, 
                                      dataset_y, sample_size_y)
```

Generally, Gaussian approximations under-estimate the transport cost, while bootstrapping over-estimates the cost.


#### Intra-Dataset Distance

To calculate the intra-dataset distance (ie $W_{p}^{p}(\alpha, \alpha)$), pass the `mask_diagonal=True` to the distance methods. The diagonal of the cost matrix (self-distance) will be masked with a large value.


## Plotting

The functions in `plot.py` provide several plotting approaches using [Matplotlib](https://github.com/matplotlib/matplotlib), [HoloViews](https://github.com/holoviz/holoviews) and [Datashader](https://github.com/holoviz/datashader). Datashader methods are recomended for large datasets.

Example with MNIST and USPS digit datasets:

```
mnist = ArrayDataset(mnist_vecs, labels=mnist_labels)
usps = ArrayDataset(usps_vecs, labels=usps_labels)

outputs = cost_fn.distance_with_labels(mnist_sample.features, usps_sample.features,
                                                      mnist_sample.labels, usps_sample.labels,
                                      gaussian_class_distance=True)

cost, coupling, OTDD_matrix, class_distances, class_x_dict, class_y_dict = outputs

emb = TSNE().fit_transform(np.concatenate([mnist.features, usps.features]))
mnist_emb = emb[:mnist.features.shape[0]]
usps_emb = emb[mnist.feaatures.shape[0]:]

```

From here we can plot the coupling matrix

```
plot_coupling(coupling, OTDD_matrix, mnist.labels, usps.labels,
              classes, classes, figsize=(8,8))
```

![coupling plot](https://raw.githubusercontent.com/kheyer/OTDD/main/media/coupling.png)

A heatmap of class distances

```
plot_class_distances(class_distances, mnist.classes, usps.classes, 
                            cmap='OrRd', figsize=(10,8))
```

![class distance heatmap](https://raw.githubusercontent.com/kheyer/OTDD/main/media/heatmap.png)

The coupling network based on 2d embeddings

```
plot_coupling_network(mnist_emb, usps_emb, mnist_sample.labels, 
                      usps_sample.labels, coupling, plot_type='hv')
```

<img src="https://raw.githubusercontent.com/kheyer/OTDD/main/media/connectivity.png" width="500">

If the coupling network is too dense to plot well, we can plot the k strongest connections

```
plot_network_k_connections(mnist_emb, usps_emb, mnist_sample.labels, 
                    usps_sample.labels, coupling, 1000, plot_type='hv')
```

<img src="https://raw.githubusercontent.com/kheyer/OTDD/main/media/k_connectivity.png" width="500">
