Skip to content

Commit

Permalink
Add DiscreteMeasure (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Dec 11, 2020
1 parent 73c394c commit 3a839f6
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 92 deletions.
58 changes: 58 additions & 0 deletions src/StochasticOptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,62 @@ import Statistics
include("utils.jl")
include("semidiscrete.jl")

@doc raw"""
wasserstein([rng, ], c, μ, ν[, ε; kwargs...])
Estimate the (entropic regularization of the) Wasserstein distance
```math
W_{ε}(μ, ν) = \min_{π ∈ Π(μ,ν)} \int c(x, y) \,π(\mathrm{d}(x,y)) +
ε \mathrm{KL}(π \,|\, μ ⊗ ν)
```
with respect to cost function `c` using stochastic optimization.
If measure `μ` is an arbitrary measure for which samples can be obtained with
`rand(rng, μ)` and `ν` is a [`DiscreteMeasure`](@ref), then the Wasserstein
distance is approximated with stochastic gradient descent with averaging (SGA).
If `ε` is `nothing` (the default), then the unregularized Wasserstein distance is
approximated. Otherwise, the entropic regularization with `ε > 0` is estimated.
The SGA algorithm uses the step size schedule
```math
τᵢ = \frac{τ₁}{1 + \sqrt{(i - 1) / w}}
```
for the ``i``th iteration, where ``τ₁`` corresponds to the initial step size and ``w``
indicates the number of iterations serving as a warm-up phase.
# Keyword arguments
- `maxiters::Int = 10_000`: maximum number of gradient steps
- `initial_stepsize = 1`: initial step size ``τ₁``
- `warmup_phase = 1`: warm-up phase ``w``
- `atol = 0`: absolute tolerance of the SGA algorithm
- `rtol = iszero(atol) ? typeof(float(atol))(1 // 10_000) : 0`: relative tolerance of the
SGA algorithm
- `montecarlo_samples = 10_000`: Number of Monte Carlo samples from `μ` for approximating
an expectation with respect to `μ`
# References
Genevay et al. (2016). Stochastic Optimization for Large-Scale Optimal Transport. Advances in Neural Information Processing Systems (NIPS 2016), 29:3440-3448.
Peyré, Gabriel, & Marco Cuturi (2019). Computational Optimal Transport. Foundations and Trends in Machine Learning, 11(5-6):355-607.
"""
wasserstein(args...; kwargs...) = wasserstein(Random.GLOBAL_RNG, args...; kwargs...)
function wasserstein(
rng::Random.AbstractRNG,
c,
μ,
ν,
ε::Union{Real,Nothing} = nothing;
kwargs...,
)
# approximate solution `v` of the dual problem
v = dual_v(rng, c, μ, ν, ε; kwargs...)

# compute Wasserstein distance from dual solution
cost = dual_cost(rng, c, v, μ, ν, ε; kwargs...)

return cost
end

end
94 changes: 23 additions & 71 deletions src/semidiscrete.jl
Original file line number Diff line number Diff line change
@@ -1,73 +1,27 @@
@doc raw"""
wasserstein_SGA([rng, ], c, μ, ν, ys[, ε; kwargs...])
Estimate the (entropic regularization of the) Wasserstein distance
```math
W_{ε}(μ, ν) = \min_{π ∈ Π(μ,ν)} \int c(x, y) \,π(\mathrm{d}(x,y)) +
ε \mathrm{KL}(π \,|\, μ ⊗ ν)
```
with respect to cost function `c` using stochastic gradient descent with averaging (SGA).
Measure `μ` can be an arbitrary measure for which samples can be obtained with
`rand(rng, μ)`. The inputs `ν` and `ys` have to be `AbstractVector`s and define
a discrete measure with support `ys`.
If `ε` is `nothing` (the default), then the unregularized Wasserstein distance is
approximated. Otherwise, the entropic regularization with `ε > 0` is estimated.
The SGA algorithm uses the step size schedule
```math
τᵢ = \frac{τ₁}{1 + \sqrt{(i - 1) / w}}
```
for the ``i``th iteration, where ``τ₁`` corresponds to the initial step size and ``w``
indicates the number of iterations serving as a warm-up phase.
# Keyword arguments
- `maxiters::Int = 10_000`: maximum number of gradient steps
- `initial_stepsize = 1`: initial step size ``τ₁``
- `warmup_phase = 1`: warm-up phase ``w``
- `atol = 0`: absolute tolerance of the SGA algorithm
- `rtol = iszero(atol) ? typeof(float(atol))(1 // 10_000) : 0`: relative tolerance of the
SGA algorithm
- `montecarlo_samples = 10_000`: Number of Monte Carlo samples from `μ` for approximating
an expectation with respect to `μ`
# References
Genevay et al. (2016). Stochastic Optimization for Large-Scale Optimal Transport. Advances in Neural Information Processing Systems (NIPS 2016), 29:3440-3448.
Peyré, Gabriel, & Marco Cuturi (2019). Computational Optimal Transport. Foundations and Trends in Machine Learning, 11(5-6):355-607.
"""
wasserstein_SGA(args...; kwargs...) = wasserstein_SGA(Random.GLOBAL_RNG, args...; kwargs...)
function wasserstein_SGA(
function dual_cost(
rng::Random.AbstractRNG,
c,
v,
μ,
ν::AbstractVector,
ys::AbstractVector,
ε::Union{Real,Nothing} = nothing;
ν::DiscreteMeasure,
ε;
montecarlo_samples = 10_000,
kwargs...,
)
# approximate solution `v` of the dual problem
v = dual_v_SGA(rng, c, μ, ν, ys, ε; kwargs...)

# compute MC estimate of the expected c-transform with respect to `μ`
mean_ctransform = Statistics.mean(
ctransform(c, v, rand(rng, μ), ys, ν, ε) for _ in 1:montecarlo_samples
ctransform(c, v, rand(rng, μ), ν, ε) for _ in 1:montecarlo_samples
)

return LinearAlgebra.dot(v, ν) + mean_ctransform
return LinearAlgebra.dot(v, ν.ps) + mean_ctransform
end

dual_v_SGA(args...; kwargs...) = dual_v_SGA(Random.GLOBAL_RNG, args...; kwargs...)
function dual_v_SGA(
function dual_v(
rng::Random.AbstractRNG,
c,
μ,
ν::AbstractVector,
ys::AbstractVector,
ε::Union{Real,Nothing} = nothing;
ν::DiscreteMeasure,
ε;
maxiters::Int = 10_000,
initial_stepsize = 1,
warmup_phase = 1,
Expand All @@ -78,7 +32,7 @@ function dual_v_SGA(
k = 1
x = rand(rng, μ)
τ = initial_stepsize / (1 + (0 / warmup_phase))
= gradient_step(c, τ, ν, x, ys, ε)
= gradient_step(c, τ, ν, x, ε)

# initial dual solution
v = copy(ṽ)
Expand All @@ -93,7 +47,7 @@ function dual_v_SGA(
k += 1
x = rand(rng, μ)
τ = initial_stepsize / (1 + ((k - 1) / warmup_phase))
gradient_step!(c, ṽ, τ, ṽ, ν, x, ys, Δv, ε)
gradient_step!(c, ṽ, τ, ṽ, ν, x, Δv, ε)

# estimate error
@. Δv = (ṽ - v) / k
Expand All @@ -112,19 +66,19 @@ function dual_v_SGA(
end

# initial gradient step (unregularized subgradient)
function gradient_step(c, τ, ν::AbstractVector, x, ys::AbstractVector, ::Nothing)
tmp = @. c((x,), ys)
function gradient_step(c, τ, ν::DiscreteMeasure, x, ::Nothing)
tmp = @. c((x,), ν.xs)
i = argmin(tmp)
z = τ .* ν
z = τ .* ν.ps
z[i] -= τ
return z
end

# initial gradient step (regularized gradient)
function gradient_step(c, τ, ν::AbstractVector, x, ys::AbstractVector, ε::Real)
tmp = @. - c((x,), ys) / ε
function gradient_step(c, τ, ν::DiscreteMeasure, x, ε::Real)
tmp = @. - c((x,), ν.xs) / ε
StatsFuns.softmax!(tmp)
z = @. τ *- tmp)
z = @. τ *.ps - tmp)
return z
end

Expand All @@ -134,14 +88,13 @@ function gradient_step!(
z::AbstractVector,
τ,
v::AbstractVector,
ν::AbstractVector,
ν::DiscreteMeasure,
x,
ys::AbstractVector,
tmp::AbstractVector,
::Nothing,
)
@. tmp = c((x,), ys) - v
@. z += τ * ν
@. tmp = c((x,), ν.xs) - v
@. z += τ * ν.ps
z[argmin(tmp)] -= τ
return z
end
Expand All @@ -152,14 +105,13 @@ function gradient_step!(
z::AbstractVector,
τ,
v::AbstractVector,
ν::AbstractVector,
ν::DiscreteMeasure,
x,
ys::AbstractVector,
tmp::AbstractVector,
ε::Real,
)
@. tmp = (v - c((x,), ys)) / ε
@. tmp = (v - c((x,), ν.xs)) / ε
StatsFuns.softmax!(tmp)
@. z += τ *- tmp)
@. z += τ *.ps - tmp)
return z
end
36 changes: 27 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,52 @@
struct DiscreteMeasure{X<:AbstractVector,P<:AbstractVector}
xs::X
ps::P

function DiscreteMeasure{X,P}(xs::X, ps::P) where {X,P}
length(xs) == length(ps) ||
error("length of support `xs` and probabilities `ps` must be equal")
new{X,P}(xs, ps)
end
end

"""
DiscreteMeasure(xs::AbstractVector, ps::AbstractVector)
Construct a discrete measure with support `xs` and corresponding weights `ps`.
"""
function DiscreteMeasure(xs::AbstractVector, ps::AbstractVector)
return DiscreteMeasure{typeof(xs),typeof(ps)}(xs, ps)
end

@doc raw"""
ctransform(c, v, x, ys, ν, ε)
ctransform(c, v, x, ν, ε)
Compute the c-transform
```math
v^{c,ε}(x) = \begin{cases}
- ε \log\bigg(\sum_{i=1}^n \exp{\Big(\frac{v[i] - c(x, ys[i])}{ε}\Big)} ν[i]\bigg) & \text{if } ε > 0,\\
\min_{i} c(x, y[i]) - v[i] & \text{otherwise}.
- ε \log\bigg(\int \exp{\Big(\frac{v_y - c(x, y)}{ε}\Big)} \, ν(\mathrm{d}y)\bigg) & \text{if } ε > 0,\\
\min_{y} c(x, y) - v_y & \text{otherwise}.
\end{cases}
```
"""
function ctransform(
c,
v::AbstractVector,
x,
ys::AbstractVector,
ν::AbstractVector,
ν::DiscreteMeasure,
::Nothing,
)
return minimum(c(x, yᵢ) - vᵢ for (vᵢ, yᵢ) in zip(v, ys))
return minimum(c(x, yᵢ) - vᵢ for (vᵢ, yᵢ) in zip(v, ν.xs))
end
function ctransform(
c,
v::AbstractVector,
x,
ys::AbstractVector,
ν::AbstractVector,
ν::DiscreteMeasure,
ε::Real,
)
t = StatsFuns.logsumexp(
(vᵢ - c(x, yᵢ)) / ε + log(νᵢ) for (vᵢ, yᵢ, νᵢ) in zip(v, ys, ν)
(vᵢ - c(x, yᵢ)) / ε + log(νᵢ) for (vᵢ, yᵢ, νᵢ) in zip(v, ν.xs, ν.ps)
)
return - ε * (t + 1)
end
Expand Down
26 changes: 14 additions & 12 deletions test/semidiscrete.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,21 @@
c(x, y) = abs(x - y)
τ = rand()
x = randn()
ys = randn(100)
xs = randn(100)
ps = rand(100)
ps ./= sum(ps)
ν = SOT.DiscreteMeasure(xs, ps)
v = zeros(100)
ν = rand(100)
ν ./= sum(ν)

# unregularized and regularized approach
for ε in (nothing, abs(randn()))
# out-of-place method
z0 = SOT.gradient_step(c, τ, ν, x, ys, ε)
z0 = SOT.gradient_step(c, τ, ν, x, ε)

# in-place method with `v = 0`
z = zero(z0)
tmp = similar(z)
SOT.gradient_step!(c, z, τ, v, ν, x, ys, tmp, ε)
SOT.gradient_step!(c, z, τ, v, ν, x, tmp, ε)
@test z0 z0
end
end
Expand All @@ -25,12 +26,13 @@
c(x, y) = abs(x - y)

# equal source and target distribution
ys = randn(3)
ν = rand(3)
ν ./= sum(ν)
μ = DiscreteNonParametric(ys, ν)
@test SOT.wasserstein_SGA(c, μ, ν, ys) 0 atol=2e-2
@test SOT.wasserstein_SGA(c, μ, ν, ys, 1e-6) 0 atol=2e-2
@test SOT.wasserstein_SGA(c, μ, ν, ys, 1e-3) 0 atol=2e-2
xs = randn(3)
ps = rand(3)
ps ./= sum(ps)
μ = DiscreteNonParametric(xs, ps)
ν = SOT.DiscreteMeasure(xs, ps)
@test SOT.wasserstein(c, μ, ν) 0 atol=2e-2
@test SOT.wasserstein(c, μ, ν, 1e-6) 0 atol=2e-2
@test SOT.wasserstein(c, μ, ν, 1e-3) 0 atol=2e-2
end
end

0 comments on commit 3a839f6

Please sign in to comment.