In [1]:
using Flux
using Random
using Statistics
using Printf: @printf

In [2]:
relu(x::Number) = max(0, x)
sigmoid(x::Number) = 1. / (1. + exp(-x));

In [3]:
# for reproducibility
rng = MersenneTwister(1983);

In [4]:
batch_size = 64
n_inputs = 100
input = randn(batch_size, n_inputs);

truth = zeros(batch_size)
truth[rand(rng, Float64, batch_size) .> 0.5] .= 1
classes = unique(truth)
n_outputs = length(classes)

n_hidden = 8
w_ih = randn(rng, Float64, (n_inputs, n_hidden));
w_ho = randn(rng, Float64, (n_hidden, n_outputs))

y = Flux.onehotbatch(truth, classes)';

In [5]:
y_h = relu.(input * w_ih)
y_hat = sigmoid.(y_h * w_ho)
loss = Flux.logitcrossentropy(y_hat, y);
@printf("Loss: %g.", loss)

Loss: 135.668.

In [6]:
pars = Flux.Params([w_ih, w_ho])
loss,grads = Flux.withgradient(pars) do
    y_hid = relu.(input * w_ih)
    y_hat = sigmoid.(y_hid * w_ho)
    Flux.logitcrossentropy(y_hat, y)
end
@printf("Loss: %g.", loss)

Loss: 135.668.

In [7]:
for p in pars
    println(size(grads[p]))
end

(100, 8)
(8, 2)
