/
reshape.jl
39 lines (30 loc) · 974 Bytes
/
reshape.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import Base.reshape, Base.vec
struct ReshapeAtom <: AbstractExpr
head::Symbol
id_hash::UInt64
children::Tuple{AbstractExpr}
size::Tuple{Int, Int}
function ReshapeAtom(x::AbstractExpr, m::Int, n::Int)
if m * n != length(x)
error("Cannot reshape expression of size $(x.size) to ($(m), $(n))")
end
return new(:reshape, objectid(x), (x,), (m, n))
end
end
function sign(x::ReshapeAtom)
return sign(x.children[1])
end
function monotonicity(x::ReshapeAtom)
return (Nondecreasing(),)
end
function curvature(x::ReshapeAtom)
return ConstVexity()
end
function evaluate(x::ReshapeAtom)
return reshape(evaluate(x.children[1]), x.size[1], x.size[2])
end
function conic_form!(x::ReshapeAtom, unique_conic_forms::UniqueConicForms)
return conic_form!(x.children[1], unique_conic_forms)
end
reshape(x::AbstractExpr, m::Int, n::Int) = ReshapeAtom(x, m, n)
vec(x::AbstractExpr) = reshape(x, length(x), 1)