You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TLDR: Performance is a slight pain (seems broadcasting) right now, but it is very straightforward to support these once the functionality is available in Base
EDIT: Code updated to work for Lux 0.4.*
Trial 1: From the Usage Example
using Lux, Random, Functors
make_immutable(x::AbstractArray) =ImmutableArray(copy(x))
make_immutable(x) = x
# Construct the layer
model =Chain(BatchNorm(128), Dense(128, 256, tanh), BatchNorm(256),
Chain(Dense(256, 1, tanh), Dense(1, 10)))
# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable =fmap(make_immutable, ps)
st_immutable =fmap(make_immutable, st)
# Dummy Input
x =randn(Float32, 128, 1024)
x_immutable =make_immutable(x)
# Run the model@benchmark$model($x, $ps, $st)
@benchmark$model($x_immutable, $ps_immutable, $st_immutable)
Standard Abstract Arrays
BenchmarkTools.Trial:1296 samples with 1 evaluation.
Range (min … max):2.125 ms …26.658 ms ┊ GC (min … max):0.00%…0.00%
Time (median):3.096 ms ┊ GC (median):0.00%
Time (mean ± σ):3.836 ms ±2.313 ms ┊ GC (mean ± σ):2.58%±7.71%
▂█
▆▄██▇▆▄▄▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▂▂▁▂▂▂▂▁▁▂▂▂▂▂▁▂▂▂▂▁▂▂▂ ▃
2.13 ms Histogram: frequency by time 14.1 ms <
Memory estimate:3.60 MiB, allocs estimate:144.
Immutable Arrays
BenchmarkTools.Trial:41 samples with 1 evaluation.
Range (min … max):107.855 ms …159.665 ms ┊ GC (min … max):3.98%…2.64%
Time (median):119.911 ms ┊ GC (median):3.54%
Time (mean ± σ):123.706 ms ±10.746 ms ┊ GC (mean ± σ):3.54%±0.67%
▂█▄
▄▁▁▁▁▁▁▁▄▆▄█████▄▁▄▆▄▆▁▁▄▁▁▄▁▁▁▁▁▁▄▁▁▁▁▄▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▆ ▁
108 ms Histogram: frequency by time 160 ms <
Memory estimate:58.32 MiB, allocs estimate:3418558.
Trial 2: Only a Dense Layer
# Construct the layer
model =Dense(128, 256)
# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable =fmap(make_immutable, ps);
st_immutable =fmap(make_immutable, st);
# Dummy Input
x =randn(Float32, 128, 1024);
x_immutable =make_immutable(x);
# Run the model@benchmark$model($x, $ps, $st)
@benchmark$model($x_immutable, $ps_immutable, $st_immutable)
Standard Abstract Arrays
BenchmarkTools.Trial:4469 samples with 1 evaluation.
Range (min … max):483.810 μs …30.894 ms ┊ GC (min … max):0.00%…0.00%
Time (median):716.669 μs ┊ GC (median):0.00%
Time (mean ± σ):1.100 ms ±1.501 ms ┊ GC (mean ± σ):5.01%±12.19%
█▆▆▅▄▃▂▂▂▂▃▃▃▂▁ ▁
█████████████████▇▇▇▆▇▆▅▅▃▃▄▅▅▄▃▅▁▁▆▄▅▁▃▃▃▃▅▁▃▃▃▃▁▃▁▁▃▁▁▁▁▃▅ █
484 μs Histogram:log(frequency) by time 7.69 ms <
Memory estimate:2.00 MiB, allocs estimate:4.
Immutable Arrays
BenchmarkTools.Trial:259 samples with 1 evaluation.
Range (min … max):15.392 ms …52.229 ms ┊ GC (min … max):0.00%…0.00%
Time (median):17.997 ms ┊ GC (median):0.00%
Time (mean ± σ):19.327 ms ±4.194 ms ┊ GC (mean ± σ):1.72%±4.44%
▃▆█ ▂
▃▆███▆█▇▅▇▇▄▆▃▆▄▄▅▄▄▄▄▄▃▄▄▃▂▁▃▃▂▁▃▂▁▁▂▂▁▂▂▂▁▃▁▃▂▂▁▁▁▂▂▁▂▂▁▂ ▃
15.4 ms Histogram: frequency by time 32.6 ms <
Memory estimate:7.00 MiB, allocs estimate:262153.
Seems like there is a lot of time being spent on broadcasting the bias (seems like a problem with broadcasting in general)
julia>@benchmark$ps_immutable.weight *$x_immutable
BenchmarkTools.Trial:4032 samples with 1 evaluation.
Range (min … max):346.287 μs …51.079 ms ┊ GC (min … max):0.00%…0.00%
Time (median):540.489 μs ┊ GC (median):0.00%
Time (mean ± σ):1.224 ms ±1.854 ms ┊ GC (mean ± σ):2.36%±8.18%
█▆▄▄▃▁▁▁ ▂▂▁▁▁▂▂▁▁ ▁▁ ▁
█████████████████████████▇▇▇▆▇▆▇▆▆▃▆▆▆▅▅▅▅▄▅▅▅▆▅▅▅▅▅▅▄▃▁▁▁▃▃ █
346 μs Histogram:log(frequency) by time 8.78 ms <
Memory estimate:1.00 MiB, allocs estimate:5.
julia>@benchmark$ps_immutable.weight *$x_immutable .+$ps_immutable.bias
BenchmarkTools.Trial:338 samples with 1 evaluation.
Range (min … max):11.177 ms …33.105 ms ┊ GC (min … max):0.00%…0.00%
Time (median):13.699 ms ┊ GC (median):0.00%
Time (mean ± σ):14.792 ms ±3.901 ms ┊ GC (mean ± σ):2.43%±5.87%
█▃
▅██▇▇▅▅▇▅▇▅▅▄▅▅▄▃▃▄▄▃▂▃▃▁▂▃▁▃▂▃▃▃▁▃▂▂▂▁▃▁▁▂▂▂▂▁▁▁▁▁▁▁▁▁▃▂▁▂ ▃
11.2 ms Histogram: frequency by time 30.9 ms <
Memory estimate:7.00 MiB, allocs estimate:262153.
Trial 3: No broadcasting
model =Dense(128, 256; bias=false)
# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable =fmap(make_immutable, ps);
st_immutable =fmap(make_immutable, st);
# Run the model@benchmark$model($x, $ps, $st)
@benchmark$model($x_immutable, $ps_immutable, $st_immutable)
Standard Abstract Arrays
BenchmarkTools.Trial:5501 samples with 1 evaluation.
Range (min … max):295.161 μs …23.801 ms ┊ GC (min … max):0.00%…0.00%
Time (median):451.402 μs ┊ GC (median):0.00%
Time (mean ± σ):899.925 μs ±1.386 ms ┊ GC (mean ± σ):3.10%±8.68%
█▆▆▄▃▂▁▂▁▁▁▂▂▂▂▁ ▁ ▁
██████████████████▇█▇█▇▇▆▆▇▇▆▆▆▆▆▆▅▅▅▆▅▅▁▆▄▆▅▃▅▄▅▄▆▄▅▁▄▆▅▅▃▅ █
295 μs Histogram:log(frequency) by time 6.98 ms <
Memory estimate:1.00 MiB, allocs estimate:2.
Immutable Arrays
BenchmarkTools.Trial:5303 samples with 1 evaluation.
Range (min … max):311.574 μs …26.953 ms ┊ GC (min … max):0.00%…0.00%
Time (median):436.316 μs ┊ GC (median):0.00%
Time (mean ± σ):930.509 μs ±1.488 ms ┊ GC (mean ± σ):3.23%±8.75%
█▆▅▃▂▁ ▁▁▂▁▁ ▁
█████████████████▆█▇▇▆▆▆▆▆▆▆▆▆▅▅▅▅▅▅▄▄▅▅▅▅▂▅▂▄▅▄▅▄▄▃▂▃▄▄▂▃▂▃ █
312 μs Histogram:log(frequency) by time 7.61 ms <
Memory estimate:1.00 MiB, allocs estimate:5.
Trial 4
model =Chain(Dense(128, 256; bias=false), Chain(Dense(256, 512; bias=false),
Dense(512, 10; bias=false)))
# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable =fmap(make_immutable, ps);
st_immutable =fmap(make_immutable, st);
# Run the model@benchmark$model($x, $ps, $st)
@benchmark$model($x_immutable, $ps_immutable, $st_immutable)
Standard Abstract Arrays
BenchmarkTools.Trial:1372 samples with 1 evaluation.
Range (min … max):1.380 ms …49.871 ms ┊ GC (min … max):0.00%…0.00%
Time (median):2.918 ms ┊ GC (median):0.00%
Time (mean ± σ):3.615 ms ±3.116 ms ┊ GC (mean ± σ):2.42%±7.94%
▅█ ▃
███▇▆▇██▇▆▅▄▄▄▃▃▃▃▂▃▃▃▂▃▂▃▂▂▂▂▁▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▂▂ ▃
1.38 ms Histogram: frequency by time 15.8 ms <
Memory estimate:3.04 MiB, allocs estimate:6.
Immutable Arrays
BenchmarkTools.Trial:894 samples with 1 evaluation.
Range (min … max):1.505 ms …66.104 ms ┊ GC (min … max):0.00%…0.00%
Time (median):4.153 ms ┊ GC (median):0.00%
Time (mean ± σ):5.561 ms ±5.432 ms ┊ GC (mean ± σ):1.87%±7.54%
█▆▅▅▅▄▅▆▆▅▄▄▂▂▂▂▁ ▁ ▁ ▁
█████████████████▇█▆███▆▇█▅▆▇███▆▇█▄▇▇▇▅▄▆▅▅▁▄▁▆▄▁▅▇▅▄▄▆▁▅ █
1.5 ms Histogram:log(frequency) by time 23.1 ms <
Memory estimate:3.04 MiB, allocs estimate:17.
I think the poor broadcasting performance likely has to do with some missed chance to perform our memory optimization in the broadcast logic somewhere (i.e., we think it is unsafe to optimize in a place where it's actually safe to do so). I will take a look into this when I get a chance, but thanks for putting this together and providing a nice, realistic benchmark for performance going forward!
Testing out the Immutable Arrays from JuliaLang/julia#44381 with #7
TLDR: Performance is a slight pain (seems broadcasting) right now, but it is very straightforward to support these once the functionality is available in Base
EDIT: Code updated to work for Lux 0.4.*
Trial 1: From the Usage Example
Standard Abstract Arrays
Immutable Arrays
Trial 2: Only a Dense Layer
Standard Abstract Arrays
Immutable Arrays
Seems like there is a lot of time being spent on broadcasting the bias (seems like a problem with broadcasting in general)
Trial 3: No broadcasting
Standard Abstract Arrays
Immutable Arrays
Trial 4
Standard Abstract Arrays
Immutable Arrays
cc @ChrisRackauckas @ianatol @aviatesk
The text was updated successfully, but these errors were encountered: