Skip to content

Commit

Permalink
demat expression are now fast
Browse files Browse the repository at this point in the history
  • Loading branch information
kk49 committed May 10, 2012
1 parent 2a782f0 commit 6a1338f
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 204 deletions.
28 changes: 9 additions & 19 deletions demat_base.jl
Expand Up @@ -43,37 +43,27 @@ de_promote(x,xs...) = tuple(de_promote(x)...,de_promote(xs...)...)

typealias BinOpParams Union((DeEle,Number),(Number,DeEle),(DeEle,DeEle))

type DeOpNull end
type DeOpAdd end
type DeOpSub end
type DeOpMulEle end
type DeOpDivEle end

const deBinOpList = (:+,:-,:.*,:./);

function de_op_to_type(op::Symbol)
if op == :+ return DeOpAdd
elseif op == :- return DeOpSub
elseif op == :.* return DeOpMulEle
elseif op == :./ return DeOpDivEle
else return DeOpNull
end
end
de_op_to_type = dict((:+,),(DeOpAdd,))
de_op_to_type[:-] = DeOpSub
de_op_to_type[:.*] = DeOpMulEle
de_op_to_type[:./] = DeOpDivEle

function de_op_to_scaler(op::Symbol)
if op == :+ return :+
elseif op == :- return :-
elseif op == :.* return :*
elseif op == :./ return :/
else return DeOpNull
end
end
de_op_to_scaler = dict((:+,),(:+,))
de_op_to_scaler[:-] = :-
de_op_to_scaler[:.*] = :*
de_op_to_scaler[:./] = :/

for op = deBinOpList
opType = de_op_to_type(op);
opType = de_op_to_type[op];
@eval ($op)(a::DeEle,b::Number) = DeBinOp{$opType}(de_promote(a,b)...)
@eval ($op)(a::Number,b::DeEle) = DeBinOp{$opType}(de_promote(a,b)...)
@eval ($op)(a::DeEle,b::DeEle) = DeBinOp{$opType}(de_promote(a,b)...)
#@eval ($op)(ps...::BinOpParams) = DeBinOp($op,de_promote(ps...)...)
end

81 changes: 31 additions & 50 deletions demat_be_julia.jl
Expand Up @@ -16,21 +16,21 @@ size(a::DeArrJulia,dim) = size(a.data,dim)
typealias DeVecJ{T} DeArrJulia{T,1}
typealias DeMatJ{T} DeArrJulia{T,2}

function de_check_dims(a::DeConst)
function de_jl_check_dims(a::DeConst)
()
end

function de_check_dims(a::DeReadOp)
function de_jl_check_dims(a::DeReadOp)
size(a.p1)
end

function de_check_dims(a::DeUniOp)
de_check_dims(a.p1)
function de_jl_check_dims(a::DeUniOp)
de_jl_check_dims(a.p1)
end

function de_check_dims(a::DeBinOp)
r1 = de_check_dims(a.p1)
r2 = de_check_dims(a.p2)
function de_jl_check_dims(a::DeBinOp)
r1 = de_jl_check_dims(a.p1)
r2 = de_jl_check_dims(a.p2)

if length(r1) == 0
return r2
Expand All @@ -43,21 +43,20 @@ function de_check_dims(a::DeBinOp)
end
end

# de_eval returns a 3-tuple that contains
# de_jl_eval returns a 3-tuple that contains
# symbol that contains the value of the extression
# the quoted preable code
# the quoted kernal code
#de_eval(a,idxSym) = de_eval(a,idxSym,x->x)

function de_eval(a::DeConst,idxSym)
function de_jl_eval(a::DeConst,idxSym)
@gensym r
( r
, quote ($r) = ($(a.p1)) end
, quote end
)
end

function de_eval(a::DeReadOp,idxSym)
function de_jl_eval(a::DeReadOp,idxSym)
@gensym r src
( r
, quote ($src) = ($(a.p1.data)) end
Expand All @@ -74,44 +73,38 @@ end
# @eval function de_do_op(T::DeBinOp{$opType},a,b) ($opS)(a,b) end
#end

function de_do_op(T::DeBinOp{DeOpAdd},a,b) +(a,b) end
function de_do_op(T::DeBinOp{DeOpMulEle},a,b) *(a,b) end
function de_jl_eval(v::DeBinOp,idxSym)
@gensym r
p1 = de_jl_eval(v.p1,idxSym)
p2 = de_jl_eval(v.p2,idxSym)
preamble = quote $(p1[2]);$(p2[2]) end
kernel = quote $(p1[3]);$(p2[3]);($r) = de_jl_do_op($v,$(p1[1]),$(p2[1])) end
( r
, preamble
, kernel
)
end

for op = deBinOpList
opType = de_op_to_type(op);
opSingle = de_op_to_scaler(op);
println(opSingle)
@eval function de_eval(v::DeBinOp{$opType},idxSym)
@gensym r
p1 = de_eval(v.p1,idxSym)
p2 = de_eval(v.p2,idxSym)
preamble = quote $(p1[2]);$(p2[2]) end
#kernel = quote $(p1[3]);$(p2[3]);($r) = ($($opSingle))($(p1[1]),$(p2[1])) end
kernel = quote $(p1[3]);$(p2[3]);($r) = ($($op))($(p1[1]),$(p2[1])) end
#kernel = quote $(p1[3]);$(p2[3]);($r) = de_do_op($v,$(p1[1]),$(p2[1])) end
#kernel = quote $(p1[3]);$(p2[3]);($r) = +($(p1[1]),$(p2[1])) end
#kernel = quote end
( r
, preamble
, kernel
)
end
opType = de_op_to_type[op];
opSingle = de_op_to_scaler[op];
@eval de_jl_do_op(v::DeBinOp{$opType},a,b) = ($opSingle)(a,b)
#@eval function de_jl_do_op(v::DeBinOp{$opType},a,b) = ($opSingle)(a,b) end # does not work?
end

function assign(lhs::DeArrJulia,rhs::DeExpr)
#println("Delayed Expression Setup Time:")
rhsSz = de_check_dims(rhs)
function assign(lhs::DeVecJ,rhs::DeExpr)
rhsSz = de_jl_check_dims(rhs)
lhsSz = size(lhs)

if rhsSz != lhsSz
error("src & dst size does not match. NOT IMPLEMENTED FOR SCALARS FIX")
end

@gensym i hiddenFunc
(rhsResult,rhsPreamble,rhsKernel) = de_eval(rhs,i)
@gensym i
(rhsResult,rhsPreamble,rhsKernel) = de_jl_eval(rhs,i)
rhsType = typeof(rhs);

ex = quote function ($hiddenFunc)(plhs::DeArrJulia,prhs::($rhsType))
@eval function hiddenFunc(plhs::DeVecJ,prhs::($rhsType))
N = size(plhs,1)
lhsData = plhs.data
$rhsPreamble
Expand All @@ -120,20 +113,8 @@ function assign(lhs::DeArrJulia,rhs::DeExpr)
lhsData[($i)] = ($rhsResult)
end
end
end

#println("---- rhsResult ----")
#println(rhsResult)
#println("---- rhsPreamble----")
#println(rhsPreamble)
#println("---- rhsKernel ----")
#println(rhsKernel)
#println()
#println(ex)

eval(ex)
hf = eval(hiddenFunc)
@time hf(lhs,rhs)
hiddenFunc(lhs,rhs)
end

assign(lhs::DeArrJulia,rhs::DeEle) = assign(lhs,de_promote(rhs)...)
Expand Down
158 changes: 23 additions & 135 deletions demat_demo.jl
Expand Up @@ -16,149 +16,37 @@ function demat_test()
cd = DeVecJ{Float64}(c)
dd = DeVecJ{Float64}(d)

r1 = 0
for i = 1:1
println("-------------------")
println("#1 Delayed Expression:")
@time ad[] = bd+cd.*dd + 1.0
#@time ad[] = bd + 1.0

println("#2 Standard Julia Vector:")
@time a = b+c.*d + 1.0
#@time a = b + 1.0

println("#3 Standard Julia For Loop:")
@time for j = 1:N
tN = 1
println("-------------------")
println("#1 Delayed Expression:")
t1 = @elapsed for i = 1:tN ad[] = bd+cd.*dd + 1.0 end
#@time ad[] = bd + 1.0
println("Elapsed time: ",t1)

println("#2 Standard Julia Vector:")
t2 = @elapsed for i = 1:tN a = b+c.*d + 1.0 end
#@time a = b + 1.0
println("Elapsed time: ",t2)

println("#3 Standard Julia For Loop:")
t3 = @elapsed for i = 1:tN
for j = 1:N
a[j] = b[j] + c[j] * d[j] + 1.0
#a[j] = b[j] + 1.0
end

println()
println("error(sum((#3 - #1).^2) / abs(sum(#3)) == ",sum((a-ad.data).^2) / sum(a))
end

r1
end

#demat_test()

function simple_test()
#--------------------------------------
function test1(x::Array{Float64})
local s = 0
local i
for i = 1:1000000
s += x[i] * i
end
s
end
#--------------------------------------
@gensym ns ti
@eval function test2(x::Array{Float64})
($ns) = x;
local s = 0
local ($ti)
for ($ti) = 1:1000000
s += ($ns)[($ti)] * ($ti)
end
s
end
#--------------------------------------
function extract_x3(x)
@gensym r

(r,quote ($r) = ($x) end)
end

function test3(x::Array{Float64})
@gensym ti

(rv,ex) = extract_x3(x)

@eval function hf()
local s = 0
local ($ti)
$ex
for ($ti) = 1:1000000
s += ($rv)[($ti)] * ($ti)
end
s
end

eval(hf)() #if eval is not here it returns the results for the last call to test3
end
#--------------------------------------
function extract_x4(xi,idx)
@gensym r src

(r,quote ($src) = ($x) end,quote ($r) = ($src)[($idx)] end)
end
println("Elapsed time: ",t3)

function test4(x::Array{Float64})
@gensym ti

(rv,ex,bd) = extract_x4(x,ti)

@eval function hf()
local s = 0
local ($ti)
$ex
for ($ti) = 1:1000000
$bd
s += ($rv) * ($ti)
end
s
end

eval(hf)() #if eval is not here it returns the results for the last call to test3
end
#--------------------------------------
error = sum((a-ad.data).^2) / sum(a)

x = randn(1000000)
n = 20

local s1,s2,s3,s4

t1time = @elapsed for i = 1:n s1 = test1(x) end
t2time = @elapsed for i = 1:n s2 = test2(x) end
t3time = @elapsed for i = 1:n s3 = test3(x) end
t4time = @elapsed for i = 1:n s4 = test4(x) end
println()
println("Estimated overhead per expression == ",(t1-t3)/tN)

println(" t1time: ",t1time," s: ",s1)
println(" t2time: ",t2time," s: ",s2)
println(" t3time: ",t3time," s: ",s3)
println(" t4time: ",t4time," s: ",s4)
println()
end
println("error(sum((#3 - #1).^2) / abs(sum(#3)) == ",error)

function stest()
@gensym test1 test2

@eval function ($test1)(a)
s = 0
for i = 1:size(a,1)
s += *(a[i],i)
end
s
end
error
end

@gensym op
op = de_op_to_scaler(:.*)
@eval function ($test2)(a)
s = 0
for i = 1:size(a,1)
s += ($op)(a[i],i)
end
s
end

x = randn(1000000)
N = 1
local s1,s2
t1 = @elapsed for i = 1:N s1 = eval(test1)(x) end
t2 = @elapsed for i = 1:N s2 = eval(test2)(x) end
demat_test()

println("test func 1: time = ",t1," result = ",s1);
println("test func 2: time = ",t2," result = ",s2);

end

0 comments on commit 6a1338f

Please sign in to comment.