/
expressions.jl
108 lines (92 loc) · 3.5 KB
/
expressions.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#############################################################################
# expressions.jl
# Defines AbstractExpr, which is subtyped by all atoms
# Each type which subtypes AbstractExpr (Variable and Constant being exceptions)
# must have:
#
## head::Symbol -- a symbol such as :vecnorm, :+ etc
## children::(AbstractExpr,) -- The expressions on which the current expression
## -- is operated
## id_hash::Uint64 -- identifier hash, can be a hash of children
## or a unique identifier of the object
## size::(Int, Int) -- size of the resulting expression
#
# Constants and variables do not have children.
#
# In addition, each atom must implement the following functions:
## sign: returns the sign of the result of the expression
## monotonicity: The monotonicity of the arguments with respect to the function
## i.e if the argument is nondecreasing, will the function be nonincreasing
## or nondecreasing? eg. negate(x) will have Nonincreasing monotonicity
## evaluate: Evaluates the value of the expression, assuming the problem has been
## solved.
## curvature: If h(x)=f∘g(x), then (for single variable calculus)
## h''(x) = g'(x)^T f''(g(x)) g'(x) + f'(g(x))g''(x)
## curvature refers to the curvature of the first term.
## We then use this curvature to find vexity of h (see vexity function below)
## conic_form!: TODO: Fill this in after conic_form! is stable
#
#############################################################################
import Base.sign, Base.size, Base.length, Base.endof, Base.ndims, Base.zero, Base.convert
export AbstractExpr, Constraint
export vexity, sign, size, evaluate, monotonicity, curvature, zero, length, convert
export conic_form!
export endof, ndims
export Value, ValueOrNothing
export get_vectorized_size
### Abstract types
abstract AbstractExpr
abstract Constraint
# If h(x)=f∘g(x), then (for single variable calculus)
# h''(x) = g'(x)^T f''(g(x)) g'(x) + f'(g(x))g''(x)
# We calculate the vexity according to this
function vexity(x::AbstractExpr)
monotonicities = monotonicity(x)
vex = curvature(x)
for i = 1:length(x.children)
vex += monotonicities[i] * vexity(x.children[i])
end
return vex
end
# This function should never be reached
function monotonicity(x::AbstractExpr)
error("monotonicity not implemented for $(x.head).")
end
# This function should never be reached
function curvature(x::AbstractExpr)
error("curvature not implemented for $(x.head).")
end
# This function should never be reached
function evaluate(x::AbstractExpr)
error("evaluate not implemented for $(x.head).")
end
# This function should never be reached
function sign(x::AbstractExpr)
error("sign not implemented for $(x.head).")
end
function size(x::AbstractExpr)
return x.size
end
function length(x::AbstractExpr)
return prod(x.size)
end
zero(x::AbstractExpr) = Constant(zeros(Float64, size(x)))
### User-defined Unions
Value = Union(Number, AbstractArray)
ValueOrNothing = Union(Value, Nothing)
AbstractExprOrValue = Union(AbstractExpr, Value)
convert(::Type{AbstractExpr}, x::Value) = Constant(x)
convert(::Type{AbstractExpr}, x::AbstractExpr) = x
### Indexing Utilities
endof(x::AbstractExpr) = x.size[1] * x.size[2]
function size(x::AbstractExpr, dim::Integer)
if dim < 1
error("dimension out of range")
elseif dim > 2
return 1
else
return size(x)[dim]
end
end
ndims(x::AbstractExpr) = 2
get_vectorized_size(x::AbstractExpr) = reduce(*, size(x))