/
dot_sort.jl
77 lines (68 loc) · 2.19 KB
/
dot_sort.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#############################################################################
# dot_sort.jl
# dot_sort(a,b) computes dot(sort(a), sort(b))
# All expressions and atoms are subtpyes of AbstractExpr.
# Please read expressions.jl first.
#############################################################################
export dot_sort
export sign, curvature, monotonicity, evaluate
# This atom computes dot(sort(x), sort(w)), where w is constant
# for example, if w = [1 1 1 0 0 0 ... 0], it computes the sum of the 3 largest elements of x
type DotSortAtom <: AbstractExpr
head::Symbol
id_hash::Uint64
children::@compat Tuple{AbstractExpr}
size::@compat Tuple{Int, Int}
w::Value
function DotSortAtom(x::AbstractExpr, w::Value)
if !(length(w) == get_vectorized_size(x))
error("x and w must be the same size")
end
children = (x,)
vecw = reshape(w, get_vectorized_size(x))
return new(:dot_sort, hash((children, vecw)), children, (1,1), vecw)
end
end
function sign(x::DotSortAtom)
if all(x.w.>=0)
return sign(x.children[1])
elseif all(x.w.<=0)
return sign(x.children[1])
else
return NoSign()
end
end
function monotonicity(x::DotSortAtom)
if all(x.w.>=0)
return (Nondecreasing(), )
else
return (NoMonotonicity(), )
end
end
function curvature(x::DotSortAtom)
return ConvexVexity()
end
function evaluate(x::DotSortAtom)
return sum(sort(vec(evaluate(x.children[1])), rev=true) .* sort(vec(x.w), rev=true))
end
function conic_form!(x::DotSortAtom, unique_conic_forms::UniqueConicForms)
if !has_conic_form(unique_conic_forms, x)
y = x.children[1]
w = x.w
if all(size(y)) > 1
y = vec(y)
end
mu = Variable(size(y))
nu = Variable(size(y))
onesvec = ones(size(y))
# given by the solution to
# minimize sum(mu) + sum(nu)
# subject to y*w' <= onesvec*nu' + mu*onesvec'
objective = conic_form!(sum(mu) + sum(nu), unique_conic_forms)
conic_form!(y*w' <= onesvec*nu' + mu*onesvec', unique_conic_forms)
cache_conic_form!(unique_conic_forms, x, objective)
end
return get_conic_form(unique_conic_forms, x)
end
dot_sort(a::AbstractExpr, b::Value) = DotSortAtom(a, b)
dot_sort(b::Value, a::AbstractExpr) = DotSortAtom(a, b)