In [None]:
require 'torch'
require 'nn'
require 'rnn'

-- our vocabulary
V = {["I"]=1, ["you"]=2, ["the"]=3, ["this"]=4, ["to"]=5, ["fire"]=6, ["Hey"]=7, ["is"]=8, 
    ["just"]=9, ["zee"]=10, ["And"]=11, ["just"]=12, ["rain"]=13, ["cray"]=14, ["met"]=15, ["Set"]=16}

-- get indices of words in each 5-gram
-- zero padding is the issue
songData = torch.LongTensor({ { V["Hey"],   V["I"],      V["just"],  V["met"],    V["you"] },
                              { V["And"],   V["this"],   V["is"],    V["cray"],   0 },
                              ---{ V["And"],   V["this"],   V["is"],    V["cray"],   V["zee"] },
                              { V["Set"],   V["fire"],   V["to"],    V["the"],    V["rain"] } })

masterpieceOrNot = torch.Tensor({{1},   -- #carlyrae4ever   
                                 {1},
                                 {0}}) 

print(songData)
-- we'll use a LookupTable to map word indices into vectors in R^6
vocab_size = 16
embed_dim = 6
LT = nn.LookupTable(vocab_size, embed_dim)

-- For batch inputs, it's a little easier to start with sequence-length x batch-size tensor, so we transpose songData
songDataT = songData:t()
batchSongLSTM = nn.Sequential()
batchSongLSTM:add(LT) -- will return a sequence-length x batch-size x embedDim tensor
batchSongLSTM:add(nn.SplitTable(1, 3)) -- splits into a sequence-length table with batch-size x embed Dim entries
print(batchSongLSTM:forward(songDataT)) -- sanity check
-- now let's add the LSTM stuff
batchSongLSTM:add(nn.Sequencer(nn.LSTM(embed_dim, embed_dim)))
batchSongLSTM:add(nn.SelectTable(-1)) -- selects last state of the LSTM
batchSongLSTM:add(nn.Linear(embed_dim, 1)) -- map last state to a score for classification
batchSongLSTM:add(nn.Sigmoid()) -- convert score to a probability
songPreds = batchSongLSTM:forward(songDataT)
print(songPreds)

-- we can now call :backward() as follows
bceCrit = nn.BCECriterion()
loss = bceCrit:forward(songPreds, masterpieceOrNot)
dLdPreds = bceCrit:backward(songPreds, masterpieceOrNot)
batchSongLSTM:backward(songDataT, dLdPreds)

print(loss)

In [4]:
#songData

 3
 5
[torch.LongStorage of size 2]



In [3]:
songData

  7   1  12  15   2
 11   4   8  14  10
 16   6   5   3  13
[torch.LongTensor of size 3x5]



In [2]:
masterpieceOrNot

 1
 1
 0
[torch.DoubleTensor of size 3x1]

