In [37]:
using Statistics
using Flux, Flux.Optimise
using Flux: onehotbatch, onecold
using Flux: crossentropy, Momentum
using Base.Iterators: partition
using Metal
using MLDatasets
using MLUtils
using Images.ImageCore
using BenchmarkTools: @btime

In [55]:
Flux.OneHotMatrix(X.targets, 10) |> gpu

10×50000 OneHotMatrix(::MtlVector{Int64, Private}) with eltype Bool:
 ⋅  ⋅  ⋅  ⋅  1  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  …  ⋅  1  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  1  1
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅     ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  …  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  1  1     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  1  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅

In [62]:
# Prepare data
X = CIFAR10(:train)
labels = Flux.OneHotMatrix(X.targets, 10) |> gpu
imgs = X.features |> gpu

# Load data to gpu and cpu
train_cpu = [(imgs[:,:,:,i], labels[:,i]) for i in partition(1:50000, 1000)] |> cpu
train_gpu = [(imgs[:,:,:,i], labels[:,i]) for i in partition(1:50000, 1000)] |> gpu
;

In [63]:
# Define neural networks for both gpu and cpu
m_gpu = Chain(
  Conv((5,5), 3=>16, relu),
  MaxPool((2,2)),
  Conv((5,5), 16=>8, relu),
  MaxPool((2,2)),
  flatten,
  Dense(200, 120),
  Dense(120, 84),
  Dense(84, 10),
  softmax) |> gpu

Chain(
  Conv((5, 5), 3 => 16, relu),          [90m# 1_216 parameters[39m
  MaxPool((2, 2)),
  Conv((5, 5), 16 => 8, relu),          [90m# 3_208 parameters[39m
  MaxPool((2, 2)),
  MLUtils.flatten,
  Dense(200 => 120),                    [90m# 24_120 parameters[39m
  Dense(120 => 84),                     [90m# 10_164 parameters[39m
  Dense(84 => 10),                      [90m# 850 parameters[39m
  NNlib.softmax,
) [90m                  # Total: 10 arrays, [39m39_558 parameters, 1.758 KiB.

In [64]:
m_cpu = Chain(
  Conv((5,5), 3=>16, relu),
  MaxPool((2,2)),
  Conv((5,5), 16=>8, relu),
  MaxPool((2,2)),
  flatten,
  Dense(200, 120),
  Dense(120, 84),
  Dense(84, 10),
  softmax) |> cpu

Chain(
  Conv((5, 5), 3 => 16, relu),          [90m# 1_216 parameters[39m
  MaxPool((2, 2)),
  Conv((5, 5), 16 => 8, relu),          [90m# 3_208 parameters[39m
  MaxPool((2, 2)),
  MLUtils.flatten,
  Dense(200 => 120),                    [90m# 24_120 parameters[39m
  Dense(120 => 84),                     [90m# 10_164 parameters[39m
  Dense(84 => 10),                      [90m# 850 parameters[39m
  NNlib.softmax,
) [90m                  # Total: 10 arrays, [39m39_558 parameters, 155.852 KiB.

In [65]:
# Define loss and optimizer
loss_gpu(m, x, y) = sum(crossentropy(m(x), y))
opt_gpu = Momentum(0.01)

loss_cpu(x, y) = sum(crossentropy(m_cpu(x), y))
opt_cpu = Momentum(0.01)

Momentum(0.01, 0.9, IdDict{Any, Any}())

In [66]:
# Set number of training iterations
epochs = 1

1

In [68]:
for epoch in 1:epochs
    for (x, y) in train_gpu
        grads = gradient(m -> loss_gpu(m, x, y), m_gpu)
        Flux.update!(opt_state, model, grads[1])
    end
end

CompositeException: TaskFailedException

    nested task error: TaskFailedException
    
        nested task error: Scalar indexing is disallowed.
        Invocation of getindex resulted in scalar indexing of a GPU array.
        This is typically caused by calling an iterating implementation of a method.
        Such implementations *do not* execute on the GPU, but very slowly on the CPU,
        and therefore should be avoided.
        
        If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
        to enable scalar iteration globally or for the operations in question.
        Stacktrace:
          [1] error(s::String)
            @ Base ./error.jl:35
          [2] errorscalar(op::String)
            @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:155
          [3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
            @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:128
          [4] assertscalar(op::String)
            @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:116
          [5] getindex
            @ ~/.julia/packages/GPUArrays/bbZD0/src/host/indexing.jl:50 [inlined]
          [6] scalar_getindex
            @ ~/.julia/packages/GPUArrays/bbZD0/src/host/indexing.jl:36 [inlined]
          [7] _getindex
            @ ~/.julia/packages/GPUArrays/bbZD0/src/host/indexing.jl:19 [inlined]
          [8] getindex
            @ ~/.julia/packages/GPUArrays/bbZD0/src/host/indexing.jl:17 [inlined]
          [9] getindex
            @ ./subarray.jl:290 [inlined]
         [10] im2col!(col::MtlMatrix{Float32, Private}, x::SubArray{Float32, 4, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Int64}, true}, cdims::DenseConvDims{3, 3, 3, 6, 3})
            @ NNlib ~/.julia/packages/NNlib/PmySZ/src/impl/conv_im2col.jl:238
         [11] (::NNlib.var"#640#641"{MtlArray{Float32, 3, Private}, Float32, Float32, SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, MtlArray{Float32, 5, Private}, DenseConvDims{3, 3, 3, 6, 3}, Int64, Int64, Int64, UnitRange{Int64}, Int64})()
            @ NNlib ~/.julia/packages/NNlib/PmySZ/src/impl/conv_im2col.jl:54
    Stacktrace:
     [1] sync_end(c::Channel{Any})
       @ Base ./task.jl:448
     [2] macro expansion
       @ ./task.jl:480 [inlined]
     [3] conv_im2col!(y::SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::MtlArray{Float32, 5, Private}, cdims::DenseConvDims{3, 3, 3, 6, 3}; col::MtlArray{Float32, 3, Private}, alpha::Float32, beta::Float32, ntasks::Int64)
       @ NNlib ~/.julia/packages/NNlib/PmySZ/src/impl/conv_im2col.jl:50
     [4] conv_im2col!(y::SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::MtlArray{Float32, 5, Private}, cdims::DenseConvDims{3, 3, 3, 6, 3})
       @ NNlib ~/.julia/packages/NNlib/PmySZ/src/impl/conv_im2col.jl:23
     [5] (::NNlib.var"#298#302"{@Kwargs{}, DenseConvDims{3, 3, 3, 6, 3}, SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, MtlArray{Float32, 5, Private}, SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
       @ NNlib ~/.julia/packages/NNlib/PmySZ/src/conv.jl:209

In [46]:
# GPU benchmark
# @btime for epoch = 1:epochs
  for d in train_gpu
    gs = gradient(Flux.params(m_gpu)) do
      l = loss_gpu(d...)
    end
    update!(opt_gpu, params(m_gpu), gs)
  end
# end

CompositeException: TaskFailedException

    nested task error: TaskFailedException
    
        nested task error: Scalar indexing is disallowed.
        Invocation of getindex resulted in scalar indexing of a GPU array.
        This is typically caused by calling an iterating implementation of a method.
        Such implementations *do not* execute on the GPU, but very slowly on the CPU,
        and therefore should be avoided.
        
        If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
        to enable scalar iteration globally or for the operations in question.
        Stacktrace:
          [1] error(s::String)
            @ Base ./error.jl:35
          [2] errorscalar(op::String)
            @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:155
          [3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
            @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:128
          [4] assertscalar(op::String)
            @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:116
          [5] getindex
            @ ~/.julia/packages/GPUArrays/bbZD0/src/host/indexing.jl:50 [inlined]
          [6] scalar_getindex
            @ ~/.julia/packages/GPUArrays/bbZD0/src/host/indexing.jl:36 [inlined]
          [7] _getindex
            @ ~/.julia/packages/GPUArrays/bbZD0/src/host/indexing.jl:19 [inlined]
          [8] getindex
            @ ~/.julia/packages/GPUArrays/bbZD0/src/host/indexing.jl:17 [inlined]
          [9] getindex
            @ ./subarray.jl:290 [inlined]
         [10] im2col!(col::MtlMatrix{Float32, Private}, x::SubArray{Float32, 4, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Int64}, true}, cdims::DenseConvDims{3, 3, 3, 6, 3})
            @ NNlib ~/.julia/packages/NNlib/PmySZ/src/impl/conv_im2col.jl:238
         [11] (::NNlib.var"#640#641"{MtlArray{Float32, 3, Private}, Float32, Float32, SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, MtlArray{Float32, 5, Private}, DenseConvDims{3, 3, 3, 6, 3}, Int64, Int64, Int64, UnitRange{Int64}, Int64})()
            @ NNlib ~/.julia/packages/NNlib/PmySZ/src/impl/conv_im2col.jl:54
    Stacktrace:
     [1] sync_end(c::Channel{Any})
       @ Base ./task.jl:448
     [2] macro expansion
       @ ./task.jl:480 [inlined]
     [3] conv_im2col!(y::SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::MtlArray{Float32, 5, Private}, cdims::DenseConvDims{3, 3, 3, 6, 3}; col::MtlArray{Float32, 3, Private}, alpha::Float32, beta::Float32, ntasks::Int64)
       @ NNlib ~/.julia/packages/NNlib/PmySZ/src/impl/conv_im2col.jl:50
     [4] conv_im2col!(y::SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::MtlArray{Float32, 5, Private}, cdims::DenseConvDims{3, 3, 3, 6, 3})
       @ NNlib ~/.julia/packages/NNlib/PmySZ/src/impl/conv_im2col.jl:23
     [5] (::NNlib.var"#298#302"{@Kwargs{}, DenseConvDims{3, 3, 3, 6, 3}, SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, MtlArray{Float32, 5, Private}, SubArray{Float32, 5, MtlArray{Float32, 5, Private}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
       @ NNlib ~/.julia/packages/NNlib/PmySZ/src/conv.jl:209

In [44]:
# CPU benchmark
for epoch = 1:epochs
    for d in train_cpu
      gs = gradient(Flux.params(m_cpu)) do
        l = loss_cpu(d...)
      end
      update!(opt_cpu, Flux.params(m_cpu), gs)
    end
  end