### 17.2. 전이 학습

개와 고양이 이미지셋 주소
https://www.kaggle.com/competitions/dogs-vs-cats/data

In [1]:
import Metalhead, Images
using StatsBase: sample, shuffle
import Flux, NNlib
import Zygote, Optimisers, Functors
using Formatting: printfmtln
using Random: MersenneTwister

데이터셋 준비

In [2]:
function get_image_sampler(path, rng)
    if isempty(readdir(path))
        error("Empty train folder")
    end
    files = joinpath.(path, readdir(path))    
    dogs = filter(x -> occursin("dog", x), files)
    cats = filter(x -> occursin("cat", x), files) 

    function image_sampler(n = 10, size = (224, 224))
        @assert iseven(n)
        dogs_ = sample(rng, dogs, Int(n/2))
        cats_ = sample(rng, cats, Int(n/2))
        imgs_paths = shuffle(rng, vcat(dogs_, cats_))

        imgs = Images.load.(imgs_paths)
        imgs = map(img -> Images.imresize(img, size...), imgs)
        
        imgs = map(imgs) do img # [CHW] -> [WHC]
            permutedims(Images.channelview(img), (3,2,1))
        end
        imgs = cat(imgs..., dims = 4) # [WHC] => WHCN
        imgs = Float32.(imgs)

        labels = map(x -> occursin("dog", x) ? 1 : 0, imgs_paths)
        labels = Flux.onehotbatch(labels, [0,1])

        imgs, labels
    end
end

get_image_sampler (generic function with 1 method)

학습, 테스트 함수 (16장과 동일)

In [3]:
function train(loader, model, loss_fn, optimizer)
    num_batches = length(loader)
    Flux.testmode!(model, false)
    for (batch, (X, y)) in enumerate(loader)
        X, y = Flux.gpu(X), Flux.gpu(y)
        grad = Zygote.gradient(m -> loss_fn(m, X, y), model)[1]
        optimizer, model = Optimisers.update(optimizer, model, grad)
        if batch % 10 == 0
            loss = loss_fn(model, X, y)
            printfmtln("[Train] loss: {:.7f} [{:>3d}/{:>3d}]", 
                loss, batch, num_batches)
        end
    end
    model, optimizer
end

function test(loader, model, loss_fn)
    num_batches = length(loader)
    Flux.testmode!(model, true)
    acc, tot = 0, 0
    loss = 0f0
    for (X, y) in loader
        X, y = Flux.gpu(X), Flux.gpu(y)
        pred = model(X)
        acc += sum(Flux.onecold(pred) .== Flux.onecold(y))
        tot += size(X)[end]
        loss += loss_fn(model, X, y)
    end
    acc, avg_loss = acc / tot * 100, loss / num_batches
    printfmtln("[Test] Accuracy: {:.1f}, Avg loss: {:.7f}", 
        acc, avg_loss)
    acc, avg_loss
end

init(rng) = Flux.glorot_uniform(rng)

init (generic function with 1 method)

모델 정의

In [4]:
struct MyResnet
    resnet
    dense
end
function (a::MyResnet)(x)
    x = a.resnet.layers[1](x)
    x = Flux.AdaptiveMeanPool((1, 1))(x)
    x = Flux.flatten(x)
    a.dense(x)
end
Functors.@functor MyResnet

학습 및 테스트

In [5]:
function run_resnet(rng; pretrain)
    sampler = get_image_sampler("cats_dogs", rng)
    resnet = Metalhead.ResNet(18, pretrain = pretrain)
    model = MyResnet(resnet, Flux.Dense(512 => 2; init=init(rng))) 
    model = model |> Flux.gpu
    optimizer = Optimisers.setup(Optimisers.Adam(), model)
    loader = (sampler(10) for _ in 1:100) 
    loss_fn = (m, x, y) -> Flux.Losses.logitcrossentropy(m(x), y)
    model, _ = train(loader, model, loss_fn, optimizer)
    loader = (sampler(10) for _ in 1:20)
    test(loader, model, loss_fn)
end

run_resnet (generic function with 1 method)

사전 미 학습

In [6]:
rng = MersenneTwister(1)
run_resnet(rng; pretrain = false);

[Train] loss: 0.7123822 [ 10/100]
[Train] loss: 0.5388184 [ 20/100]
[Train] loss: 0.8376984 [ 30/100]
[Train] loss: 0.6657692 [ 40/100]
[Train] loss: 1.1854464 [ 50/100]
[Train] loss: 0.6098708 [ 60/100]
[Train] loss: 0.6977958 [ 70/100]
[Train] loss: 0.6554491 [ 80/100]
[Train] loss: 0.6035808 [ 90/100]
[Train] loss: 0.8463958 [100/100]
[Test] Accuracy: 50.0, Avg loss: 0.8647254


사전 학습 (전이 학습)

In [7]:
Optimisers.trainable(x::MyResnet) = (; dense = x.dense)
run_resnet(rng; pretrain = true);

[Train] loss: 0.6280646 [ 10/100]
[Train] loss: 0.3220276 [ 20/100]
[Train] loss: 0.2373100 [ 30/100]
[Train] loss: 0.1000746 [ 40/100]
[Train] loss: 0.2597194 [ 50/100]
[Train] loss: 0.0759974 [ 60/100]
[Train] loss: 0.0343409 [ 70/100]
[Train] loss: 0.2529985 [ 80/100]
[Train] loss: 0.3924147 [ 90/100]
[Train] loss: 0.2217513 [100/100]
[Test] Accuracy: 93.5, Avg loss: 0.1560094
