<a href="https://colab.research.google.com/github/johannnamr/Discrepancy-based-inference-using-QMC/blob/main/Helper-functions/ot_sink.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Source code from the OT library for Sinkhorn divergence

Adjusted for permanent use of the scipy backend to calculate pair-wise distances

https://pythonot.github.io/_modules/ot/bregman.html#empirical_sinkhorn_divergence

In [None]:
def empirical_sinkhorn_divergence_adj(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
                                  numIterMax=10000, stopThr=1e-9,
                                  verbose=False, log=False, warn=True, **kwargs):
    r'''
    Compute the sinkhorn divergence loss from empirical data

    The function solves the following optimization problems and return the
    sinkhorn divergence :math:`S`:

    .. math::

        W &= \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
        \mathrm{reg} \cdot\Omega(\gamma)

        W_a &= \min_{\gamma_a} \quad \langle \gamma_a, \mathbf{M_a} \rangle_F +
        \mathrm{reg} \cdot\Omega(\gamma_a)

        W_b &= \min_{\gamma_b} \quad \langle \gamma_b, \mathbf{M_b} \rangle_F +
        \mathrm{reg} \cdot\Omega(\gamma_b)

        S &= W - \frac{W_a + W_b}{2}

    .. math::
        s.t. \ \gamma \mathbf{1} &= \mathbf{a}

             \gamma^T \mathbf{1} &= \mathbf{b}

             \gamma &\geq 0

             \gamma_a \mathbf{1} &= \mathbf{a}

             \gamma_a^T \mathbf{1} &= \mathbf{a}

             \gamma_a &\geq 0

             \gamma_b \mathbf{1} &= \mathbf{b}

             \gamma_b^T \mathbf{1} &= \mathbf{b}

             \gamma_b &\geq 0
    where :

    - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`)
      is the (`n_samples_a`, `n_samples_b`) metric cost matrix
      (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`))
    - :math:`\Omega` is the entropic regularization term
      :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)


    Parameters
    ----------
    X_s : array-like, shape (n_samples_a, dim)
        samples in the source domain
    X_t : array-like, shape (n_samples_b, dim)
        samples in the target domain
    reg : float
        Regularization term >0
    a : array-like, shape (n_samples_a,)
        samples weights in the source domain
    b : array-like, shape (n_samples_b,)
        samples weights in the target domain
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True
    warn : bool, optional
        if True, raises a warning if the algorithm doesn't convergence.

    Returns
    -------
    W : (1,) array-like
        Optimal transportation symmetrized loss for the given parameters
    log : dict
        log dictionary return only if log==True in parameters

    Examples
    --------
    >>> n_samples_a = 2
    >>> n_samples_b = 4
    >>> reg = 0.1
    >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
    >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
    >>> empirical_sinkhorn_divergence(X_s, X_t, reg)  # doctest: +ELLIPSIS
    1.499887176049052


    References
    ----------
    .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative
        Models with Sinkhorn Divergences,  Proceedings of the Twenty-First
        International Conference on Artficial Intelligence and Statistics,
        (AISTATS) 21, 2018
    '''
    if log:
        sinkhorn_loss_ab, log_ab = empirical_sinkhorn2_adj(X_s, X_t, reg, a, b, metric=metric,
                                                       numIterMax=numIterMax,
                                                       stopThr=1e-9, verbose=verbose,
                                                       log=log, warn=warn, **kwargs)

        sinkhorn_loss_a, log_a = empirical_sinkhorn2_adj(X_s, X_s, reg, a, a, metric=metric,
                                                     numIterMax=numIterMax,
                                                     stopThr=1e-9, verbose=verbose,
                                                     log=log, warn=warn, **kwargs)

        sinkhorn_loss_b, log_b = empirical_sinkhorn2_adj(X_t, X_t, reg, b, b, metric=metric,
                                                     numIterMax=numIterMax,
                                                     stopThr=1e-9, verbose=verbose,
                                                     log=log, warn=warn, **kwargs)

        sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)

        log = {}
        log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
        log['sinkhorn_loss_a'] = sinkhorn_loss_a
        log['sinkhorn_loss_b'] = sinkhorn_loss_b
        log['log_sinkhorn_ab'] = log_ab
        log['log_sinkhorn_a'] = log_a
        log['log_sinkhorn_b'] = log_b

        return max(0, sinkhorn_div), log

    else:
        sinkhorn_loss_ab = empirical_sinkhorn2_adj(X_s, X_t, reg, a, b, metric=metric,
                                               numIterMax=numIterMax, stopThr=1e-9,
                                               verbose=verbose, log=log,
                                               warn=warn, **kwargs)

        sinkhorn_loss_a = empirical_sinkhorn2_adj(X_s, X_s, reg, a, a, metric=metric,
                                              numIterMax=numIterMax, stopThr=1e-9,
                                              verbose=verbose, log=log,
                                              warn=warn, **kwargs)

        sinkhorn_loss_b = empirical_sinkhorn2_adj(X_t, X_t, reg, b, b, metric=metric,
                                              numIterMax=numIterMax, stopThr=1e-9,
                                              verbose=verbose, log=log,
                                              warn=warn, **kwargs)

        sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
        return max(0, sinkhorn_div)

In [None]:
def empirical_sinkhorn2_adj(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
                        numIterMax=10000, stopThr=1e-9, isLazy=False,
                        batchSize=100, verbose=False, log=False, warn=True, **kwargs):
    r'''
    Solve the entropic regularization optimal transport problem from empirical
    data and return the OT loss


    The function solves the following optimization problem:

    .. math::
        W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
        \mathrm{reg} \cdot\Omega(\gamma)

        s.t. \ \gamma \mathbf{1} &= \mathbf{a}

             \gamma^T \mathbf{1} &= \mathbf{b}

             \gamma &\geq 0
    where :

    - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix
    - :math:`\Omega` is the entropic regularization term
      :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)


    Parameters
    ----------
    X_s : array-like, shape (n_samples_a, dim)
        samples in the source domain
    X_t : array-like, shape (n_samples_b, dim)
        samples in the target domain
    reg : float
        Regularization term >0
    a : array-like, shape (n_samples_a,)
        samples weights in the source domain
    b : array-like, shape (n_samples_b,)
        samples weights in the target domain
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (>0)
    isLazy: boolean, optional
        If True, then only calculate the cost matrix by block and return
        the dual potentials only (to save memory). If False, calculate
        full cost matrix and return outputs of sinkhorn function.
    batchSize: int or tuple of 2 int, optional
        Size of the batches used to compute the sinkhorn update without memory overhead.
        When a tuple is provided it sets the size of the left/right batches.
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True
    warn : bool, optional
        if True, raises a warning if the algorithm doesn't convergence.


    Returns
    -------
    W : (n_hists) array-like or float
        Optimal transportation loss for the given parameters
    log : dict
        log dictionary return only if log==True in parameters

    Examples
    --------

    >>> n_samples_a = 2
    >>> n_samples_b = 2
    >>> reg = 0.1
    >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
    >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
    >>> b = np.full((n_samples_b, 3), 1/n_samples_b)
    >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False)
    array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05])


    References
    ----------

    .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
        of Optimal Transport, Advances in Neural Information
        Processing Systems (NIPS) 26, 2013

    .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling
        Algorithms for Entropy Regularized Transport Problems.
        arXiv preprint arXiv:1610.06519.  [Titel anhand dieser ArXiv-ID in Citavi-Projekt übernehmen] 

    .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
        Scaling algorithms for unbalanced transport problems.
        arXiv preprint arXiv:1607.05816.  [Titel anhand dieser ArXiv-ID in Citavi-Projekt übernehmen] 
    '''

    X_s, X_t = ot.utils.list_to_array(X_s, X_t)

    nx = ot.backend.get_backend(X_s, X_t)

    ns, nt = X_s.shape[0], X_t.shape[0]
    if a is None:
        a = nx.from_numpy(unif(ns), type_as=X_s)
    if b is None:
        b = nx.from_numpy(unif(nt), type_as=X_s)

    if isLazy:
        if log:
            f, g, dict_log = ot.bregman.empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
                                                numIterMax=numIterMax,
                                                stopThr=stopThr,
                                                isLazy=isLazy,
                                                batchSize=batchSize,
                                                verbose=verbose, log=log,
                                                warn=warn)
        else:
            f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
                                      numIterMax=numIterMax, stopThr=stopThr,
                                      isLazy=isLazy, batchSize=batchSize,
                                      verbose=verbose, log=log,
                                      warn=warn)

        bs = batchSize if isinstance(batchSize, int) else batchSize[0]
        range_s = range(0, ns, bs)

        loss = 0

        X_s_np = nx.to_numpy(X_s)
        X_t_np = nx.to_numpy(X_t)

        for i in range_s:
            M_block = dist_adj(X_s_np[i:i + bs, :], X_t_np, metric=metric)
            M_block = nx.from_numpy(M_block, type_as=a)
            pi_block = nx.exp(f[i:i + bs, None] + g[None, :] - M_block / reg)
            loss += nx.sum(M_block * pi_block)

        if log:
            return loss, dict_log
        else:
            return loss

    else:
        M = dist_adj(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric)
        M = nx.from_numpy(M, type_as=a)

        if log:
            sinkhorn_loss, log = ot.bregman.sinkhorn2(a, b, M, reg, numItermax=numIterMax,
                                           stopThr=stopThr, verbose=verbose, log=log,
                                           warn=warn, **kwargs)
            return sinkhorn_loss, log
        else:
            sinkhorn_loss = ot.bregman.sinkhorn2(a, b, M, reg, numItermax=numIterMax,
                                      stopThr=stopThr, verbose=verbose, log=log,
                                      warn=warn, **kwargs)
            return sinkhorn_loss

In [None]:
def dist_adj(x1, x2=None, metric='sqeuclidean'):
    r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`

    .. note:: This function is backend-compatible and will work on arrays
        from all compatible backends.

    Parameters
    ----------

    x1 : array-like, shape (n1,d)
        matrix with `n1` samples of size `d`
    x2 : array-like, shape (n2,d), optional
        matrix with `n2` samples of size `d` (if None then :math:`\mathbf{x_2} = \mathbf{x_1}`)
    metric : str | callable, optional
        'sqeuclidean' or 'euclidean' on all backends. On numpy the function also
        accepts  from the scipy.spatial.distance.cdist function : 'braycurtis',
        'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
        'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis',
        'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
        'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.


    Returns
    -------

    M : array-like, shape (`n1`, `n2`)
        distance matrix computed with given metric

    """
    if x2 is None:
        x2 = x1
    #if metric == "sqeuclidean":
    #    return euclidean_distances(x1, x2, squared=True)
    #elif metric == "euclidean":
    #    return euclidean_distances(x1, x2, squared=False)
    #else:
    if not ot.backend.get_backend(x1, x2).__name__ == 'numpy':
        raise NotImplementedError()
    else:
        return distance.cdist(x1, x2, metric=metric)