# SparseSinkhorn Solver: Barycenters

In [None]:
from lib.header_notebook import *
import Solvers.Sinkhorn as Sinkhorn
import Solvers.Sinkhorn.Barycenter as Barycenter
import lib.header_params_Sinkhorn
%matplotlib inline

In [None]:
# verbosity: disable iteration output, may become really slow in notebooks
paramsVerbose={
        "solve_overview":True,\
        "solve_update":True,\
        "solve_kernel":True,\
        "solve_iterate":False\
        }

## Parameters

In [None]:
params=lib.header_params_Sinkhorn.getParamsDefaultBarycenter()
paramsListCommandLine,paramsListCFGFile=lib.header_params_Sinkhorn.getParamListsBarycenter()

params["setup_tag"]="cfg/Sinkhorn/Barycenter/OT_simple_1-2-1"
#params["setup_tag"]="cfg/Sinkhorn/Barycenter/WF_three_1-2-1"

In [None]:
params["setup_cfgfile"]=params["setup_tag"]+".txt"

# load parameters from config file
params.update(ScriptTools.readParameters(params["setup_cfgfile"],paramsListCFGFile))

# interpreting some parameters

# totalMass regulates, whether marginals should be normalized or not.
if params["setup_totalMass"]<0:
    params["setup_totalMass"]=None
# finest level for multi-scale algorithm
params["hierarchy_lBottom"]=params["hierarchy_depth"]+1


print("final parameter settings")
for k in sorted(params.keys()):
    print("\t",k,params[k])

## Problem Setup

In [None]:
def loadProblem(filename):
    img=sciio.loadmat(filename)["a"]
    return img

def setupDensity(img,constOffset=0,posScale=1.,keepZero=False,totalMass=None):
    (mu,pos)=OTTools.processDensity_Grid(img,totalMass=totalMass,constOffset=constOffset,keepZero=keepZero)
    pos=pos/posScale
    return (mu,pos,img.shape)

problemData=[setupDensity(loadProblem(filename),\
        posScale=params["setup_posScale"],constOffset=params["setup_constOffset"],\
        keepZero=False,totalMass=params["setup_totalMass"]\
        )\
        for filename in params["setup_fileList"]]

problemDataCenter=setupDensity(np.full(params["setup_centerRes"],1.,dtype=np.double),posScale=params["setup_posScale"])
nProblems=len(problemData)

In [None]:
# visualize marginals
fig=plt.figure(figsize=(4*nProblems,4))
for i in range(nProblems):
    img=OTTools.ProjectInterpolation2D(problemData[i][1],problemData[i][0],problemData[i][2][0],problemData[i][2][1])
    img=img.toarray()
    fig.add_subplot(1,nProblems,i+1)
    plt.imshow(img)
plt.show()

In [None]:
# set up eps-scaling

# geometric scaling from eps_start to eps_target in eps_steps+1 steps
params.update(
        Sinkhorn.Aux.SetupEpsScaling_Geometric(params["eps_target"],params["eps_start"],params["eps_steps"],\
        verbose=True))

# determine finest epsilon for each hierarchy level.
# on coarsest level it is given by params["eps_boxScale"]**params["eps_boxScale_power"]
# with each level, the finest scale params["eps_boxScale"] is effectively divided by 2
params["eps_scales"]=[(params["eps_boxScale"]/(2**n))**params["eps_boxScale_power"]\
        for n in range(params["hierarchy_depth"]+1)]+[0]
print("eps_scales:\t",params["eps_scales"])

# divide eps_list into eps_lists, one for each hierarchy scale, divisions determined by eps_scales.
params.update(Sinkhorn.Aux.SetupEpsScaling_Scales(params["eps_list"],params["eps_scales"],\
        levelTop=params["hierarchy_lTop"], nOverlap=1))

In [None]:
# construct hierarchical representation
partitionChildMode=HierarchicalPartition.THPMode_Tree

# constructing basic partitions
partitionList=[]
for i in range(nProblems):
    partition=HierarchicalPartition.GetPartition(problemData[i][1],params["hierarchy_depth"],partitionChildMode,\
            box=None, signal_pos=True, signal_radii=True,clib=SolverCFC, export=False, verbose=False,\
            finestDimsWarning=False)
    partitionList.append(partition)

partitionCenter=HierarchicalPartition.GetPartition(problemDataCenter[1],params["hierarchy_depth"],partitionChildMode,\
            box=None, signal_pos=True, signal_radii=True,clib=SolverCFC, export=False, verbose=False,\
            finestDimsWarning=False)

# exporting partitions
pointerListPartition=np.zeros((nProblems),dtype=np.int64)
for i in range(nProblems):
    pointerListPartition[i]=SolverCFC.Export(partitionList[i])

pointerPartitionCenter=SolverCFC.Export(partitionCenter)

muHList=[SolverCFC.GetSignalMass(pointer,partition,aprob[0])
        for pointer,partition,aprob in zip(pointerListPartition,partitionList,problemData)]

muHCenterList=SolverCFC.GetSignalMass(pointerPartitionCenter,partitionCenter,problemDataCenter[0])

# pointer lists
pointerPosList=[HierarchicalPartition.getSignalPointer(partition,"pos") for partition in partitionList]
pointerRadiiList=[HierarchicalPartition.getSignalPointer(partition,"radii",lBottom=partition.nlayers-2)
        for partition in partitionList]

pointerPosCenter=HierarchicalPartition.getSignalPointer(partitionCenter,"pos")
pointerRadiiCenter=HierarchicalPartition.getSignalPointer(partitionCenter,"radii",lBottom=partition.nlayers-2)

pointerListPos=np.array([pointerPos.ctypes.data for pointerPos in pointerPosList],dtype=np.int64)
pointerListRadii=np.array([pointerRadii.ctypes.data for pointerRadii in pointerRadiiList],dtype=np.int64)


# pairwise pointer lists
pointerListPosPairs=[np.array([pointerListPos[i],pointerPosCenter.ctypes.data],dtype=np.int64)
        for i in range(nProblems)]
pointerListRadiiPairs=[np.array([pointerListRadii[i],pointerRadiiCenter.ctypes.data],dtype=np.int64)
        for i in range(nProblems)]
pointerListPairsPartition=[np.array([pointerListPartition[i],pointerPartitionCenter],dtype=np.int64)
        for i in range(nProblems)]

# print a few stats on the created problem
for i,partition in enumerate(partitionList):
    print("cells in partition {:d}: ".format(i),partition.cardLayers)
print("cells in central partition: ",partitionCenter.cardLayers)

## Solver Setup

In [None]:
if params["model_transportModel"]=="ot":
    import Solvers.Sinkhorn.Models.BarycenterOT as ModelBarycenterOT

    def get_method_CostFunctionProviderPair(index):
        return lambda level, pointerAlpha, alphaFinest=None :\
                Sinkhorn.CInterface.Setup_CostFunctionProvider_SquaredEuclidean(pointerListPosPairs[index],\
                        partitionList[0].ndim,level,pointerListRadiiPairs[index],pointerAlpha,alphaFinest\
                        )

    method_iterate_iterate = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu,\
            eps, nInnerIterations: \
                    ModelBarycenterOT.Iterate(kernel,alphaList,scalingList,muList,eps,nInnerIterations,\
                            params["setup_weightList"],zeroHandling=True,setZeroInf=True)

    method_iterate_error = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu, eps:\
            ModelBarycenterOT.ErrorMarginLInf(kernel,scalingList,muList,zeroHandling=True)

elif params["model_transportModel"]=="wf":
    import Solvers.Sinkhorn.Models.BarycenterWF as ModelBarycenterWF

    def get_method_CostFunctionProviderPair(index):
        return lambda level, pointerAlpha, alphaFinest=None :\
                Sinkhorn.CInterface.Setup_CostFunctionProvider_SquaredEuclideanWF(pointerListPosPairs[index],\
                        partitionList[0].ndim,level,pointerListRadiiPairs[index],pointerAlpha,alphaFinest,\
                        FR_kappa=params["model_FR_kappa"])

    method_iterate_iterate = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu,\
            eps, nInnerIterations: \
            ModelBarycenterWF.Iterate(kernel,alphaList,scalingList,muList,eps,nInnerIterations,\
                    params["setup_weightList"], params["model_FR_kappa"],zeroHandling=True)

    method_iterate_error = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu, eps:\
            ModelBarycenterWF.ScorePDGap(kernel,alphaList,scalingList,muList,eps,\
                    params["setup_weightList"],params["model_FR_kappa"])

elif params["model_transportModel"]=="ghk":
    import Solvers.Sinkhorn.Models.BarycenterWF as ModelBarycenterWF

    def get_method_CostFunctionProviderPair(index):
        return lambda level, pointerAlpha, alphaFinest=None :\
                Sinkhorn.CInterface.Setup_CostFunctionProvider_SquaredEuclidean(pointerListPosPairs[index],\
                        partitionList[0].ndim,level,pointerListRadiiPairs[index],pointerAlpha,alphaFinest\
                        )

    method_iterate_iterate = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu,\
            eps, nInnerIterations: \
            ModelBarycenterWF.Iterate(kernel,alphaList,scalingList,muList,eps,nInnerIterations,\
                    params["setup_weightList"], params["model_FR_kappa"],zeroHandling=True)

    method_iterate_error = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu, eps:\
            ModelBarycenterWF.ScorePDGap(kernel,alphaList,scalingList,muList,eps,\
                    params["setup_weightList"],params["model_FR_kappa"])



else:
    raise ValueError("model_transportModel not recognized: "+params["model_transportModel"])

In [None]:
get_method_getKernel = lambda level, muList, muCenter:\
        lambda kernel, alpha, eps:\
                Barycenter.GetKernel_SparseCSR(partitionList,pointerListPartition,\
                        partitionCenter,pointerPartitionCenter,\
                        get_method_CostFunctionProviderPair,\
                        level, alpha, eps, muList, muCenter,\
                        kThresh=params["sparsity_kThresh"],\
                        verbose=paramsVerbose["solve_kernel"])


method_deleteKernel = lambda kernel : None


method_refineKernel = lambda level, kernel, alphaList, muList, muCenter, eps:\
        Barycenter.RefineKernel_CSR(partitionList, pointerListPartition, partitionCenter, pointerPartitionCenter,\
                get_method_CostFunctionProviderPair,\
                level, kernel, alphaList, eps, muList, muCenter,\
                verbose=paramsVerbose["solve_kernel"])

method_getKernelVariablesCount=Barycenter.GetKernelVariablesCount_CSR

method_absorbScaling = lambda alphaList,scalingList,eps:\
                Sinkhorn.Method_AbsorbScalings(alphaList,scalingList,eps,\
                        residualScaling=None,minAlpha=[-1E5 for i in range(2*nProblems)])


get_method_update = lambda epsList, method_getKernel, method_deleteKernel, method_absorbScaling:\
        lambda status, data:\
                Sinkhorn.Update(status,data,epsList,\
                        method_getKernel, method_deleteKernel, method_absorbScaling,\
                        verbose=paramsVerbose["solve_update"],absorbFinalIteration=True,maxRepeats=20)


method_iterate = lambda status,data : Sinkhorn.Method_IterateToPrecision(status,data,\
                method_iterate=method_iterate_iterate,\
                method_error=method_iterate_error,\
                maxError=params["sinkhorn_error"],\
                nInnerIterations=params["sinkhorn_nInner"],maxOuterIterations=params["sinkhorn_maxOuter"],\
                scalingBound=params["adaption_scalingBound"],scalingLowerBound=params["adaption_scalingLowerBound"],\
                verbose=paramsVerbose["solve_iterate"])

In [None]:
result=Barycenter.MultiscaleSolver(partitionList,pointerListPartition,\
        partitionCenter, pointerPartitionCenter,\
        muHList,muHCenterList,\
        params["eps_lists"],\
        get_method_getKernel,method_deleteKernel,method_absorbScaling,\
        method_iterate,get_method_update,method_refineKernel,\
        levelTop=params["hierarchy_lTop"],levelBottom=params["hierarchy_lBottom"],\
        verbose=paramsVerbose["solve_overview"],\
        collectReports=True,method_getKernelVariablesCount=method_getKernelVariablesCount\
        )

In [None]:
data=result["data"]
status=result["status"]
setup=result["setup"]
setupAux=result["setupAux"]

In [None]:
data["kernel"]=setupAux["method_getKernel"](data["kernel"],data["alpha"],data["eps"])

In [None]:
# clean up hieararchical partitions
for pointer in pointerListPartition:
    SolverCFC.Close(pointer)
    
SolverCFC.Close(pointerPartitionCenter)

## Post-Processing

In [None]:
# extract all marginals
# two lists. first list: marginals for the reference measures, second list: marginals for the common central barycenter
margs=Barycenter.GetMarginals(data["kernel"],data["scaling"])

In [None]:
# check marginal error / primal-dual gap
if params["model_transportModel"]=="ot":
    
    for f,v in zip(["error"],[\
            method_iterate_error(data["kernel"],data["alpha"],data["scaling"],data["mu"],None,None,data["eps"])\
           ]):
           print(f," : ",v)

elif params["model_transportModel"] in ["wf","ghk"]:

    for f,v in zip(["PD gap","primal","dual"],[\
            method_iterate_error(data["kernel"],data["alpha"],data["scaling"],data["mu"],None,None,data["eps"]),\
            ModelBarycenterWF.ScorePrimal(data["kernel"],data["alpha"],data["scaling"],data["mu"],data["eps"],\
                    params["setup_weightList"],params["model_FR_kappa"]),\
            ModelBarycenterWF.ScoreDual(data["kernel"],data["alpha"],data["scaling"],data["mu"],data["eps"],\
                    params["setup_weightList"],params["model_FR_kappa"])\
           ]):
           print(f," : ",v)

In [None]:
# compute barycenter from central marginals (weighted average)
img=np.zeros(problemDataCenter[2],dtype=np.double)
for i in range(nProblems):
    img[...]+=params["setup_weightList"][i]*margs[1][i].reshape(problemDataCenter[2])
plt.imshow(img)
plt.show()