In [192]:
import numpy as np
import mosek
from scipy.linalg import sqrtm

In [168]:
def ot(cost,p1,p2):
    #function ot() performs standard Optimal Transport
    
    #cost[m1,m2] is the cost distance between any pair of clusters.
    # When Wasserstein-2 metric is computed, cost() is the square of the Eulcidean distance between 
    # two support points across the two distributions. The solution is in res.
    # objval is the optimized objective function value. If Wasserstein-2 metric
    # is to be solved (using the right cost defintion, see above), sqrt(objval)
    # is the Wasserstein-2 metric.
    #p1[m1] is the marginal for the first distribution, column vector
      #p2[m2] is the marginal for the second distribution, column vector

    #  cost=[0.1,0.2,1.0;0.8,0.8,0.1]; p1=[0.4;0.6]; p2=[0.2;0.3;0.5];

    
    m1 = cost.shape[0] #number of support points in the first distribution
    m2 = cost.shape[1] #number of support points in the second distribution

    if (np.sum(p1) == 0.0 or np.sum(p2) == 0.0):
      print("Warning: total probability is zero: %f, %f\n" % (np.sum(p1), np.sum(p2)))
      return

    # Normalization
    p1 = p1 / np.sum(p1)
    p2 = p2 / np.sum(p2)
    
    with mosek.Task() as task:
    
        coststack = np.reshape(cost,m1 * m2, order='F')
        c = coststack
        blx = np.zeros(m1 * m2)
        ulx = np.inf * np.ones(m1 * m2)
        bkx = [mosek.boundkey.lo] * (m1 * m2)
        a = np.zeros((m1 + m2, m1 * m2))
        blc = np.zeros(m1 + m2)
        buc = np.zeros(m1 + m2)
        bkc = [mosek.boundkey.fx] * (m1 + m2)

        # Generate subscript matrix for easy reference
        wijsub = np.zeros((m1, m2))
        k = 0
        for j in range(m2):
            for i in range(m1):
                wijsub[i,j] = k
                k = k + 1

        # Set up the constraints
        for i in range(m1):
            for j in range(m2):
                a[i, int(wijsub[i,j])] = 1.0
            buc[i] = p1[i]
            blc[i] = p1[i]


        for j in range(m2):
            for i in range(m1):
                a[j+m1, int(wijsub[i,j])] = 1.0
            buc[j+m1] = p2[j]
            blc[j+m1] = p2[j]
    
        numvar = len(blx)
        numcon = len(blc)

    
        # Append 'numcon' empty constraints.
        # The constraints will initially have no bounds.
        task.appendcons(numcon)

        # Append 'numvar' variables.
        # The variables will initially be fixed at zero (x=0).
        task.appendvars(numvar)

        for j in range(numvar):
            # Set the linear term c_j in the objective.
            task.putcj(j, c[j])

            # Set the bounds on variable j
            # blx[j] <= x_j <= bux[j]
            task.putvarbound(j, bkx[j], blx[j], ulx[j])
        
            asub = []
            aval = []
            for i in range(a.shape[0]):
                if a[i, j] != 0:
                    asub.append(i)
                    aval.append(a[i, j])

            # Input column j of A
            task.putacol(j,                  # Variable (column) index.
                         asub,            # Row index of non-zeros in column j.
                         aval)            # Non-zero Values of column j.

        for i in range(numcon):
            task.putconbound(i, bkc[i], blc[i], buc[i])
    
        task.putintparam(mosek.iparam.optimizer, mosek.optimizertype.intpnt)
        task.putintparam(mosek.iparam.log, 0)
    
        # Input the objective sense (minimize/maximize)
        task.putobjsense(mosek.objsense.minimize)
    
        # Solve the problem
        task.optimize()

        # To extract the optimized objective function value
        objval = task.getprimalobj(mosek.soltype.itr)

        # To extract the matching weights solved
        xx = task.getxx(mosek.soltype.itr)
        gammaij = np.reshape(np.array(xx)[0:m1*m2], (m1, m2), order='F')
       
    return({"objval": objval, "gammaij": gammaij})



In [181]:
cost = np.array([[0.1,0.2,1.0], [0.8,0.8,0.1]])
cost

array([[0.1, 0.2, 1. ],
       [0.8, 0.8, 0.1]])

In [182]:
p1 = np.array([0.4,0.6])
p2 = np.array([0.2, 0.3, 0.5])

In [185]:
result = ot(cost, p1, p2)

In [186]:
result["objval"]

0.19

In [187]:
result["gammaij"]

array([[0.2, 0.2, 0. ],
       [0. , 0.1, 0.5]])

In [213]:
def GaussWasserstein(d, supp1, supp2):
    # Compute the pairwise squared Wasserstein distance between each component in supp1 and each component in supp2
    # numcomponents in a distribution is size(supp1.2).
    #Suppose numcmp1=size(supp1,2), numcmp2=size(supp2,2)
    # Squared Wasserstein distance between two Gaussian:
    # \|\mu_1-\mu_2\|^2+trace(\Sigma_1+\Sigma_2-2*(\Sigma_1^{1/2}\Sigma_2\Sigma_1^{1/2})^{1/2})
    # For commutative case when \Sigma_1*\Sigma_2=\Sigma_2*\Sigma_1 (true for symmetric matrices)
    # The distance is equivalent to
    # \|\mu_1-\mu_2\|^2+\|Sigma_1^{1/2}-\Sigma_2^{1/2}\|_{Frobenius}^2
    # Frobenius norm of matrices is the Euclidean (L2) norm of the stacked
    # vector converted from the matrix
    # We use the commutative case formula to avoid more potential precision
    # errors

    numcmp1 = supp1.shape[1]
    numcmp2 = supp2.shape[1]
    pairdist = np.zeros((numcmp1, numcmp2))

    for ii in range(numcmp1):
        for jj in range(numcmp2):
            sigma1 = np.reshape(supp1[d:d+d*d,ii], (d,d), order='F')
            sigma2 = np.reshape(supp2[d:d+d*d,jj], (d,d), order='F')

            # b1=sqrtm_eig(sigma1); %use eigen value decomposition to solve squre root          
            # b2=sqrtm_eig(sigma2);
            b1 = sqrtm(sigma1)
            b2 = sqrtm(sigma2)

            mudif = supp1[0:dim,ii] - supp2[0:dim,jj]
            pairdist[ii,jj] = np.sum(mudif * mudif) + np.sum((b1 - b2) * (b1 - b2))

    return(pairdist)


In [250]:
def Mawdist(d, supp1, supp2,w1,w2):
    # Compute the MAW distance between two GMM with Gusassian component parameters specified in supp1 and supp2 and prior specified to w1 and w2

    pairdist = GaussWasserstein(d, supp1, supp2)
    result = ot(pairdist, w1, w2)
    return({"dist": result["objval"], "gammaij": result["gammaij"]})


In [251]:
import scipy.io
mouse = scipy.io.loadmat('/Users/jz259/Desktop/Prelim/MAW/test/mouse_2.mat')

In [252]:
supp = mouse["supp"]
stride = mouse["stride"]
ww = mouse["ww"][0]

In [256]:
final = Mawdist(d = 2, 
                supp1 = supp[:,0:stride[0]], 
                supp2 = supp[:,stride[0]:(stride[0]+stride[1])],
                w1 = ww[0:stride[0]],
                w2 = ww[stride[0]:(stride[0]+stride[1])])

In [257]:
final["dist"]

0.735194981136692

In [258]:
final["gammaij"]

array([[6.47204666e-03, 2.03892192e-12, 4.47391747e-02, 1.73672196e-13,
        2.04479470e-13, 1.22125031e-13, 1.10222441e-13, 1.12335695e-13,
        5.19380750e-03, 5.90498647e-13, 8.52713268e-13, 6.49004465e-13,
        8.06878139e-13, 4.96591701e-13, 2.86370811e-13, 3.92808034e-13],
       [1.12703812e-12, 5.78225218e-13, 1.10078584e-11, 1.53278710e-13,
        1.75277406e-13, 1.05872739e-13, 9.44569580e-14, 9.40233601e-14,
        3.39789316e-04, 3.90034209e-13, 5.75391566e-13, 4.22674260e-13,
        6.18964893e-13, 4.25105033e-13, 2.36971614e-13, 2.93720394e-13],
       [3.53788716e-13, 2.62081385e-13, 4.95752633e-01, 1.50942942e-13,
        2.59542909e-13, 1.70993129e-13, 1.69979592e-13, 1.64104484e-13,
        2.58864111e-12, 1.71734016e-13, 1.82345503e-13, 1.76251687e-13,
        2.45214512e-13, 2.12704666e-13, 1.57644012e-13, 1.74534773e-13],
       [3.01149525e-13, 2.72894530e-13, 1.49180162e-12, 3.24110672e-02,
        4.09834282e-02, 5.30923806e-13, 2.24290429e-13, 2.786

In [260]:
scipy.io.savemat("python_result.mat", final)