### 18.2. 문자열 생성

셰익스피어 데이터셋 주소

https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt 

In [1]:
using Flux
using Flux.Losses: logitcrossentropy
import Zygote, Optimisers
using MLUtils: chunk, batchseq
using StatsBase: wsample
using Formatting: printfmtln
using Random: MersenneTwister

In [2]:
fpath = "tinyshakespeare.txt"
furl = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
isfile(fpath) || download(furl, fpath)

true

데이터셋 생성

In [3]:
function get_data(fpath; batch_size = 32, seq_len = 100)    
    text = collect(String(read(fpath)))
    alphabet = unique(text)
    '_' in alphabet || push!(alphabet, '_');
    
    text = map(ch -> Flux.onehot(ch, alphabet), text)
    pad = Flux.onehot('_', alphabet);

    Xs = batchseq(chunk(text, batch_size), pad)
    Xs = map(x -> hcat(x...), Xs)
    Xs = chunk(Xs; size = seq_len)#[1:end-1];

    Ys = batchseq(chunk(text[2:end], batch_size), pad)
    Ys = map(y -> hcat(y...), Ys)
    Ys = chunk(Ys; size = seq_len)#[1:end-1]

    zip(Xs, Ys), alphabet
end

get_data (generic function with 1 method)

In [4]:
fpath = "tinyshakespeare.txt"
loader, alphabet = get_data(fpath, batch_size = 32, seq_len = 100);

학습 함수 (16장 학습 함수와 동일)

In [5]:
function train(loader, model, loss_fn, optimizer)
    num_batches = length(loader)
    Flux.testmode!(model, false)
    for (batch, (X, y)) in enumerate(loader)
        X, y = Flux.gpu(X), Flux.gpu(y)
        grad = Zygote.gradient(m -> loss_fn(m, X, y), model)[1]
        optimizer, model = Optimisers.update(optimizer, model, grad)
        if batch % 100 == 0
            loss = loss_fn(model, X, y)
            printfmtln("[Train] loss: {:.7f} [{:>3d}/{:>3d}]", 
                loss, batch, num_batches)
        end
    end
    model, optimizer
end

train (generic function with 1 method)

모델 및 손실 함수

In [6]:
init(rng) = Flux.glorot_uniform(rng)

function build_model(N; rng)
    Chain(
        LSTM(N, 512; init=init(rng)),
        LSTM(512, 512; init=init(rng)),
        Dense(512, N; init=init(rng))
    )
end;

rng = MersenneTwister(1)
model = build_model(length(alphabet); rng=rng) |> gpu;
loss_fn(m, xs, ys) = sum(logitcrossentropy.([m(x) for x in xs], ys));
optimizer = Optimisers.setup(Optimisers.Adam(), model);

In [7]:
for epoch in 1:20
    Flux.reset!(model)
    println("Epoch $epoch")
    println("-------------------------------")
    global model, optimizer = train(loader, model, loss_fn, optimizer)
end

Epoch 1
-------------------------------


[Train] loss: 3

17.8969727 [100/349]


[Train] loss: 245.4123840 [200/349]


[Train] loss: 227.3763275 [300/349]


Epoch 2
-------------------------------


[Train] loss: 208.0317993 [100/349]


[Train] loss: 197.2237701 [200/349]


[Train] loss: 188.6101227 [300/349]


Epoch 3
-------------------------------


[Train] loss: 184.3034363 [100/349]


[Train] loss: 176.7043915 [200/349]


[Train] loss: 169.4081421 [300/349]


Epoch 4
-------------------------------


[Train] loss: 169.4940491 [100/349]


[Train] loss: 164.6542358 [200/349]


[Train] loss: 158.0427856 [300/349]


Epoch 5
-------------------------------


[Train] loss: 160.4414825 [100/349]


[Train] loss: 156.8576965 [200/349]


[Train] loss: 150.1932678 [300/349]


Epoch 6
-------------------------------


[Train] loss: 154.2440491 [100/349]


[Train] loss: 151.2652283 [200/349]


[Train] loss: 145.2941284 [300/349]


Epoch 7
-------------------------------


[Train] loss: 150.1058807 [100/349]


[Train] loss: 147.1708527 [200/349]


[Train] loss: 141.0655365 [300/349]


Epoch 8
-------------------------------


[Train] loss: 146.9607086 [100/349]


[Train] loss: 143.9465942 [200/349]


[Train] loss: 137.7180939 [300/349]


Epoch 9
-------------------------------


[Train] loss: 144.1071320 [100/349]


[Train] loss: 140.9292755 [200/349]


[Train] loss: 134.2348328 [300/349]


Epoch 10
-------------------------------


[Train] loss: 141.4981232 [100/349]


[Train] loss: 137.7169342 [200/349]


[Train] loss: 131.6757050 [300/349]


Epoch 11
-------------------------------


[Train] loss: 138.5637970 [100/349]


[Train] loss: 135.7661896 [200/349]


[Train] loss: 129.5307312 [300/349]


Epoch 12
-------------------------------


[Train] loss: 135.4887238 [100/349]


[Train] loss: 133.0057983 [200/349]


[Train] loss: 127.0049286 [300/349]


Epoch 13
-------------------------------


[Train] loss: 132.5238495 [100/349]


[Train] loss: 131.0608368 [200/349]


[Train] loss: 125.8164368 [300/349]


Epoch 14
-------------------------------


[Train] loss: 131.0935669 [100/349]


[Train] loss: 130.2286530 [200/349]


[Train] loss: 124.7033081 [300/349]


Epoch 15
-------------------------------


[Train] loss: 128.7349243 [100/349]


[Train] loss: 129.7062531 [200/349]


[Train] loss: 122.3825760 [300/349]


Epoch 16
-------------------------------


[Train] loss: 126.4792175 [100/349]


[Train] loss: 128.2374268 [200/349]


[Train] loss: 121.1559906 [300/349]


Epoch 17
-------------------------------


[Train] loss: 124.7345428 [100/349]


[Train] loss: 126.2337265 [200/349]


[Train] loss: 118.8970642 [300/349]


Epoch 18
-------------------------------


[Train] loss: 123.4196625 [100/349]


[Train] loss: 124.0294418 [200/349]


[Train] loss: 117.6818695 [300/349]


Epoch 19
-------------------------------


[Train] loss: 122.6253052 [100/349]


[Train] loss: 121.4550552 [200/349]


[Train] loss: 117.0097656 [300/349]


Epoch 20
-------------------------------


[Train] loss: 121.0933533 [100/349]


[Train] loss: 119.8905106 [200/349]


[Train] loss: 116.1155472 [300/349]


In [8]:
function generate(model, alphabet, init, len; rng)    
    model = model |> cpu
    Flux.reset!(model)
    generated = [init]
    for _ in 1:len
        w = softmax(model(Flux.onehot(generated[end], alphabet)))
        push!(generated, wsample(rng, alphabet, w))
    end
    text = String(generated)
    for r in split(text, '\n')
        println(r)
    end
    text
end

generate (generic function with 1 method)

In [14]:
generate(model, alphabet, 'O', 500; rng=MersenneTwister(1));

O, Yet glove against the king:
I was a father doth paph direth.

ROMEO:
Stupp'st; My brother, till you had but fury to that.

HERMIONE:
Why, thy are,
Doth hein and leart in strifferance.

SEBASTIAN:
An I all so: but hath been so disconcedded;
'Tis a thrifty is scen-willingly go.
Came heart ucts not for you, But I by no.
Ah, so! no; I, how ly Isable, nor then;
So assured, mirrily, or Claudio, lives thustraft.

KING LEWIS XI:
What is this?

First Citizen:
We are, go you at your children, sir; yet I


In [12]:
using BSON: @save
@save "tinyshakespeare.bson" model = Flux.cpu(model), alphabet