/
diagm.jl
69 lines (57 loc) · 2 KB
/
diagm.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#############################################################################
# diagm.jl
# Converts a vector of size n into an n x n diagonal
# All expressions and atoms are subtpyes of AbstractExpr.
# Please read expressions.jl first.
#############################################################################
import LinearAlgebra.diagm, LinearAlgebra.Diagonal
struct DiagMatrixAtom <: AbstractExpr
head::Symbol
id_hash::UInt64
children::Tuple{AbstractExpr}
size::Tuple{Int, Int}
function DiagMatrixAtom(x::AbstractExpr)
(num_rows, num_cols) = x.size
if num_rows == 1
sz = num_cols
elseif num_cols == 1
sz = num_rows
else
throw(ArgumentError("Only vectors are allowed for diagm/Diagonal. Did you mean to use diag?"))
end
children = (x, )
return new(:diagm, hash(children), children, (sz, sz))
end
end
function sign(x::DiagMatrixAtom)
return sign(x.children[1])
end
# The monotonicity
function monotonicity(x::DiagMatrixAtom)
return (Nondecreasing(),)
end
# If we have h(x) = f o g(x), the chain rule says h''(x) = g'(x)^T f''(g(x))g'(x) + f'(g(x))g''(x);
# this represents the first term
function curvature(x::DiagMatrixAtom)
return ConstVexity()
end
function evaluate(x::DiagMatrixAtom)
return Diagonal(vec(evaluate(x.children[1])))
end
function diagm((d, x)::Pair{<:Integer, <:AbstractExpr})
d == 0 || throw(ArgumentError("only the main diagonal is supported"))
return DiagMatrixAtom(x)
end
Diagonal(x::AbstractExpr) = DiagMatrixAtom(x)
function conic_form!(x::DiagMatrixAtom, unique_conic_forms::UniqueConicForms)
if !has_conic_form(unique_conic_forms, x)
sz = x.size[1]
I = 1:sz+1:sz*sz
J = 1:sz
coeff = sparse(I, J, 1.0, sz * sz, sz)
objective = conic_form!(x.children[1], unique_conic_forms)
new_obj = coeff * objective
cache_conic_form!(unique_conic_forms, x, new_obj)
end
return get_conic_form(unique_conic_forms, x)
end