# Minimal character-level Vanilla RNN model.
Original version by Andrej Karpathy (@karpathy), BSD License.
The original Python code is available at this link [https://gist.github.com/karpathy/d4dee566867f8291f086](https://gist.github.com/karpathy/d4dee566867f8291f086).
More complex model (with better results) [here](https://colab.research.google.com/github/trekhleb/machine-learning-experiments/blob/master/experiments/text_generation_shakespeare_rnn/text_generation_shakespeare_rnn.ipynb#scrollTo=bui0MyTjv1Mp).

Julia version by David Métivier [david.metivier@inrae.fr](mailto:david.metivier@inrae.fr).

This Julia version only use the package `HTTP` to download the Shakespeare dataset (not needed for Proust), and the package `JLD2` to save/load parameters if needed.
It means that it should work on freshly installed Julia without any package.
**Note:** This is not how one would write good (& fast) Julia code but it mimics very closely the Python version and is very readable.

For Julia newcomer, here are some tips to understand the code:
- To install Julia use
  - Windows store on Windows. That's it!
  - `curl -fsSL https://install.julialang.org | sh` on Linux/Mac in a terminal. That's it!
- To use Jupyter notebook, just install VSCode extension Jupyter and Julia extension. That's it!
- To add package from a Jupyter Notebook you can do `import Pkg; Pkg.add("HTTP")`.
- For more explanations on how to install packages etc. look at [Modern Julia Workflows](https://modernjuliaworkflows.org/writing/).
- Julia can use unicode to have nice looking variables like `\partial + TAB`  -> `∂` or `\_t + TAB`  -> `ₜ`. Some like it some don't!
- Julia uses dot notation `.` to broadcast operations to every element of an Array e.g. `[1,2,3].^2 = [1,4,9]` of `cos.([0, 2π]) = [1.0, 1.0]`. In `R` and `Python` this broadcast is generally implicit.

# Data

## Choose the dataset
Shakespeare dataset seems a bit easier to train than the Proust dataset, probably because there is a smaller list of unique character + maybe the verse structure is simpler than long French sentence and complex grammar?

### Shakespeare
All the works of Shakespeare concatenated into a single (4.4MB) file.

In [1]:
using HTTP
url = "https://raw.githubusercontent.com/weixsong/min-char-rnn/refs/heads/master/input.txt"
response = HTTP.get(url)
data = String(response.body)

"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we"[93m[1m ⋯ 1114856 bytes ⋯ [22m[39m"speak'st\nOut of thy sleep. What is it thou didst say?\nThis is a strange repose, to be asleep\nWith eyes wide open; standing, speaking, moving,\nAnd yet so fast asleep.\n\nANTONIO:\nNoble Sebastian,\nThou let'st thy fortune sleep--die, rather; wink'st\nWhiles thou art waking."

Extract from the beginning

In [2]:
println(data[1:463])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.



### A la Recherche du temps perdu
All the books from A la Recherche du temps perdu by Marcel Proust into a single file (7MB).
Most special characters "à", "é" etc have been removed (and replaced by analog) to simplified training.

In [3]:
data = read("proust_simplified.txt", String)

"MARCEL PROUST\r\na LA RECHERCHE\r\n\r\nDU TEMPS PERDU\r\nI\r\n\r\nDU CoTe DE CHEZ SWANN\r\n\r\n(PREMIeRE PARTIE)\r\n\r\nI\r\nLongtemps, je me suis couche de bonne heure. Parfois, a peine ma bougie eteinte, mes yeux se fermaient si vite que je n'avais pas le temps de me dire : \" Je m'endors."[93m[1m ⋯ 7250873 bytes ⋯ [22m[39m" est reservee dans l'espace, une place, au contraire, prolongee sans mesure, puisqu'ils touchent simultanement, comme des geants, plonges dans les annees, a des epoques vecues par eux, si distantes - entre lesquelles tant de jours sont venus se placer - dans le Temps."

Extract from the beginning

In [4]:
println(data[102:102+641])


Longtemps, je me suis couche de bonne heure. Parfois, a peine ma bougie eteinte, mes yeux se fermaient si vite que je n'avais pas le temps de me dire : " Je m'endors. " Et, une demi-heure apres, la pensee qu'il etait temps de chercher le sommeil m'eveillait ; je voulais poser le volume que je croyais avoir encore dans les mains et souffler ma lumiere ; je n'avais pas cesse en dormant de faire des reflexions sur ce que je venais de lire, mais ces reflexions avaient pris un tour un peu particulier ; il me semblait que j'etais moi-meme ce dont parlait l'ouvrage : une eglise, un quatuor, la rivalite de Francois Ier et de Charles-Quint. 


## Extract chararacters from the data

In [5]:
chars = collect(Set(data))
data_size, vocab_size = length(data), length(chars)
println("data has $data_size characters, $vocab_size unique.")
char_to_ix = Dict(ch => i for (i, ch) in enumerate(chars))
ix_to_char = Dict(i => ch for (i, ch) in enumerate(chars))

data has 7251410 characters, 77 unique.


Dict{Int64, Char} with 77 entries:
  5  => '4'
  56 => 'm'
  16 => 'a'
  20 => '3'
  35 => '9'
  55 => 'k'
  60 => 'Y'
  30 => '?'
  19 => 'D'
  32 => 'v'
  49 => 'B'
  6  => 's'
  67 => 'J'
  45 => 'z'
  44 => 'C'
  9  => 'r'
  31 => 'f'
  73 => '('
  74 => 'x'
  ⋮  => ⋮

Hyperparameters

In [12]:
hidden_size = 300  # Size of hidden layer. >200 starts to be slow.
seq_length = 25  # Number of steps to unroll
learning_rate = 1e-1

0.1

Model parameters: Initialization

In [13]:
Wₓ = 0.01 * randn(hidden_size, vocab_size)  # Input to hidden
Wₕ = 0.01 * randn(hidden_size, hidden_size)  # Hidden to hidden
Wᵧ = 0.01 * randn(vocab_size, hidden_size)  # Hidden to output
bₕ = zeros(hidden_size)  # Hidden bias
bᵧ = zeros(vocab_size)  # Output bias

77-element Vector{Float64}:
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 ⋮
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0

# Training and sampling function

In [14]:
function loss_and_grad(inputs, targets, h₀)
    T = length(inputs)

    xₜ, hₜ, yₜ, pₜ = Dict(), Dict(), Dict(), Dict()
    hₜ[0] = h₀
    loss = 0.0

    # Forward pass
    for t in 1:T
        xₜ[t] = zeros(vocab_size)
        xₜ[t][inputs[t]] = 1
        hₜ[t] = tanh.(Wₓ * xₜ[t] + Wₕ * hₜ[t-1] + bₕ)
        yₜ[t] = Wᵧ * hₜ[t] + bᵧ
        pₜ[t] = exp.(yₜ[t]) ./ sum(exp.(yₜ[t]))
        loss += -log(pₜ[t][targets[t]])
    end

    # Backward pass
    ∂Wₓ, ∂Wₕ, ∂Wᵧ = zero(Wₓ), zero(Wₕ), zero(Wᵧ)
    ∂bₕ, ∂bᵧ = zero(bₕ), zero(bᵧ)
    ∂hₜ₊₁ = zero(hₜ[1])

    for t in reverse(1:T)
        ∂yₜ = copy(pₜ[t])
        ∂yₜ[targets[t]] -= 1
        ∂Wᵧ .+= ∂yₜ * hₜ[t]'
        ∂bᵧ .+= ∂yₜ
        ∂hₜ = Wᵧ' * ∂yₜ + ∂hₜ₊₁
        ∂hₜraw = (1 .- hₜ[t] .^ 2) .* ∂hₜ
        ∂bₕ += ∂hₜraw
        ∂Wₓ .+= ∂hₜraw * xₜ[t]'
        ∂Wₕ .+= ∂hₜraw * hₜ[t-1]'
        ∂hₜ₊₁ = Wₕ' * ∂hₜraw
    end

    # Clip gradients to avoid exploding gradients
    for ∂param in [∂Wₓ, ∂Wₕ, ∂Wᵧ, ∂bₕ, ∂bᵧ]
        ∂param .= clamp.(∂param, -5, 5)
    end
    # return loss and gradients and latest memory state
    return loss, ∂Wₓ, ∂Wₕ, ∂Wᵧ, ∂bₕ, ∂bᵧ, hₜ[T]
end

loss_and_grad (generic function with 1 method)

In [15]:
function sampleRNN(h, seed_ix, n)
    x = zeros(vocab_size)
    x[seed_ix] = 1
    ixes = Int[]

    for t in 1:n
        h = tanh.(Wₓ * x + Wₕ * h + bₕ)
        y = Wᵧ * h + bᵧ
        p = exp.(y) ./ sum(exp.(y)) # softmax
        ix = searchsortedfirst(cumsum(p), rand()) # <=> rand(Categorical(p)) # draw at random with weights p from the available chars
        x .= 0
        x[ix] = 1
        push!(ixes, ix)
    end
    return ixes
end

sampleRNN (generic function with 1 method)

# Training

## Load trained parameters
If you want to load or save pre train parameters (instead of random) uncomment the following code (given you have save parameters!), otherwise skip this part.

In [16]:
# using JLD2
# θ = JLD2.load("weights_proust.jld2")
# Wₓ = θ["Wₓ"]
# Wₕ = θ["Wₕ"]
# Wᵧ = θ["Wᵧ"]
# bₕ = θ["bₕ"]
# bᵧ = θ["bᵧ"]

## Training will begin

In [17]:
iter = 0
p = 0
epoch = 0
mWₓ, mWₕ, mWᵧ = zero(Wₓ), zero(Wₕ), zero(Wᵧ)
mbₕ, mbᵧ = zero(bₕ), zero(bᵧ)
losses = Float64[]
smooth_loss = -log(1.0 / vocab_size) * seq_length

108.5951355463421

In [None]:
batch_size = 2*400_000 # act like the usual batch_size

In [18]:
@time for n in 1:batch_size
    if p + seq_length + 1 >= data_size || iter == 0
        global hprev = zeros(hidden_size)
        p = 0
        epoch += 1
    end
    inputs = [char_to_ix[ch] for ch in data[p+1:p+seq_length]]
    targets = [char_to_ix[ch] for ch in data[p+2:p+seq_length+1]]

    if n % 1000 == 0
        sample_ix = sampleRNN(hprev, inputs[1], 200)
        txt = join([ix_to_char[ix] for ix in sample_ix])
        println("\n--Epoch = $(epoch), iter = $(iter) --- train data -> $(100round(p/data_size, digits = 2))%\n$txt\n------\n")
    end

    loss, ∂Wₓ, ∂Wₕ, ∂Wᵧ, ∂bₕ, ∂bᵧ, hprev = loss_and_grad(inputs, targets, hprev)
    smooth_loss = smooth_loss * 0.999 + loss * 0.001 # arbitrary smoothing

    if n % 1000 == 0
        push!(losses, smooth_loss)
        println("loss: $smooth_loss \n----------------\n\n")
        # work with `using Plots` to display the training loss at every 1000 iterations
        # plt = plot(losses) # work with `using Plots`
        # display(plt)
    end

    # Parameter update with Adagrad method https://paperswithcode.com/method/adagrad
    for (param, ∂param, mem) in zip([Wₓ, Wₕ, Wᵧ, bₕ, bᵧ],
        [∂Wₓ, ∂Wₕ, ∂Wᵧ, ∂bₕ, ∂bᵧ],
        [mWₓ, mWₕ, mWᵧ, mbₕ, mbᵧ])
        mem .= mem .+ ∂param .^ 2
        param .+= -learning_rate * ∂param ./ sqrt.(mem .+ 1e-8)
    end
    p += seq_length
    iter += 1
end


--Epoch = 1, iter = 999 --- train data -> 0.0%
ssue' it*nvaiatau tesen u panrl  iedpeeuun. vdere uvtde scslerCqurt,etueZepanlnotJirtteve e oessn naur yces eatqvit pvs ipvosnla rur.  onr cius,';f"nns,u srie r:r pss'raiqcn rrishr e  cobi onaK ovs ae
------

loss: 97.51503079754681 
----------------



--Epoch = 1, iter = 1999 --- train data -> 1.0%
t el'cesemaber e quu denre les oe jnire Due t tonn O'en s de3cemauose"tdas mde rete q'rai s6 de pq seqt enrse   o
 dildnl lous eou? juos iss
''it . de acev jet der  as q etiise m Gen, Mw lespn sere re
------

loss: 80.4158340915369 
----------------



--Epoch = 1, iter = 2999 --- train data -> 1.0%
xoi, Qtqsqnstecure te as laneque. rlx de dai chie aj,sines me me s de ranveute 2ur lincamoudeunt gans laintoWet pom'e jnas coutrs du redet p cie vounat es arvant 
------

loss: 70.07535234048575 
----------------



--Epoch = 1, iter = 3999 --- train data -> 1.0%
ut na et jig ltga de cohce qhit qupebdos roicnt l'oves, e quu, potomoutrerit d'e vetb

## Save trained parameters & generate

In [19]:
using JLD2
jldsave("weights.jld2"; Wₓ, Wₕ, Wᵧ, bₕ, bᵧ)

Generate a sequence with the latest parameters.

In [34]:
begin
    ii = 11
    nb_of_generated_char = 400
    sample_ix = sampleRNN(hprev, char_to_ix[ix_to_char[ii]], nb_of_generated_char)
    txt = join(vcat(ix_to_char[ii],[ix_to_char[ix] for ix in sample_ix]))
    println("\n----\n$txt\n----\n")
end


----
lister de Ctez de Ces que M. de ux vid'il un pesqueller servelte albempure, ascetie, rophesse et plais fait a la meroili, tout " doyent ou hangonsanc. Passignes, ces chandait mon funait-et a vie, l'avac-n'aissi qu'en nemme d'un prosser du suemes, ause attifues aux qu'a dites touveiller que Mme de parsant et orraire. Il confer d'ixsasere chore compre, oraction. De ronde de je elliit seulereut oan qu
----



---

*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*