In [1]:
import dynet as dy

In [37]:
import random
from collections import defaultdict
from itertools import count
import sys

LAYERS = 2
INPUT_DIM = 40
HIDDEN_DIM = 50

characters = list("abcdefghijklmnopqrstuvwxyz ")
characters.append("<EOS>")

int2char = list(characters)
char2int = {c:i for i,c in enumerate(characters)}

VOCAB_SIZE = len(characters)

In [38]:
pc = dy.ParameterCollection()

In [39]:
rnn = dy.SimpleRNNBuilder(LAYERS, INPUT_DIM, HIDDEN_DIM, pc)
lstm = dy.LSTMBuilder(LAYERS, INPUT_DIM, HIDDEN_DIM, pc)
gru = dy.GRUBuilder(LAYERS, INPUT_DIM, HIDDEN_DIM, pc)

In [40]:
params = {}
params["lookup"] = pc.add_lookup_parameters((VOCAB_SIZE, INPUT_DIM))
params["R"] = pc.add_parameters((VOCAB_SIZE, HIDDEN_DIM))
params["bias"] = pc.add_parameters((VOCAB_SIZE))

In [39]:
# return compute loss of RNN for one sentence
def do_one_sentence(rnn, sentence):
    # setup the sentence
    dy.renew_cg()
    s0 = rnn.initial_state()

    R = dy.parameter(params["R"])
    bias = dy.parameter(params["bias"])
    lookup = params["lookup"]
    
    sentence = ["<EOS>"] + list(sentence) + ["<EOS>"]
    sentence = [char2int[c] for c in sentence]
    s = s0
    loss = []
    for char,next_char in zip(sentence,sentence[1:]):
        s = s.add_input(lookup[char])
        probs = dy.softmax(R*s.output() + bias)
        loss.append( -dy.log(dy.pick(probs,next_char)) )
    loss = dy.esum(loss) # element wise sum
    return loss

In [71]:
# generate from model:
def generate(rnn):
    def sample(probs):
        rnd = random.random()
        for i,p in enumerate(probs):
            rnd -= p
            if rnd <= 0: break
        return i

    # setup the sentence
    dy.renew_cg()
    s0 = rnn.initial_state()

    R = dy.parameter(params["R"])
    bias = dy.parameter(params["bias"])
    lookup = params["lookup"]

    s = s0.add_input(lookup[char2int["<EOS>"]])
    out=[]
    while True:
        probs = dy.softmax(R*s.output() + bias)
        probs = probs.vec_value()
        next_char = sample(probs)
        out.append(int2char[next_char])
        if out[-1] == "<EOS>": break
        s = s.add_input(lookup[next_char])
    return "".join(out[:-1]) # strip the <EOS>

In [42]:
# train, and generate every 5 samples
def train(rnn, sentence):
    trainer = dy.SimpleSGDTrainer(pc)
    for i in range(200):
        loss = do_one_sentence(rnn, sentence)
        loss_value = loss.value()
        loss.backward()
        trainer.update()
        if i % 5 == 0:
            print('loss value: {}'.format(loss_value))
            print(generate(rnn))

### single step

In [101]:
dy.renew_cg()
s0 = rnn.initial_state()

R = dy.parameter(params["R"])
bias = dy.parameter(params["bias"])
lookup = params["lookup"]

sentence = "a quick brown fox jumped over the lazy dog"
sentence = ["<EOS>"] + list(sentence) + ["<EOS>"]
sentence = [char2int[c] for c in sentence]
s = s0
loss = []
for char,next_char in zip(sentence,sentence[1:]):
    s = s.add_input(lookup[char])
    probs = dy.softmax(R*s.output() + bias) # the probability of each character
    loss.append( -dy.log(dy.pick(probs,next_char)) )
loss = dy.esum(loss) # element wise sum


In [102]:
trainer = dy.SimpleSGDTrainer(pc)
loss_value = loss.value() # forward
loss.backward()
trainer.update()

print('loss value: {}'.format(loss_value))

loss value: 86.3531494140625


In [103]:
generate(rnn)

'c zrmx hopg'

In [104]:
lookup[0].value()

[-0.057095859199762344,
 0.03747761249542236,
 0.0346369743347168,
 0.2968907952308655,
 0.2164282649755478,
 -0.0005062874406576157,
 0.22434881329536438,
 0.21788105368614197,
 0.09886930137872696,
 -0.13218142092227936,
 -0.04024209454655647,
 0.18948598206043243,
 0.2418624311685562,
 -0.08214465528726578,
 0.11116494983434677,
 -0.07358702272176743,
 -0.16337384283542633,
 0.14538949728012085,
 0.037504035979509354,
 -0.18335182964801788,
 0.09826783090829849,
 0.03053700365126133,
 0.28696954250335693,
 -0.19025973975658417,
 0.08363141864538193,
 -0.10596631467342377,
 -0.16728824377059937,
 -0.17335285246372223,
 0.07082431018352509,
 0.06108779460191727,
 0.1542903631925583,
 0.2096983790397644,
 -0.17789234220981598,
 -0.11972109973430634,
 0.2188454568386078,
 0.19367754459381104,
 0.2011972963809967,
 0.14911124110221863,
 -0.14412280917167664,
 0.2826405465602875]

### train with rnn

In [43]:
sentence = "a quick brown fox jumped over the lazy dog"
train(rnn, sentence) 

loss value: 156.53305053710938
obvkwqrwbcumlcx
loss value: 103.2486572265625
cv  l nqhtxujwhmonqbniet whrxjuavex cyf  mesa c fpo fdpprqdgib  webjvmfzhps zviveuav wkvxp dinkwlc uy bipmr gtt rgpmwateum
loss value: 73.2106704711914
usl
loss value: 44.31501388549805
aw  ajg ivesawuligv uzbpnjrkpynfrh
loss value: 25.60333824157715
o 
loss value: 13.611724853515625
wkyuipx brojn loh buote yohwn moe juiped ovpr ooejramte ovf d z edcoxen ovxx ocz  oewc ohg ruwy d ivld over the lazy dov
loss value: 5.072129249572754
z qumckdbdhwn fou jsmped over twe lazy dou
loss value: 1.6574616432189941
a lnick bronn fox jumped over the lazy doe
loss value: 0.809463381767273
a quick brown fox jumped over the lazy dog
loss value: 0.5666172504425049
a quick brown fox jumped over the lazy dog
loss value: 0.4357306957244873
a quick brown fox jumped over the lazy dog
loss value: 0.353680282831192
a quick brown fox jumped over the lazy dog
loss value: 0.2974169850349426
a quick brown fox jumped over the lazy dog
lo

### train with lstm

In [45]:
sentence = "a quick brown fox jumped over the lazy dog"
train(lstm, sentence)

loss value: 21.115440368652344
a quicb kownx fox mmpmed overr tllzyy oag
loss value: 19.564010620117188
 qqiik bronn o ox uumped overrrr tez lzay od
loss value: 17.545866012573242
 qqucck rronn ffox muped over the hllyyy dogg
loss value: 16.65489959716797
 iukkk boww fox jmmmpdd oveert taaay yy oggg 
loss value: 15.41010570526123
 quck brown fox jumee ovvr te llzy dog
loss value: 14.564728736877441
b  uiicb rrowrn foxj juuppddo over thh lyzz doo
loss value: 12.651235580444336
a juick bronn jox mmmpe over tt llye od
loss value: 11.91175651550293
 quccck brownn fon jpmeedd over the lzzy dogg 
loss value: 10.690475463867188
a qucc bbown fox jjmpdd ovrr the llyz og
loss value: 9.658659934997559
aq iick brown fox jummed ovve tle ayz dog
loss value: 8.690759658813477
 uucik brown fox mumpedd over hhe lazy dogg
loss value: 7.7903828620910645
 quic bown fox jmmpe oov ee tee llzy dog
loss value: 6.987741947174072
a quuikk bbrown ox jjmpe over hhe laaz do
loss value: 6.039589881896973
a qiuck br

In [47]:
another_sentence = 'these pretzels are making me thirsty'
train(rnn, another_sentence)

loss value: 364.0517883300781
a quick brown fox jumped over the lazy dog
loss value: 115.98564910888672
a quick brown fox jumped over the jazy dog
loss value: 35.31338119506836
a quick ba wne zhe hazi dvg lazt tog oa jumck baown moe ma tuuck ba wnicke eaetktak woe
loss value: 5.6694207191467285
thes  van mce mazr tharqqhadquick baare maernthseetaayldog jazrnfmped therlthy maz
loss value: 1.7543278932571411
these pretzees ahe making me thirsty
loss value: 0.2038564383983612
ahe making me thirsty
loss value: 0.1282159388065338
these phetzels are making me thirsty
loss value: 0.0984501913189888
these pretzels are making me thirsty
loss value: 0.08078761398792267
these pretzels are making me thirsty
loss value: 0.0689062550663948
these pretzels are making me thirsty
loss value: 0.06029056757688522
these pretzels are making me thirsty
loss value: 0.05371999740600586
these pretzels are making me thirsty
loss value: 0.04852164536714554
these pretzels are making me thirsty
loss value: 0.044297