# POT Backend

This gives an overview of using the [Python Optimal Transport](https://pythonot.github.io/) libray to solve transport problems

## Distance Function

Use the `POTDistance` as a distance function. This takes as input a `distance_metric` string value, which works for any distance function supported by the POT library. See [here](https://pythonot.github.io/all.html#ot.dist) for a full list.

## Cost Function

Two POT-based cost functinos are provided, `EarthMoversCost` and `SinkhornCost`. `EarthMoversCost` uses `ot.emd2` while `SinkhornCost` uses `ot.sinkhorn`

## Extending

For new distance functions:

The `POTDistance` should cover all distance metrics supported by POT. However, the functions implemented for `gaussian_distance` and `gaussian_distance_from_stats` only work for a Euclidean distance metric. If a Gaussian approximation is rederived for another distance metric, it can be added by subclassing `POTDistance` and changing the `gaussian_distance` and `gaussian_distance_from_stats` functions.

For new cost functions:

New cost functions using other solvers in POT should subclass `CostFunction` and update the `cost_function` function. Consider the code for `EarthMoversCost` and `SinkhornCost`:

```
class EarthMoversCost(CostFunction):
    def __init__(self, distance_function, default_max_iter=100000):
        super().__init__(distance_function, default_max_iter)
        
    def cost_function(self, x_weights, y_weights, M_dist, max_iter):
        
        if x_weights is None:
            x_weights = self.get_sample_weights(M_dist.shape[0])
        
        if y_weights is None:
            y_weights = self.get_sample_weights(M_dist.shape[1])
        
        max_iter = self.get_iter(max_iter)
        output = ot.emd2(x_weights, y_weights, M_dist, numItermax=max_iter, return_matrix=True)
    
        cost = output[0]
        coupling = output[1]['G']
        
        return cost, coupling

class SinkhornCost(CostFunction):
    def __init__(self, distance_function, entropy, default_max_iter=1000, method='sinkhorn'):
        super().__init__(distance_function, default_max_iter=default_max_iter)
        
        self.entropy = entropy
        self.method = 'sinkhorn'
        
    def cost_function(self, x_weights, y_weights, M_dist, max_iter):
        
        if x_weights is None:
            x_weights = self.get_sample_weights(M_dist.shape[0])
        
        if y_weights is None:
            y_weights = self.get_sample_weights(M_dist.shape[1])
        
        max_iter = self.get_iter(max_iter)
        
        output = ot.sinkhorn(x_weights, y_weights, M_dist, self.entropy,
                            method=self.method, numItermax=max_iter, log=True)
        
        coupling = output[0]
        cost = (coupling*M_dist).sum()
        
        return cost, coupling
```

Any POT solver can be added by updating `cost_function` to work with that particular solver.