# SparseSinkhorn Solver

In [None]:
from lib.header_notebook import *
import Solvers.Sinkhorn as Sinkhorn
import lib.header_params_Sinkhorn
%matplotlib inline

In [None]:
# verbosity: disable iteration output, may become really slow in notebooks
paramsVerbose={
        "solve_overview":True,\
        "solve_update":True,\
        "solve_kernel":True,\
        "solve_iterate":False\
        }

## Parameters

The full enhanced Sinkhorn algorithm requires a lot of configuration parameters:

This includes modelling aspects, such was the transport-model, e.g. whether to compute standard optimal transport, Wasserstein-Fisher-Rao (including the balancing parameter between transport and growth), gradient flows or barycenters.

But also computational aspects: which data structure for the kernel (dense or truncated; the truncation threshold), parameters for the log-stabilization scheme, for epsilon-scaling, for the multi-scale scheme, error tolerances etc.

These parameters are stored in config files in the subdirectory cfg/ and read by this script. Several test problems have been prepared. By setting params["setup_tag"] one can choose an example (see below) and subsequently run the rest of the script for solving it.

In [None]:
# setup parameter managment

params=lib.header_params_Sinkhorn.getParamsDefaultTransport()
paramsListCommandLine,paramsListCFGFile=lib.header_params_Sinkhorn.getParamListsTransport()

# choose setup_tag. This specifies, from which config file the problem parameters are loaded.
# Several examples have been prepared. Uncomment the one you would like to try.


############################################################################
# Compare successive enhancements of simple algorithm on 64x64 test image
# One example of data for Figure 2, at eps=0.1 h^2
############################################################################
## 1: log-domain stabilization: careful: takes a frustratingly long time
#params["setup_tag"]="cfg/Sinkhorn/CompareEnhancements/1"
## 2: log-domain stabilization, epsilon scaling
#params["setup_tag"]="cfg/Sinkhorn/CompareEnhancements/2"
## 3: log-domain stabilization, epsilon scaling, truncated kernel
#params["setup_tag"]="cfg/Sinkhorn/CompareEnhancements/3"
## 4: log-domain stabilization, epsilon scaling, truncated kernel, coarse-to-fine
#params["setup_tag"]="cfg/Sinkhorn/CompareEnhancements/4"


############################################################################
# Large example, on two 256x256 images
# One example of data for Figure 3, at eps=0.1 h^2, error=1E-6
############################################################################
# Standard Wasserstein distance example, as used for Fig 3
params["setup_tag"]="cfg/Sinkhorn/ImageBenchmark/OT_256"
# For comparison, a similar example with Wasserstein-Fisher-Rao distance is also given
#params["setup_tag"]="cfg/Sinkhorn/ImageBenchmark/WF_256"


############################################################################
# Compare Wasserstein with Wasserstein-Fisher-Rao distance
# An example with moving Gaussian blobs of different mass.
# Run both and compare the resulting displacement interpolations
# Standard Optimal transport will have to transfer mass between Gaussians of different weight.
# Wasserstein-Fisher-Rao can compensate the difference by local growth / annihilation.
############################################################################
#params["setup_tag"]="cfg/Sinkhorn/Gaussians/OT_128_000-001"
#params["setup_tag"]="cfg/Sinkhorn/Gaussians/WF_128_000-001"

In [None]:
params["setup_cfgfile"]=params["setup_tag"]+".txt"

# load parameters from config file
params.update(ScriptTools.readParameters(params["setup_cfgfile"],paramsListCFGFile))

# interpreting some parameters

# totalMass regulates, whether marginals should be normalized or not.
if params["setup_totalMass"]<0:
    params["setup_totalMass"]=None
# finest level for multi-scale algorithm
params["hierarchy_lBottom"]=params["hierarchy_depth"]+1


print("final parameter settings")
for k in sorted(params.keys()):
    print("\t",k,params[k])

## Problem Setup

In [None]:
# define problem: setup marginals

def loadProblem(filename):
    img=sciio.loadmat(filename)["a"]
    return img

def setupDensity(img,posScale,totalMass,constOffset,keepZero):
    (mu,pos)=OTTools.processDensity_Grid(img,totalMass=totalMass,constOffset=constOffset,keepZero=keepZero)
    pos=pos/posScale
    return (mu,pos,img.shape)

problemData=[setupDensity(loadProblem(filename),posScale=params["setup_posScale"],\
        totalMass=params["setup_totalMass"],constOffset=params["setup_constOffset"],keepZero=False)\
        for filename in [params["setup_f1"],params["setup_f2"]]]


nProblems=len(problemData)

In [None]:
# visualize marginals
fig=plt.figure(figsize=(4*nProblems,4))
for i in range(nProblems):
    img=OTTools.ProjectInterpolation2D(problemData[i][1],problemData[i][0],problemData[i][2][0],problemData[i][2][1])
    img=img.toarray()
    fig.add_subplot(1,nProblems,i+1)
    plt.imshow(img)
plt.show()

In [None]:
# set up eps-scaling

# geometric scaling from eps_start to eps_target in eps_steps+1 steps
params.update(
        Sinkhorn.Aux.SetupEpsScaling_Geometric(params["eps_target"],params["eps_start"],params["eps_steps"],\
        verbose=True))

# determine finest epsilon for each hierarchy level.
# on coarsest level it is given by params["eps_boxScale"]**params["eps_boxScale_power"]
# with each level, the finest scale params["eps_boxScale"] is effectively divided by 2
params["eps_scales"]=[(params["eps_boxScale"]/(2**n))**params["eps_boxScale_power"]\
        for n in range(params["hierarchy_depth"]+1)]+[0]
print("eps_scales:\t",params["eps_scales"])

# divide eps_list into eps_lists, one for each hierarchy scale, divisions determined by eps_scales.
params.update(Sinkhorn.Aux.SetupEpsScaling_Scales(params["eps_list"],params["eps_scales"],\
        levelTop=params["hierarchy_lTop"], nOverlap=1))

In [None]:
## setup hierarchical partitions
partitionChildMode=HierarchicalPartition.THPMode_Tree

# constructing basic partitions
partitionList=[]
for i in range(nProblems):
    partition=HierarchicalPartition.GetPartition(problemData[i][1],params["hierarchy_depth"],partitionChildMode,\
            box=None, signal_pos=True, signal_radii=True,clib=SolverCFC, export=False, verbose=False,\
            finestDimsWarning=False)
    partitionList.append(partition)

# exporting partitions
pointerListPartition=np.zeros((nProblems),dtype=np.int64)
for i in range(nProblems):
    pointerListPartition[i]=SolverCFC.Export(partitionList[i])

muHList=[SolverCFC.GetSignalMass(pointer,partition,aprob[0])
        for pointer,partition,aprob in zip(pointerListPartition,partitionList,problemData)]

# pointer lists
pointerPosList=[HierarchicalPartition.getSignalPointer(partition,"pos") for partition in partitionList]
pointerRadiiList=[HierarchicalPartition.getSignalPointer(partition,"radii",lBottom=partition.nlayers-2)
        for partition in partitionList]
pointerListPos=np.array([pointerPos.ctypes.data for pointerPos in pointerPosList],dtype=np.int64)
pointerListRadii=np.array([pointerRadii.ctypes.data for pointerRadii in pointerRadiiList],dtype=np.int64)

# print a few stats on the created problem
for i,partition in enumerate(partitionList):
    print("cells in partition {:d}: ".format(i),partition.cardLayers)

## Solver Setup

The code for the algorithm is extremely "modularized". Almost every part can be controlled by supplying suitable methods as parameters. This encompasses both modelling aspects such as different transport-type problems, as well as computational aspects, such as data structures for handling the kernels or cost functions.
This requires a somewhat lengthy configuration sequence, and does not yield the fastest performance, but is very flexible and is thus useful for development.

In [None]:
# model specific stuff
import Solvers.Sinkhorn.Models.OT as ModelOT

if params["model_transportModel"]=="ot":

    method_CostFunctionProvider = lambda level, pointerAlpha, alphaFinest=None :\
            Sinkhorn.CInterface.Setup_CostFunctionProvider_SquaredEuclidean(pointerListPos,\
                    partitionList[0].ndim,level,pointerListRadii,pointerAlpha,alphaFinest)

    method_iterate_iterate = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu,\
            eps, nInnerIterations: \
            ModelOT.Iterate(kernel[0],kernel[1],scalingList[0],scalingList[1],muList[0],muList[1],nInnerIterations)

    def method_iterate_error(kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu, eps):
            return ModelOT.ErrorMarginLInf(kernel[0],kernel[1],scalingList[0],scalingList[1],muList[0],muList[1])

elif params["model_transportModel"]=="wf":

    import Solvers.Sinkhorn.Models.WF as ModelWF

    method_CostFunctionProvider = lambda level, pointerAlpha, alphaFinest=None :\
            Sinkhorn.CInterface.Setup_CostFunctionProvider_SquaredEuclideanWF(pointerListPos,\
                    partitionList[0].ndim,level,pointerListRadii,pointerAlpha,alphaFinest,\
                    FR_kappa=params["model_FR_kappa"],FR_cMax=params["model_FR_cMax"])

    method_iterate_iterate = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu,\
            eps, nInnerIterations: \
            ModelWF.Iterate(kernel[0],kernel[1],alphaList[0],alphaList[1],scalingList[0],scalingList[1],\
                    muList[0],muList[1],eps,params["model_FR_kappa"],nInnerIterations)

    method_iterate_error = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu, eps: \
            ModelWF.ScorePDGap(kernel[0],kernel[1],alphaList[0],alphaList[1],\
                    scalingList[0],scalingList[1],muList[0],muList[1],\
                    eps,params["model_FR_kappa"])

else:
    raise ValueError("model_transportModel not recognized: "+params["model_transportModel"])

In [None]:
# data structure choice for kernel
if params["setup_type_kernel"]=="csr":
    get_method_getKernel = lambda level, muList:\
            lambda kernel, alpha, eps:\
                    Sinkhorn.GetKernel_SparseCSR(
                            partitionList,pointerListPartition,\
                            method_CostFunctionProvider,\
                            level, alpha, eps,\
                            kThresh=params["sparsity_kThresh"],\
                            baseMeasureX=muList[0], baseMeasureY=muList[1],\
                            sanityCheck=False,\
                            verbose=paramsVerbose["solve_kernel"])

    method_deleteKernel = lambda kernel : None

    method_refineKernel = lambda level, kernel, alphaList, muList, eps:\
                    Sinkhorn.RefineKernel_CSR(partitionList, pointerListPartition,\
                            method_CostFunctionProvider,\
                            level, (kernel[0].indices,kernel[0].indptr), alphaList,\
                            eps,\
                            baseMeasureX=muList[0], baseMeasureY=muList[1],\
                            verbose=paramsVerbose["solve_kernel"])

    method_getKernelVariablesCount=Sinkhorn.GetKernelVariablesCount_CSR

elif params["setup_type_kernel"]=="dense":
    get_method_getKernel = lambda level, muList:\
            lambda kernel, alpha, eps:\
                    Sinkhorn.GetKernel_DenseArray(
                            partitionList,pointerListPartition,\
                            method_CostFunctionProvider,\
                            level, alpha, eps,\
                            baseMeasureX=muList[0], baseMeasureY=muList[1],\
                            truncationThresh=1E-200,verbose=paramsVerbose["solve_kernel"]\
                            )


    method_deleteKernel = lambda kernel : None

    method_refineKernel = lambda level, kernel, alphaList, muList, eps:\
            None
    
    method_getKernelVariablesCount=Sinkhorn.GetKernelVariablesCount_Array

else:
    raise ValueError("setup_type_kernel not recognized: "+params["setup_type_kernel"])

In [None]:
if params["setup_doAbsorption"]==1:
    method_absorbScaling = lambda alphaList,scalingList,eps:\
                    Sinkhorn.Method_AbsorbScalings(alphaList,scalingList,eps,\
                            residualScaling=None,minAlpha=None,verbose=False)
else:
    method_absorbScaling = lambda alphaList,scalingList,eps: None

get_method_update = lambda epsList, method_getKernel, method_deleteKernel, method_absorbScaling:\
        lambda status, data:\
                Sinkhorn.Update(status,data,epsList,\
                        method_getKernel, method_deleteKernel, method_absorbScaling,\
                        absorbFinalIteration=True,maxRepeats=params["sinkhorn_maxRepeats"],\
                        verbose=paramsVerbose["solve_update"]\
                        )


method_iterate = lambda status,data : Sinkhorn.Method_IterateToPrecision(status,data,\
                method_iterate=method_iterate_iterate,method_error=method_iterate_error,\
                maxError=params["sinkhorn_error"],\
                nInnerIterations=params["sinkhorn_nInner"],maxOuterIterations=params["sinkhorn_maxOuter"],\
                scalingBound=params["adaption_scalingBound"],scalingLowerBound=params["adaption_scalingLowerBound"],\
                verbose=paramsVerbose["solve_iterate"])

In [None]:
result=Sinkhorn.MultiscaleSolver(partitionList,pointerListPartition,muHList,params["eps_lists"],\
        get_method_getKernel,method_deleteKernel,method_absorbScaling,\
        method_iterate,get_method_update,method_refineKernel,\
        params["hierarchy_lTop"],params["hierarchy_lBottom"],\
        collectReports=True,method_getKernelVariablesCount=method_getKernelVariablesCount,\
        verbose=paramsVerbose["solve_overview"],\
        )

In [None]:
# extract results of algorithm
data=result["data"]
status=result["status"]
setup=result["setup"]
setupAux=result["setupAux"]

In [None]:
# re-estimate kernel one last time
data["kernel"]=setupAux["method_getKernel"](data["kernel"],data["alpha"],data["eps"])

In [None]:
# clean up hieararchical partitions
for pointer in pointerListPartition:
    SolverCFC.Close(pointer)

## Post-Processing

Some post-processing for fun.

### Extract Coupling

Extract coupling from solution, test marginal accuracy.

In [None]:
import Solvers.Sinkhorn.Models.Common as ModelCommon

In [None]:
pi=ModelCommon.GetCouplingCSR(data["kernel"][0],data["scaling"][0],data["scaling"][1])

In [None]:
if params["model_transportModel"]=="ot":
    # L^1 marginal errors
    m0,m1=ModelCommon.GetMarginals(pi)
    print([np.sum(np.abs(marg-probData[0])) for marg,probData in zip((m0,m1),problemData)])

### Displacement Interpolation

In this section we compute a naive displacement interpolation from the optimal couplings. For classical optimal transport it is well known how to do this (in the continuum), for Wasserstein-Fisher-Rao / Hellinger-Kantorovich we refer to [Liero, Mielke, Savaré: 'Optimal Entropy-Transport problems and a new Hellinger-Kantorovich distance between positive measures'], see also [Chizat, Peyré, Schmitzer, Vialard: 'An Interpolating Distance between Optimal Transport and Fisher-Rao Metrics' and 'Unbalanced Optimal Transport: Geometry and Kantorovich Formulation'].

Note that handling of the discretization is done in a very simplistic way: for each discrete mass particle, the optimal continuous trajectory over the image plane is computed. Each travelling particle is then projected onto the nearest pixels, its mass being distributed according to piecewise linear interpolation. This leads to oscillation-type artifacts when mass distributions are deformed smoothly. These artifacts are purely a result of the interpolation process.

Nevertheless, these interpolations help to visualize the transport process.

In [None]:
if params["model_transportModel"]=="ot":
    # extract data about travelling mass particles from coupling
    particles=ModelOT.GetParticles(pi)
    print(particles[0].shape)
    method_interpolate=ModelOT.interpolateEuclidean
elif params["model_transportModel"]=="wf":
    # extract data about travelling mass particles from coupling
    particles=ModelWF.GetParticles(pi,problemData[0][0],problemData[1][0],1E-7)
    print([x[0].shape for x in particles])
    def method_interpolate(particles,pos0,pos1,t):
            rhoPre=[ModelWF.interpolateEuclidean(p,posx,posy,t,params["model_FR_kappa"])\
                    for p,posx,posy in zip(particles,[pos0,pos0,pos1],[pos1,pos0,pos1])]
            rhoPos=np.vstack([x[0] for x in rhoPre])
            rhoMass=np.hstack([x[1] for x in rhoPre])
            return (rhoPos,rhoMass)

In [None]:
datT=np.linspace(0,1,10)
imgList=[]
res=problemData[0][2]
for i,t in enumerate(datT):

    rho=method_interpolate(particles,problemData[0][1],problemData[1][1],t)

    projection=OTTools.ProjectInterpolation2D(rho[0],rho[1],res[0],res[1]).toarray()
    imgList.append(projection)
    
    plt.imshow(projection)
    plt.show()

In [None]:
massList=[np.sum(img) for img in imgList]

In [None]:
plt.plot(massList)