Machine Learning today has evolved to use many highly trained models in a general task, where they are tuned to perform especially well on a subset of the problem.

This is one of the key ways in which larger (or smaller) models are used in practice. They are trained on a general problem, achieving good results on the test set, and then subsequently tuned on specialised datasets.

In this process, our model is already pretty well trained on the problem, so we don’t need to train it all over again as if from scratch. In fact, as it so happens, we don’t need to do that at all! We only need to tune the last couple of layers to get the most performance from our models. The exact last number of layers is dependant on the problem setup and the expected outcome, but a common tip is to train the last few Dense layers in a more complicated model.

So let’s try to simulate the problem in Flux.

We’ll tune a pretrained ResNet from Metalhead as a proxy. We will tune the Dense layers in there on a new set of images.

In [1]:
using Flux, Metalhead
using Flux: @epochs
using Metalhead.Images


Chain(
  Conv((7, 7), 3 => 64, pad=3, stride=2),  [90m# 9_472 parameters[39m
  MaxPool((3, 3), pad=1, stride=2),
  ResidualBlock(
    Tuple(
      Conv((1, 1), 64 => 64),           [90m# 4_160 parameters[39m
      Conv((3, 3), 64 => 64, pad=1),    [90m# 36_928 parameters[39m
      Conv((1, 1), 64 => 256),          [90m# 16_640 parameters[39m
    ),
    Tuple(
      BatchNorm(64),                    [90m# 128 parameters[39m[90m, plus 128[39m
      BatchNorm(64),                    [90m# 128 parameters[39m[90m, plus 128[39m
      BatchNorm(256),                   [90m# 512 parameters[39m[90m, plus 512[39m
    ),
    Chain(
      Conv((1, 1), 64 => 256),          [90m# 16_640 parameters[39m
      BatchNorm(256),                   [90m# 512 parameters[39m[90m, plus 512[39m
    ),
  ),
  ResidualBlock(
    Tuple(
      Conv((1, 1), 256 => 64),          [90m# 16_448 parameters[39m
      Conv((3, 3), 64 => 64, pad=1),    [90m# 36_928 parameters[39m
      Conv((1,

In [4]:
resnet = ResNet().layers

Chain(
  Conv((7, 7), 3 => 64, pad=3, stride=2),  [90m# 9_472 parameters[39m
  MaxPool((3, 3), pad=1, stride=2),
  ResidualBlock(
    Tuple(
      Conv((1, 1), 64 => 64),           [90m# 4_160 parameters[39m
      Conv((3, 3), 64 => 64, pad=1),    [90m# 36_928 parameters[39m
      Conv((1, 1), 64 => 256),          [90m# 16_640 parameters[39m
    ),
    Tuple(
      BatchNorm(64),                    [90m# 128 parameters[39m[90m, plus 128[39m
      BatchNorm(64),                    [90m# 128 parameters[39m[90m, plus 128[39m
      BatchNorm(256),                   [90m# 512 parameters[39m[90m, plus 512[39m
    ),
    Chain(
      Conv((1, 1), 64 => 256),          [90m# 16_640 parameters[39m
      BatchNorm(256),                   [90m# 512 parameters[39m[90m, plus 512[39m
    ),
  ),
  ResidualBlock(
    Tuple(
      Conv((1, 1), 256 => 64),          [90m# 16_448 parameters[39m
      Conv((3, 3), 64 => 64, pad=1),    [90m# 36_928 parameters[39m
      Conv((1,

In [5]:
#hypotethical
model = Chain(
    resnet[1:end-2],               # We only need to pull out the dense layer in here
    x -> reshape(x, size_we_want), # / global_avg_pooling layer
    Dense(reshaped_input_features, n_classes)
  )

LoadError: UndefVarError: reshaped_input_features not defined

In [10]:
using Flux, Images
using StatsBase: sample, shuffle

const PATH = joinpath(@__DIR__, "train")
const FILES = joinpath.(PATH, readdir(PATH))
if isempty(readdir(PATH))
  error("Empty train folder - perhaps you need to download and extract the kaggle dataset.")
end

const DOGS = filter(x -> occursin("dog", x), FILES)
const CATS = filter(x -> occursin("cat", x), FILES)

function load_batch(n = 10, nsize = (224,224); path = PATH)
  imgs_paths = shuffle(vcat(sample(DOGS, Int(n/2)), sample(CATS, Int(n/2))))
  labels = map(x -> occursin("dog.",x) ? 1 : 0, imgs_paths)
  labels = Flux.onehotbatch(labels, [0,1])
  imgs = Images.load.(imgs_paths)
  imgs = map(img -> Images.imresize(img, nsize...), imgs)
  imgs = map(img -> permutedims(channelview(img), (3,2,1)), imgs)
  imgs = cat(imgs..., dims = 4)
  Float32.(imgs), labels
end

load_batch (generic function with 3 methods)

In [16]:
model = Chain(
  resnet[1:end-2],
  Dense(2048, 1000),  
  Dense(1000, 256),
  Dense(256, 2),        # we get 2048 features out, and we have 2 classes
)

Chain(
  Chain(
    Conv((7, 7), 3 => 64, pad=3, stride=2),  [90m# 9_472 parameters[39m
    MaxPool((3, 3), pad=1, stride=2),
    ResidualBlock(
      Tuple(
        Conv((1, 1), 64 => 64),         [90m# 4_160 parameters[39m
        Conv((3, 3), 64 => 64, pad=1),  [90m# 36_928 parameters[39m
        Conv((1, 1), 64 => 256),        [90m# 16_640 parameters[39m
      ),
      Tuple(
        BatchNorm(64),                  [90m# 128 parameters[39m[90m, plus 128[39m
        BatchNorm(64),                  [90m# 128 parameters[39m[90m, plus 128[39m
        BatchNorm(256),                 [90m# 512 parameters[39m[90m, plus 512[39m
      ),
      Chain(
        Conv((1, 1), 64 => 256),        [90m# 16_640 parameters[39m
        BatchNorm(256),                 [90m# 512 parameters[39m[90m, plus 512[39m
      ),
    ),
    ResidualBlock(
      Tuple(
        Conv((1, 1), 256 => 64),        [90m# 16_448 parameters[39m
        Conv((3, 3), 64 => 64, pad=1),  [90m# 36_9

In [25]:
model = model |> gpu
dataset = [gpu.(load_batch(10)) for i in 1:10]

LoadError: cannot assign a value to variable Metalhead.dataset from module Main

In [12]:
opt = ADAM()
loss(x,y) = Flux.Losses.logitcrossentropy(model(x), y)

loss (generic function with 1 method)

In [13]:
ps = Flux.params(model[2:end])  # ignore the already trained layers of the ResNet

Params([Float32[-0.031635202 0.021174062 … 0.039401326 0.019845864; -0.028860807 -0.018728338 … 0.027546868 0.036203995; … ; -0.042369913 0.031618297 … -0.009289181 -0.008160443; 0.024158053 0.043600507 … -0.010321118 -0.028277107], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.038418807 0.04389513 … -0.06675846 0.030356195; 0.0025776026 -0.055093326 … 0.009703058 -0.05153177; … ; 0.06616131 -0.050281897 … 0.021569693 -0.032436736; 0.0032285412 -0.050683416 … -0.003926493 0.036584653], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.01469563 0.06130809 … -0.074142456 0.01752065; 0.0527549 0.085296914 … 0.045016725 0.015742391], Float32[0.0, 0.0]])

In [14]:
@epochs 2 Flux.train!(loss, ps, dataset, opt)

┌ Info: Epoch 1
└ @ Main C:\Users\joaof\.julia\packages\Flux\Zz9RI\src\optimise\train.jl:138


LoadError: MethodError: no method matching iterate(::typeof(dataset))
[0mClosest candidates are:
[0m  iterate([91m::Union{LinRange, StepRangeLen}[39m) at range.jl:664
[0m  iterate([91m::Union{LinRange, StepRangeLen}[39m, [91m::Int64[39m) at range.jl:664
[0m  iterate([91m::T[39m) where T<:Union{Base.KeySet{var"#s79", var"#s78"} where {var"#s79", var"#s78"<:Dict}, Base.ValueIterator{var"#s77"} where var"#s77"<:Dict} at dict.jl:693
[0m  ...

In [None]:
imgs, labels = gpu.(load_batch(10))
display(model(imgs))

labels