Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Immutable Arrays #8

Open
avik-pal opened this issue Apr 21, 2022 · 3 comments
Open

Immutable Arrays #8

avik-pal opened this issue Apr 21, 2022 · 3 comments
Labels
enhancement New feature or request

Comments

@avik-pal
Copy link
Member

avik-pal commented Apr 21, 2022

Testing out the Immutable Arrays from JuliaLang/julia#44381 with #7

TLDR: Performance is a slight pain (seems broadcasting) right now, but it is very straightforward to support these once the functionality is available in Base

EDIT: Code updated to work for Lux 0.4.*

Trial 1: From the Usage Example

using Lux, Random, Functors

make_immutable(x::AbstractArray) = ImmutableArray(copy(x))
make_immutable(x) = x

# Construct the layer
model = Chain(BatchNorm(128), Dense(128, 256, tanh), BatchNorm(256),
                        Chain(Dense(256, 1, tanh), Dense(1, 10)))

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps)
st_immutable = fmap(make_immutable, st)

# Dummy Input
x = randn(Float32, 128, 1024)
x_immutable = make_immutable(x)

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 1296 samples with 1 evaluation.
 Range (min  max):  2.125 ms  26.658 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     3.096 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.836 ms ±  2.313 ms  ┊ GC (mean ± σ):  2.58% ± 7.71%

    ▂█                                                        
  ▆▄██▇▆▄▄▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▂▂▁▂▂▂▂▁▁▂▂▂▂▂▁▂▂▂▂▁▂▂▂ ▃
  2.13 ms        Histogram: frequency by time        14.1 ms <

 Memory estimate: 3.60 MiB, allocs estimate: 144.

Immutable Arrays

BenchmarkTools.Trial: 41 samples with 1 evaluation.
 Range (min  max):  107.855 ms  159.665 ms  ┊ GC (min  max): 3.98%  2.64%
 Time  (median):     119.911 ms               ┊ GC (median):    3.54%
 Time  (mean ± σ):   123.706 ms ±  10.746 ms  ┊ GC (mean ± σ):  3.54% ± 0.67%

              ▂█▄                                                
  ▄▁▁▁▁▁▁▁▄▆▄█████▄▁▄▆▄▆▁▁▄▁▁▄▁▁▁▁▁▁▄▁▁▁▁▄▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▆ ▁
  108 ms           Histogram: frequency by time          160 ms <

 Memory estimate: 58.32 MiB, allocs estimate: 3418558.

Trial 2: Only a Dense Layer

# Construct the layer
model = Dense(128, 256)

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);

# Dummy Input
x = randn(Float32, 128, 1024);
x_immutable = make_immutable(x);

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 4469 samples with 1 evaluation.
 Range (min  max):  483.810 μs  30.894 ms  ┊ GC (min  max): 0.00%   0.00%
 Time  (median):     716.669 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.100 ms ±  1.501 ms  ┊ GC (mean ± σ):  5.01% ± 12.19%

  █▆▆▅▄▃▂▂▂▂▃▃▃▂▁                                              ▁
  █████████████████▇▇▇▆▇▆▅▅▃▃▄▅▅▄▃▅▁▁▆▄▅▁▃▃▃▃▅▁▃▃▃▃▁▃▁▁▃▁▁▁▁▃▅ █
  484 μs        Histogram: log(frequency) by time      7.69 ms <

 Memory estimate: 2.00 MiB, allocs estimate: 4.

Immutable Arrays

BenchmarkTools.Trial: 259 samples with 1 evaluation.
 Range (min  max):  15.392 ms  52.229 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     17.997 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   19.327 ms ±  4.194 ms  ┊ GC (mean ± σ):  1.72% ± 4.44%

    ▃▆█ ▂                                                      
  ▃▆███▆█▇▅▇▇▄▆▃▆▄▄▅▄▄▄▄▄▃▄▄▃▂▁▃▃▂▁▃▂▁▁▂▂▁▂▂▂▁▃▁▃▂▂▁▁▁▂▂▁▂▂▁▂ ▃
  15.4 ms         Histogram: frequency by time        32.6 ms <

 Memory estimate: 7.00 MiB, allocs estimate: 262153.

Seems like there is a lot of time being spent on broadcasting the bias (seems like a problem with broadcasting in general)

julia> @benchmark $ps_immutable.weight * $x_immutable
BenchmarkTools.Trial: 4032 samples with 1 evaluation.
 Range (min  max):  346.287 μs  51.079 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     540.489 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.224 ms ±  1.854 ms  ┊ GC (mean ± σ):  2.36% ± 8.18%

  █▆▄▄▃▁▁▁ ▂▂▁▁▁▂▂▁▁  ▁▁                                       ▁
  █████████████████████████▇▇▇▆▇▆▇▆▆▃▆▆▆▅▅▅▅▄▅▅▅▆▅▅▅▅▅▅▄▃▁▁▁▃▃ █
  346 μs        Histogram: log(frequency) by time      8.78 ms <

 Memory estimate: 1.00 MiB, allocs estimate: 5.

julia> @benchmark $ps_immutable.weight * $x_immutable .+ $ps_immutable.bias
BenchmarkTools.Trial: 338 samples with 1 evaluation.
 Range (min  max):  11.177 ms  33.105 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     13.699 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   14.792 ms ±  3.901 ms  ┊ GC (mean ± σ):  2.43% ± 5.87%

   █▃                                                          
  ▅██▇▇▅▅▇▅▇▅▅▄▅▅▄▃▃▄▄▃▂▃▃▁▂▃▁▃▂▃▃▃▁▃▂▂▂▁▃▁▁▂▂▂▂▁▁▁▁▁▁▁▁▁▃▂▁▂ ▃
  11.2 ms         Histogram: frequency by time        30.9 ms <

 Memory estimate: 7.00 MiB, allocs estimate: 262153.

Trial 3: No broadcasting

model = Dense(128, 256; bias=false)

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 5501 samples with 1 evaluation.
 Range (min  max):  295.161 μs  23.801 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     451.402 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   899.925 μs ±  1.386 ms  ┊ GC (mean ± σ):  3.10% ± 8.68%

  █▆▆▄▃▂▁▂▁▁▁▂▂▂▂▁ ▁                                           ▁
  ██████████████████▇█▇█▇▇▆▆▇▇▆▆▆▆▆▆▅▅▅▆▅▅▁▆▄▆▅▃▅▄▅▄▆▄▅▁▄▆▅▅▃▅ █
  295 μs        Histogram: log(frequency) by time      6.98 ms <

 Memory estimate: 1.00 MiB, allocs estimate: 2.

Immutable Arrays

BenchmarkTools.Trial: 5303 samples with 1 evaluation.
 Range (min  max):  311.574 μs  26.953 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     436.316 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   930.509 μs ±  1.488 ms  ┊ GC (mean ± σ):  3.23% ± 8.75%

  █▆▅▃▂▁   ▁▁▂▁▁                                               ▁
  █████████████████▆█▇▇▆▆▆▆▆▆▆▆▆▅▅▅▅▅▅▄▄▅▅▅▅▂▅▂▄▅▄▅▄▄▃▂▃▄▄▂▃▂▃ █
  312 μs        Histogram: log(frequency) by time      7.61 ms <

 Memory estimate: 1.00 MiB, allocs estimate: 5.

Trial 4

model = Chain(Dense(128, 256; bias=false), Chain(Dense(256, 512; bias=false),
                                                                                   Dense(512, 10; bias=false)))

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 1372 samples with 1 evaluation.
 Range (min  max):  1.380 ms  49.871 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     2.918 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.615 ms ±  3.116 ms  ┊ GC (mean ± σ):  2.42% ± 7.94%

  ▅█    ▃                                                     
  ███▇▆▇██▇▆▅▄▄▄▃▃▃▃▂▃▃▃▂▃▂▃▂▂▂▂▁▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▂▂ ▃
  1.38 ms        Histogram: frequency by time        15.8 ms <

 Memory estimate: 3.04 MiB, allocs estimate: 6.

Immutable Arrays

BenchmarkTools.Trial: 894 samples with 1 evaluation.
 Range (min  max):  1.505 ms  66.104 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     4.153 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   5.561 ms ±  5.432 ms  ┊ GC (mean ± σ):  1.87% ± 7.54%

  █▆▅▅▅▄▅▆▆▅▄▄▂▂▂▂▁     ▁  ▁     ▁                            
  █████████████████▇█▆███▆▇█▅▆▇███▆▇█▄▇▇▇▅▄▆▅▅▁▄▁▆▄▁▅▇▅▄▄▆▁▅ █
  1.5 ms       Histogram: log(frequency) by time     23.1 ms <

 Memory estimate: 3.04 MiB, allocs estimate: 17.

cc @ChrisRackauckas @ianatol @aviatesk

@ianatol
Copy link

ianatol commented Apr 21, 2022

I think the poor broadcasting performance likely has to do with some missed chance to perform our memory optimization in the broadcast logic somewhere (i.e., we think it is unsafe to optimize in a place where it's actually safe to do so). I will take a look into this when I get a chance, but thanks for putting this together and providing a nice, realistic benchmark for performance going forward!

@ianatol
Copy link

ianatol commented Apr 21, 2022

Also, minor nit, but ImmutableArray will copy by itself if we can't optimize, so don't think copy is necessary here:

make_immutable(x::AbstractArray) = ImmutableArray(copy(x))

@avik-pal
Copy link
Member Author

I added it for ReshapedArray doesn't seem to have a dispatch for that

@avik-pal avik-pal added the enhancement New feature or request label Jun 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants