In [1]:
using Flux

## Training Only Specified Parameters

Only the fields returned by trainable will be collected as trainable parameters of the layer when calling Flux.params.

Define custom layer as a struct with a constructor.

In [3]:
struct Affine
  W
  b
end

In [4]:
Affine(in::Integer, out::Integer) =
  Affine(randn(out, in), randn(out))

Affine

Overload the struct.

In [5]:
(m::Affine)(x) = m.W * x .+ m.b

Suppose only $W$ needs to be trained. Overload the trainable function to return just $W$.

In [8]:
Flux.trainable(a::Affine) = (a.W,)

An example with fake data.

In [6]:
a = Affine(5,1)

Affine([0.8608780682418017 -1.5395126915150943 … 1.1875546906171954 -0.28030155423215836], [0.4226600495054129])

In [7]:
a(rand(5,2))

1×2 Matrix{Float64}:
 0.500636  0.116417

Flux.params collects only parameters returned by trainable.

In [12]:
ps = Flux.params(a)

Params([[0.8608780682418017 -1.5395126915150943 … 1.1875546906171954 -0.28030155423215836]])

In [13]:
x = rand(5, 100)
y = rand(1,5)*x

1×100 Matrix{Float64}:
 1.35773  1.14898  1.29776  1.41374  …  1.3741  1.53472  0.485464  1.35829

In [14]:
gs = Flux.gradient(()->Flux.mse(a(x), y), Flux.params(a))

Grads(...)

In [40]:
length(gs)

1

The gradient object only contains gradient wrt $W$.

In [16]:
gs[a.W]

1×5 Matrix{Float64}:
 -0.344647  -0.914001  -0.455823  -0.288851  -0.440177

A model with many layers. We want to train only $W$ of the second layer. The standard flux Dense layer object denotes $W$ by "weight".

In [35]:
m = Chain(
      Dense(5 => 3, relu),
      Dense(3 => 2, relu),
      Dense(2 => 1)
    )

Chain(
  Dense(5 => 3, relu),                  [90m# 18 parameters[39m
  Dense(3 => 2, relu),                  [90m# 8 parameters[39m
  Dense(2 => 1),                        [90m# 3 parameters[39m
) 

In [44]:
Flux.trainable(m::Chain) = (weight=m[2].weight,)

In [54]:
ps = Flux.params(m)
gs = Flux.gradient(m->Flux.mse(m(x), y), m)

((layers = ((weight = Float32[-0.34470168 -0.38269594 … -0.21690816 -0.30961356; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], bias = Float32[-0.53195184, 0.0, 0.0], σ = nothing), (weight = Float32[-0.08672465 0.0 0.0; -0.21705872 0.0 0.0], bias = Float32[-0.39137542, -0.97955364], σ = nothing), (weight = Float32[-0.08895962 -0.050059285], bias = Float32[-1.7786287], σ = nothing)),),)

In [50]:
gs[m[2].weight]

2×3 Matrix{Float32}:
 -0.0867246  0.0  0.0
 -0.217059   0.0  0.0

In [48]:
Flux.trainable(m)

(weight = Float32[0.5643518 0.7329073 -0.8194195; 0.31757158 -0.039004058 -0.98523957],)

In [41]:
for l in m.layers
    println(l.weight, l.bias)
end

Float32[0.5018218 0.40540275 -0.6181809 -0.7249601 0.21906047; -0.009452393 -0.25717664 0.16760351 -0.4715247 -0.68426275; -0.49048844 -0.79621434 0.5623093 0.69380105 -0.19479854]Float32[0.0, 0.0, 0.0]
Float32[0.5643518 0.7329073 -0.8194195; 0.31757158 -0.039004058 -0.98523957]Float32[0.0, 0.0]
Float32[0.55017334 1.3770009]Float32[0.0]


In [53]:
Flux.update!(Flux.setup(ADAM(0.01), m), m, gs)

LoadError: Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`

## MTL Split Layers

In [36]:
using Fluxperimental

In [37]:
model = Chain(
              Dense(3 => 5),
              Split(Dense(5 => 1, tanh), Dense(5 => 1, tanh))
        )

Chain(
  Dense(3 => 5),                        [90m# 20 parameters[39m
  Split(
    Tuple(
      Dense(5 => 1, tanh),              [90m# 6 parameters[39m
      Dense(5 => 1, tanh),              [90m# 6 parameters[39m
    ),
  ),
) [90m                  # Total: 6 arrays, [39m32 parameters, 576 bytes.

In [40]:
xs = randn(3, 4)
ypred = model(xs)
typeof(ypred)

Tuple{Matrix{Float64}, Matrix{Float64}}