In [1]:
# 一連の行列を引数としてとり、最適な順番で乗算する
function fastmatmul(args::AbstractMatrix...)
    length(args) ≤ 1 && return *(args...)
    sizes = size.(args)
    if !all(sizes[i][2] == sizes[i+1][1] for i in 1:length(sizes)-1)
         throw(ArgumentError("matrix dimensions mismatch"))
    end
    partcost = Dict{Tuple{Int,Int}, Tuple{Int, Int}}()
    from, to = 1, length(sizes)
    solvemul(sizes, partcost, from, to)
    domul(args, partcost, from, to)
end

# 最適な乗算順序を見つける
function solvemul(sizes, partcost, from, to)
    if from == to
        partcost[(from, to)] = (0, from)
        return
    end
    mincost = typemax(Int)
    minj = -1
    for j in from:to-1
        haskey(partcost, (from, j)) || solvemul(sizes, partcost, from, j)
        haskey(partcost, (j+1, to)) || solvemul(sizes, partcost, j+1, to)
        curcost = sizes[from][1]*sizes[j][2]*sizes[to][2] +
                  partcost[(from, j)][1] + partcost[(j+1, to)][1]
        if curcost < mincost
            minj = j
            mincost = curcost
        end
    end
    partcost[(from, to)] = (mincost, minj)
end

# 事前に計算した最適な順番で乗算を行う
function domul(args, partcost, from, to)
    from == to && return args[from]
    from+1 == to && return args[from]*args[to]
    j = partcost[(from, to)][2]
    domul(args, partcost, from, j) * domul(args, partcost, j+1, to)
end


domul (generic function with 1 method)

In [2]:
using BenchmarkTools
A = ones(5, 5000);
B = ones(5000, 5);

┌ Info: Recompiling stale cache file /home/nakada/.julia/compiled/v1.2/BenchmarkTools/ZXPQo.ji for BenchmarkTools [6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf]
└ @ Base loading.jl:1240


In [3]:
@btime *(repeat([A,B], outer=10)...);

  635.300 μs (33 allocations: 1.72 MiB)


In [4]:
@btime fastmatmul(repeat([A,B], outer=10)...);

  401.600 μs (43 allocations: 52.25 KiB)


In [5]:
macro fastmatmul(ex::Expr)
    ex.head == :call || throw(ArgumentError("expression must be a call"))
    ex.args[1] == :(*) || throw(ArgumentError("only multiplication is allowed"))
    new_ex = deepcopy(ex)
    new_ex.args[1] = :fastmatmul
    esc(new_ex)
end


@fastmatmul (macro with 1 method)

In [6]:
@fastmatmul ones(2,3)*ones(3,4)*ones(4,5)

2×5 Array{Float64,2}:
 12.0  12.0  12.0  12.0  12.0
 12.0  12.0  12.0  12.0  12.0