In [2]:
using Flux
using Flux: crossentropy, throttle
using Flux.Data: Tree, children, isleaf

In [4]:
using Flux: onehot
using Flux.Data.Sentiment
using Flux.Data: leaves

traintrees = Sentiment.train()

# Get the raw labels and phrases as separate trees.
labels  = map.(x -> x[1], traintrees)
phrases = map.(x -> x[2], traintrees)

# All tokens in the training set.
tokens = vcat(map(leaves, phrases)...)

# Count how many times each token appears.
freqs = Dict{String,Int}()
for t in tokens
  freqs[t] = get(freqs, t, 0) + 1
end

# Replace singleton tokens with an "unknown" marker.
# This roughly cuts our "alphabet" of tokens in half.
phrases = map.(t -> get(freqs, t, 0) == 1 ? "UNK" : t, phrases)

# Our alphabet of tokens.
alphabet = unique(vcat(map(leaves, phrases)...))

# One-hot-encode our training data with respect to the alphabet.
phrases_e = map.(t -> t == nothing ? t : onehot(t, alphabet), phrases)
labels_e  = map.(t -> onehot(t, 0:4), labels)

train = map.(tuple, phrases_e, labels_e);

In [5]:
train[1]

Tree{Any}
(nothing, Bool[false, false, false, true, false])
├─ (nothing, Bool[false, false, true, false, false])
│  ├─ (Bool[true, false, false, false, false, false, false, false, false, false  …  false, false, false, false, false, false, false, false, false, false], Bool[false, false, true, false, false])
│  └─ (Bool[false, true, false, false, false, false, false, false, false, false  …  false, false, false, false, false, false, false, false, false, false], Bool[false, false, true, false, false])
└─ (nothing, Bool[false, false, false, false, true])
   ├─ (nothing, Bool[false, false, false, true, false])
   │  ├─ (Bool[false, false, true, false, false, false, false, false, false, false  …  false, false, false, false, false, false, false, false, false, false], Bool[false, false, true, false, false])
   │  └─ (nothing, Bool[false, false, false, false, true])
   │     ├─ (Bool[false, false, false, true, false, false, false, false, false, false  …  false, false, false, false, false, false,

In [None]:
#include("data.jl")

N = 300

embedding = param(randn(Float32, N, length(alphabet)))

W = Dense(2N, N, tanh)
combine(a, b) = W([a; b])

sentiment = Chain(Dense(N, 5), softmax)

function forward(tree)
  if isleaf(tree)
    token, sent = tree.value
    phrase = embedding * token
    phrase, crossentropy(sentiment(phrase), sent)
  else
    _, sent = tree.value
    c1, l1 = forward(tree[1])
    c2, l2 = forward(tree[2])
    phrase = combine(c1, c2)
    phrase, l1 + l2 + crossentropy(sentiment(phrase), sent)
  end
end

loss(tree) = forward(tree)[2]

opt = ADAM()
ps = params(embedding, W, sentiment)
evalcb = () -> @show loss(train[1])

Flux.train!(loss, ps, zip(train), opt,
           cb = throttle(evalcb, 10))

loss(train[1]) = 129.81721f0 (tracked)
loss(train[1]) = 78.744995f0 (tracked)
loss(train[1]) = 76.83441f0 (tracked)
loss(train[1]) = 65.99269f0 (tracked)
loss(train[1]) = 68.21367f0 (tracked)
loss(train[1]) = 70.987f0 (tracked)
loss(train[1]) = 60.92894f0 (tracked)
loss(train[1]) = 59.966213f0 (tracked)
loss(train[1]) = 56.20399f0 (tracked)
loss(train[1]) = 63.450035f0 (tracked)
loss(train[1]) = 56.58547f0 (tracked)
loss(train[1]) = 56.626347f0 (tracked)
loss(train[1]) = 58.17201f0 (tracked)
loss(train[1]) = 54.977734f0 (tracked)
loss(train[1]) = 57.22774f0 (tracked)
loss(train[1]) = 50.27898f0 (tracked)
loss(train[1]) = 51.113247f0 (tracked)
loss(train[1]) = 49.581978f0 (tracked)
loss(train[1]) = 52.639908f0 (tracked)
loss(train[1]) = 50.390205f0 (tracked)
loss(train[1]) = 51.81828f0 (tracked)
loss(train[1]) = 50.022552f0 (tracked)
loss(train[1]) = 49.008026f0 (tracked)
loss(train[1]) = 45.74195f0 (tracked)
loss(train[1]) = 45.52603f0 (tracked)
loss(train[1]) = 47.33557f0 (tracked)
lo

In [7]:
ps

Params([Float32[2.17359 -0.591458 … 1.06074 0.173339; 0.450242 0.827586 … -0.771939 0.288859; … ; -1.19872 -0.867446 … -0.0930106 -1.16869; -1.12999 2.1004 … -0.0562822 2.06688] (tracked), Float32[-0.00971816 0.111023 … -0.0348003 0.0508199; 0.0214682 -0.0229316 … 0.167756 -0.0172309; … ; -0.0823738 -0.0306393 … 0.0550236 -0.0172346; -0.0952964 0.00343107 … 0.0703971 -0.103234] (tracked), Float32[0.525751, -0.117234, 0.167475, -0.387744, -0.0301813, -0.0871347, 0.0651157, 0.324827, -0.393376, -0.201326  …  0.127852, 0.0528466, -0.25958, 0.0782037, 0.253156, 0.303237, -0.174336, -0.0256605, -0.673184, -0.690059] (tracked), Float32[0.0356477 0.00622747 … -0.0344168 0.0732029; 0.0875137 0.0757835 … -0.0358825 -0.0796985; … ; 0.0570049 -0.0583903 … 0.0141113 -0.0525645; 0.0623812 -0.0643861 … 0.0111935 0.0352257] (tracked), Float32[-1.62962, -0.490181, 1.45246, -0.387184, -1.56789] (tracked)])

In [11]:
labels

8544-element Array{Tree{Any},1}:
 Tree{Any}
3
├─ 2
│  ├─ 2
│  └─ 2
└─ 4
   ├─ 3
   │  ├─ 2
   │  └─ 4
   │     ├─ 2
   │     └─ 2
   │        ├─ 2
   │        │  ├─ 2
   │        │  │  ├─ 2
   │        │  │  │  ├─ 2
   │        │  │  │  └─ 2
   │        │  │  │     ├─ 2
   │        │  │  │     └─ 2
   │        │  │  │        ├─ 2
   │        │  │  │        └─ 2
   │        │  │  │           ├─ 2
   │        │  │  │           └─ 2
   │        │  │  │              ├─ 2
   │        │  │  │              │  ├─ 2
   │        │  │  │              │  └─ 2
   │        │  │  │              └─ 2
   │        │  │  │                 ├─ 3
   │        │  │  │                 └─ 2
   │        │  │  │                    ├─ 2
   │        │  │  │                    └─ 2
   │        │  │  └─ 2
   │        │  └─ 2
   │        └─ 3
   │           ├─ 2
   │           └─ 3
   │              ├─ 2
   │              └─ 3
   │                 ├─ 2
   │                 └─ 3
   │                    ├─ 2
   │       

In [12]:
phrases[1]

Tree{Any}
nothing
├─ nothing
│  ├─ "The"
│  └─ "Rock"
└─ nothing
   ├─ nothing
   │  ├─ "is"
   │  └─ nothing
   │     ├─ "destined"
   │     └─ nothing
   │        ├─ nothing
   │        │  ├─ nothing
   │        │  │  ├─ nothing
   │        │  │  │  ├─ "to"
   │        │  │  │  └─ nothing
   │        │  │  │     ├─ "be"
   │        │  │  │     └─ nothing
   │        │  │  │        ├─ "the"
   │        │  │  │        └─ nothing
   │        │  │  │           ├─ "21st"
   │        │  │  │           └─ nothing
   │        │  │  │              ├─ nothing
   │        │  │  │              │  ├─ "Century"
   │        │  │  │              │  └─ "'s"
   │        │  │  │              └─ nothing
   │        │  │  │                 ├─ "new"
   │        │  │  │                 └─ nothing
   │        │  │  │                    ├─ "``"
   │        │  │  │                    └─ "Conan"
   │        │  │  └─ "''"
   │        │  └─ "and"
   │        └─ nothing
   │           ├─ "that"
   │           └─ 

In [16]:
phrases[2][1][2][1]

Tree{Any}
nothing
├─ "of"
└─ "``"


In [22]:
function render(buf::Vector{String}, tree)
    if tree.value != nothing
        push!(buf, tree.value)
    end
    
    for c in tree.children
        render(buf, c)
    end
    return buf
end


render (generic function with 1 method)

In [45]:
join(render(String[], phrases[21]), " ")

"You 'll probably love it ."

In [47]:
forward(train[21])

(Float32[-0.999944, -0.999753, -0.965345, -0.999998, 0.99979, 0.983871, -0.999845, 0.999993, 0.998173, -0.999967  …  -0.972358, -0.999999, -0.945017, -0.999944, -0.954841, 0.997816, -0.785472, 0.67466, -0.921032, -0.999986] (tracked), 8.622866f0 (tracked))

In [48]:
labels[21]

Tree{Any}
4
├─ 2
└─ 4
   ├─ 3
   │  ├─ 2
   │  │  ├─ 2
   │  │  └─ 2
   │  └─ 4
   │     ├─ 4
   │     └─ 2
   └─ 2
