In [1]:
using Zygote

In [2]:
include("ngrams.jl")
include("datasetloader.jl")

In [60]:
function normalize!(P::Dict{Vector{UInt8}, Float32})
    vs = sum(values(P))
    map!(v -> v / vs, values(P))
    P
end
sigmoid(x) = 1 / (1 + exp(-x))
logit(x) = -log((1 / x) - 1)
log_sigmoid(x) = -log1p(exp(-x)) # -softplus(-x)


log_sigmoid (generic function with 1 method)

In [4]:
const PACKAGE_PATH = "."
ALL_LANGUAGES = [f[1:end-4] for f in readdir(joinpath(PACKAGE_PATH, "ngrams"))]
lang2index = Dict(lang => i for (i, lang) in enumerate(ALL_LANGUAGES))

Dict{String, Int64} with 2 entries:
  "zho" => 2
  "eng" => 1

In [30]:
Qs = Vector{Dict{Vector{UInt8}, Float32}}()
for lang in ALL_LANGUAGES
    h, D = load_ngram_table(joinpath("ngrams", lang * ".txt"))
    D = Dict(D)
    normalize!(D)
    map!(logit, values(D))
    push!(Qs, D)
end
DEFAULT_Q::Float32 =  minimum(minimum.(values.(Qs)))


-13.900675f0

In [73]:
function loglikelihood(P, Q, default_q)
    sc = 0.0
    for (code, p) in P
        q = haskey(Q, code) ? Q[code] : default_q
        sc += p * log_sigmoid(q)
    end
    sc
end
function loss(lls, ys)
    sum(ys .* (log.(sum(exp.(lls))) .- lls)) # softmax & cross entropy
end
function get_loss(params, p, ys)
    lls = [loglikelihood(p, Q, params.default_q) for Q in params.Qs]
    loss(lls, ys)
end
function loss_and_grad(params, p::AbstractDict, ys::AbstractVector)
    withgradient(Params([params])) do 
        get_loss(params, p, ys)
    end
end
function loss_and_grad(params, x::AbstractString, y::AbstractString)
    yi = lang2index[y]
    onehot = zeros(Float32, length(ALL_LANGUAGES))
    onehot[yi] = 1.0
    p = count_all_ngrams(x, 5) |> normalize!
    loss_and_grad(params, p, onehot)
end

loss_and_grad (generic function with 2 methods)

In [74]:
text = "hello world loglikelihood"
p = count_all_ngrams(text, 5)
normalize!(p)

params = (Qs=Qs, default_q=DEFAULT_Q)
val, grad = loss_and_grad(params, p, [1.0, 0.0])

(val = 0.07368449732666171, grad = Grads(...))

In [40]:
grad[params].Qs

2-element Vector{Dict{Any, Any}}:
 Dict(UInt8[0x6c, 0x69, 0x6b, 0x65] => -0.00056824373f0, UInt8[0x68, 0x6f] => -0.00056795677f0, UInt8[0x65, 0x6c, 0x69] => -0.0005682644f0, UInt8[0x6f, 0x6f, 0x64] => -0.00056825485f0, UInt8[0x6c, 0x6c, 0x6f, 0x20] => -0.00056829443f0, UInt8[0x6f, 0x64, 0x20] => -0.0005682523f0, UInt8[0x67, 0x6c, 0x69] => -0.00056828774f0, UInt8[0x68, 0x6f, 0x6f] => -0.00056827115f0, UInt8[0x77, 0x6f, 0x72] => -0.0005682291f0, UInt8[0x72, 0x6c, 0x64] => -0.0005682818f0…)
 Dict(UInt8[0x72, 0x6c] => 0.0005682938f0, UInt8[0x67, 0x6c] => 0.0005682945f0, UInt8[0x6f, 0x6f] => 0.00056829123f0, UInt8[0x77] => 0.0005682352f0, UInt8[0x6f, 0x20] => 0.00056826556f0, UInt8[0x68, 0x6f] => 0.0005682883f0, UInt8[0x65, 0x6c] => 0.0011365679f0, UInt8[0x6b] => 0.0005682598f0, UInt8[0x20, 0x77] => 0.0005682805f0, UInt8[0x6c] => 0.0034090604f0…)

In [42]:
WV = WikiDataSet("corpus/wikipedia/test", langs=ALL_LANGUAGES)

400-element WikiDataSet:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               

In [43]:
x, y = WV[rand(1:length(WV))]

val, grad = loss_and_grad(params, x, y)

(val = 0.05836776099029637, grad = Grads(...))

In [46]:
function step!(grad, lr=1e-3)
    global DEFAULT_Q -= lr * grad[params].default_q
    for (D1, D2) in zip(params.Qs, grad[params].Qs)
        mergewith!((v1, v2) -> v1 - lr * v2, D1, D2)
    end
end

step! (generic function with 2 methods)

In [54]:
step!(grad)
params

(Qs = Dict{Vector{UInt8}, Float32}[Dict([0x6e, 0x74, 0x20, 0x69, 0x6e, 0x74] => -13.408537, [0x6c, 0x65, 0x20, 0x61, 0x72] => -12.615268, [0x65, 0x65, 0x69] => -12.780086, [0x61, 0x64, 0x6f, 0x70, 0x74, 0x65] => -13.538923, [0x67, 0x65, 0x72, 0x69, 0x61, 0x6e, 0x73] => -13.684971, [0x72, 0x65, 0x6d, 0x69, 0x65, 0x72] => -12.913396, [0x6b, 0x20, 0x75, 0x70] => -12.88174, [0x68, 0x61, 0x74, 0x20, 0x6e, 0x65] => -13.169395, [0x20, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64] => -11.852873, [0x64, 0x20, 0x77, 0x61, 0x69, 0x74] => -13.835136…), Dict([0x98, 0xe6, 0xb4] => -13.081837, [0xbf, 0x90, 0xe5, 0x9f, 0x8e, 0xe5] => -11.856253, [0xaa, 0xe4, 0xba, 0x8e] => -13.079651, [0x81, 0xe5, 0x8c, 0x96] => -13.15991, [0x9b, 0xae, 0xe7, 0x9a, 0x84, 0xe5] => -13.119659, [0x88, 0x9b] => -11.285851, [0xba, 0xa6, 0xe5, 0x9c] => -13.412644, [0xaf, 0x86, 0xe5, 0xba, 0xa6, 0x20] => -12.980187, [0xbb, 0xba, 0xe7, 0xaf] => -12.284087, [0x8f, 0x90, 0xe4, 0xbe, 0x9b, 0x20] => -13.447279…)], default_q = -13.900675f0)

In [55]:
DEFAULT_Q

-13.900751f0