/
flatten.jl
162 lines (139 loc) · 4.15 KB
/
flatten.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
## Flatten, flatten nested operations that are associative
# eg (a + b + (c + d)) --> (a + b + c + d)
# FlatT because Flat is already a symbol
#typealias FlatT Union{Mxpr{:Plus},Mxpr{:Times},Mxpr{:And},Mxpr{:Or}, Mxpr{:LCM}, Mxpr{:GCD} }
# We do not rely only on the Flat attribute. We use FlatT in the hope that Julia compiles
# efficient code for each type in the Union.
# Might be faster to interleave the terms because they may need less
# ordering then. As it is, this is very fast for flattening two long
# sums. We tried this:
# m1 = Apply(Plus,Table(x^i,[i,1000]));
# m2 = Apply(Plus,Table(x^i,[i,1000])); then
# m3 = mxpr(:Plus,symval(:m1),symval(:m2)) flatten!(m3) --> 6e-5s
# canonexpr!(m3) (ie ordering) --> 0.011s running canonexpr! again
# (i.e. it is already sorted and combined) takes 0.00925 s, that is
# only a little bit faster. So maybe optimizing here is not worth
# anything at the moment. At Symata cli, m3 = m1 + m2 --> 0.04 s.
# calling setfixed() after canonexpr cuts this time to 0.01s But, this
# cannot be done, in general. Maxima generates the two sums much more
# slowly but adds them much more quickly.
# Flatten one level only
function flatten!{T}(mx::T)
needsflat::Bool = false
for x in margs(mx)
if isa(x,T)
needsflat = true
break
end
end
needsflat == false && return mx
na = newargs()
for x in margs(mx)
if isa(x,T)
for y in margs(x)
push!(na,y)
end
else
push!(na,x)
end
end
return mxpr(mhead(mx),na)
end
# Flatten one level only if the argument has the Flat attribute
maybeflatten!(mx::FlatT) = flatten!(mx)
maybeflatten!(x) = x
# Flatten to all levels
function flatten_recursive!{T}(mx::T)
needsflat::Bool = false
for x in margs(mx)
if isa(x,T)
needsflat = true
break
end
end
needsflat == false && return mx
na = newargs()
for x in margs(mx)
if isa(x,T)
for y in margs(flatten_recursive!(x))
push!(na,y)
end
else
push!(na,x)
end
end
return mxpr(mhead(mx),na)
end
type FlattenData
level::Int
maxlevel::Int
head::Symbol # We may want type Any, for heads that are not Symbols
end
# Flatten from level 1 to level n
function flatten_recursive!{T}(mx::Mxpr{T}, n::Integer)
d = FlattenData(1,n,T)
_flatten_recursive!(mx,d)
end
# Flatten expressions with head headtype from level 1 to level n
function flatten_recursive!(mx::Mxpr, n::Integer, headtype::Symbol)
d = FlattenData(1,n,headtype)
_flatten_recursive!(mx,d)
end
function _flatten_recursive!(mx::Mxpr, d::FlattenData)
needsflat::Bool = false
for x in margs(mx)
if is_Mxpr(x,d.head)
needsflat = true
break
end
end
needsflat == false && return mx
na = newargs()
for x in margs(mx)
if is_Mxpr(x,d.head)
if d.level >= d.maxlevel
for y in margs(x)
push!(na,y)
end
else
d.level += 1
for y in margs(_flatten_recursive!(x,d))
push!(na,y)
end
d.level -=1
end
else
push!(na,x)
end
end
return mxpr(mhead(mx),na)
end
# TODO: implement the "transpose" case
@mkapprule Flatten :nodefault => true
@sjdoc Flatten """
Flatten(expr)
remove braces from `Lists` at all levels of `expr`.
Flatten(expr,n)
flatten only down to level `n`.
Flatten(expr,n,h)
flatten only expressions with head `h`.
"""
@doap Flatten(x::Mxpr) = flatten_recursive!(x)
@doap Flatten(x) = x
@doap function Flatten(x::Mxpr, n::Integer)
n == 0 && return x
flatten_recursive!(x,n)
end
@doap function Flatten(x::Mxpr, n, headtype::Symbol)
n == 0 && return x
n = (n == Infinity ? typemax(Int) : n)
flatten_recursive!(x,n, headtype)
end
@doap function Flatten(x::Mxpr, i::Mxpr{:DirectedInfinity})
i[1] != 1 && return x
flatten_recursive!(x)
end
@doap function Flatten(x::Mxpr,spec::ListT)
warn("Flatten: unimplemented feature")
mx
end