In [1]:
using Flux
using Flux: onehot, chunk, batchseq, throttle, crossentropy
using StatsBase: wsample
using Base.Iterators: partition


### Examples

In [147]:
onehot('a',['y','a','i'])

3-element Flux.OneHotVector:
 false
  true
 false

In [148]:
x= chunk([1,2,3,4,5,6],3)

3-element Array{Array{Int64,1},1}:
 [1, 2]
 [3, 4]
 [5, 6]

In [149]:
batchseq(x)

2-element Array{Array{Int64,1},1}:
 [1, 3, 5]
 [2, 4, 6]

In [150]:
x= chunk([1,2,3,4,5,6,7,8,9,10],3)

3-element Array{Array{Int64,1},1}:
 [1, 2, 3, 4]
 [5, 6, 7, 8]
 [9, 10]     

In [151]:
batchseq(x)

4-element Array{Array{T,1} where T,1}:
 [1, 5, 9]                           
 [2, 6, 10]                          
 Union{Nothing, Int64}[3, 7, nothing]
 Union{Nothing, Int64}[4, 8, nothing]

In [152]:
batchseq(x,100)

4-element Array{Array{Int64,1},1}:
 [1, 5, 9]  
 [2, 6, 10] 
 [3, 7, 100]
 [4, 8, 100]

In [222]:
u=IOBuffer()
t=wsample(['a','k','l'],[0.2,0.8,0.05])
write(u,t)
String(take!(u))

"k"

###########################################################################################

In [3]:
cd(@__DIR__)

isfile("input.txt") ||
  download("https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt",
           "input.txt")

text = collect(String(read("input.txt")));
alphabet = [unique(text)..., '_'];
text = map(ch -> onehot(ch, alphabet), text) # array of 4M and each one is 68 length
stop = onehot('_', alphabet); # array of 68 length

In [3]:
o=chunk(text, 50); # Array of 50  .. each one is 91,467 .. each one is 68 length  
@show size(o)
@show size(o[1])
@show size(o[end])
@show size(o[1][1])


size(o) = (50,)
size(o[1]) = (91467,)
size(o[end]) = (91455,)
size((o[1])[1]) = (68,)


(68,)

In [4]:
b=batchseq( chunk(text, 50),stop)
@show size(b)
@show size(b[1])
@show size(b[end])

size(b) = (91467,)
size(b[1]) = (68, 50)
size(b[end]) = (68, 50)


(68, 50)

In [5]:
d= collect(partition(b,50)); # 91467/50 -> 1830 

################################################################################################################

In [133]:
N = length(alphabet)
seqlen = 50
nbatch = 40

Xs = collect( partition( batchseq( chunk(text, nbatch), stop), seqlen));
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen));



In [147]:
global counter_ip=0

0

In [148]:
function ff(ii)
    println( counter_ip , " ", size(ii))
    global counter_ip+=1
    return ii
end

ff (generic function with 1 method)

In [149]:
global counter=0
m = Chain(
  x-> ff(x) , 
  LSTM(N, 128),
  LSTM(128, 128),
  Dense(128, N),
  softmax)

m = gpu(m)

function loss(xs, ys)
  global counter+=1
  ws= Tracker.data(params(m)[2])
  println(counter," " , ws[1:2,1:3])
  l= sum(crossentropy.(m.(gpu.(xs)), gpu.(ys)))
  return l
end

opt = ADAM(params(m), 0.01);
tx, ty = (gpu.(Xs[5]), gpu.(Ys[5]));


In [150]:
Flux.train!(loss, zip(Xs, Ys), opt)

1 [-0.090128 -0.0257904 -0.0919167; 0.0725545 0.0122108 0.0210503]
0 (68, 40)
1 (68, 40)
2 (68, 40)
3 (68, 40)
4 (68, 40)
5 (68, 40)
6 (68, 40)
7 (68, 40)
8 (68, 40)
9 (68, 40)
10 (68, 40)
11 (68, 40)
12 (68, 40)
13 (68, 40)
14 (68, 40)
15 (68, 40)
16 (68, 40)
17 (68, 40)
18 (68, 40)
19 (68, 40)
20 (68, 40)
21 (68, 40)
22 (68, 40)
23 (68, 40)
24 (68, 40)
25 (68, 40)
26 (68, 40)
27 (68, 40)
28 (68, 40)
29 (68, 40)
30 (68, 40)
31 (68, 40)
32 (68, 40)
33 (68, 40)
34 (68, 40)
35 (68, 40)
36 (68, 40)
37 (68, 40)
38 (68, 40)
39 (68, 40)
40 (68, 40)
41 (68, 40)
42 (68, 40)
43 (68, 40)
44 (68, 40)
45 (68, 40)
46 (68, 40)
47 (68, 40)
48 (68, 40)
49 (68, 40)
2 [-0.0944669 -0.0245902 -0.0895075; 0.0725689 0.0171154 0.0239312]
50 (68, 40)
51 (68, 40)
52 (68, 40)
53 (68, 40)
54 (68, 40)
55 (68, 40)
56 (68, 40)
57 (68, 40)
58 (68, 40)
59 (68, 40)
60 (68, 40)
61 (68, 40)
62 (68, 40)
63 (68, 40)
64 (68, 40)
65 (68, 40)
66 (68, 40)
67 (68, 40)
68 (68, 40)
69 (68, 40)
70 (68, 40)
71 (68, 40)
72 (68, 40)

InterruptException: InterruptException:

In [110]:
# # Sampling
# m = cpu(m)

# function sample(m, alphabet, len; temp = 1)
#   Flux.reset!(m)
#   buf = IOBuffer()
#   c = rand(alphabet)
#   for i = 1:len
#     write(buf, c)
#     c = wsample(alphabet, m(onehot(c, alphabet)).data)
#   end
#   return String(take!(buf))
# end

# sample(m, alphabet, 1000) |> println

# # evalcb = function ()
# #   @show loss(Xs[5], Ys[5])
# #   println(sample(deepcopy(m), alphabet, 100))
# # end
