In [None]:
using Pkg
Pkg.activate("../env/")
using Flux, CUDA
using Flux: @layer
dev = CUDA.has_cuda_gpu() ? gpu : cpu

In [None]:
# PReLU
preluweights(ch_in::Int) = Chain(DepthwiseConv((1, 1), ch_in => ch_in;
                                 bias=false,
                                 init=rand32
))

struct ConvPReLU
    conv::Chain
end
@layer ConvPReLU

function ConvPReLU(ch_in::Int)
    return ConvPReLU(preluweights(ch_in))
end

function (m::ConvPReLU)(x)
    return max.(x, 0) .+ m.conv(min.(x, 0))
end

In [None]:
# Synth data
k = 8
c = 2
Xs = randn(Float32, k,k,c,10)
ys = rand(Bool, k,k,c,10)
# ys = ifelse.(ys, 0.9f0, 0.1f0)

data = Flux.DataLoader((Xs, ys)) |> dev

In [None]:
model = ConvPReLU(c) |> dev

In [None]:
loss(model,x,y) = Flux.mse(model(x), y)
opt = Flux.Adam()
opt_state = Flux.setup(opt, model)

In [None]:
x = randn(Float32, k,k,c,1) |> dev

In [None]:
y_before = model(x)
weights_before = Flux.destructure(model)[1]

In [None]:
for _ in 1:10   Flux.train!(loss, model, data, opt_state)   end

In [None]:
weights_after = Flux.destructure(model)[1]
@assert weights_before != weights_after || error("Weights should have been updated after training")

y_after = model(x)
@assert y_before != y_after || error("Model should have been changed after training")