In [127]:
require 'nn'
require 'rnn'

In [128]:
maskLayer = nn.MaskedSelect()

In [129]:
function generateNbyK(n, k, a, b)
    local out = {}
    for i = 1, n do
        local tmp = {}
        for k = 1, k do
            tmp[k] = torch.random(a, b)
        end
        out[i] = tmp
    end
    return out
end

In [130]:
n_q = 10
n_s = 10

-- Generating the queries
queries = torch.LongTensor(generateNbyK(n_q, 4, 0, 10)):resize(n_q, 4)

-- Generating the streams
streams = torch.LongTensor(generateNbyK(n_q, 15, 0, 100))

-- Predicted Summary
predsummary = torch.zeros(n_q, 30)

--- Best possible actions
actions = torch.zeros(n_q, 2)

for i = 1, actions:size(1) do
    if (i % 2)==0 then
        actions[i][2] = 1
    end
end
actions = actions:byte()

In [131]:
batchSize = 10
embeddingSize = 5
vocabSize = 100
sentenceLookup = nn.Sequential()
                    :add(nn.LookupTableMaskZero(vocabSize, embeddingSize))
                    :add(nn.SplitTable(2))
                    :add(nn.Sequencer(nn.LSTM(embeddingSize, embeddingSize)))
                    :add(nn.SelectTable(-1))            -- selects last state of the LSTM
                    :add(nn.Linear(embeddingSize, embeddingSize))
                    :add(nn.ReLU())

queryLookup = sentenceLookup:clone("weight", "gradWeight") 
summaryLookup = sentenceLookup:clone("weight", "gradWeight")
pmodule = nn.ParallelTable()
            :add(sentenceLookup)
            :add(queryLookup)
            :add(summaryLookup)

nnmodel = nn.Sequential()
        :add(pmodule)
        :add(nn.JoinTable(2))
        :add(nn.ReLU())
        :add(nn.Linear(embeddingSize * 3, 2))

qValues = nnmodel:forward({streams, queries, predsummary})
print(qValues)

 0.2015  0.2582
 0.2086  0.2576
 0.2267  0.2594
 0.2328  0.2402
 0.2157  0.2258
 0.2241  0.2221
 0.2270  0.2444
 0.2416  0.2306
 0.2373  0.2378
 0.2475  0.2336
[torch.DoubleTensor of size 10x2]



In [None]:
predQOnActions = maskLayer:forward({qVlaues, actions_in}) 

In [137]:
-- Choosing max values
qValues = torch.rand(n_s, 2)
qMax, qindx = torch.max(qValues, 2)

In [138]:
actions = actions:double()

In [139]:
actions

 0  0
 0  1
 0  0
 0  1
 0  0
 0  1
 0  0
 0  1
 0  0
 0  1
[torch.DoubleTensor of size 10x2]



In [140]:
preds = torch.randn(10, 2)

In [141]:
qMax, qindx = torch.max(preds, 2)

In [142]:
actions = torch.zeros(qindx:size(1), 2)

In [143]:
actions

 0  0
 0  0
 0  0
 0  0
 0  0
 0  0
 0  0
 0  0
 0  0
 0  0
[torch.DoubleTensor of size 10x2]



# Min example

In [112]:
qValues = torch.rand(10, 2)
actions = torch.zeros(10, 2)
-- Generating the max values and getting the indices
qMax, qindx = torch.max(qVlaues, 2)

In [115]:
qindx

 1
 2
 1
 1
 2
 2
 2
 1
 2
 1
[torch.LongTensor of size 10]



In [116]:
--- I want to select 
actions:index(1, qindx:resize(10))

 0  0
 0  0
 0  0
 0  0
 0  0
 0  0
 0  0
 0  0
 0  0
 0  0
[torch.DoubleTensor of size 10x2]



In [21]:
sentenceLookup = nn.Sequential()
            :add(nn.LookupTableMaskZero(vocabSize, embeddingSize))
            :add(nn.Sum(2, 3, true)) -- Not averaging blows up model so keep this true
            :add(nn.Tanh())

queryLookup = sentenceLookup:clone("weight", "gradWeight") 
summaryLookup = sentenceLookup:clone("weight", "gradWeight")

pmodule = nn.ParallelTable()
        :add(sentenceLookup)
        :add(queryLookup)
        :add(summaryLookup)

-- nnmodel = nn.Sequential()
--         :add(pmodule)
--         :add(nn.JoinTable(2))
--         :add(nn.Tanh())
--         :add(nn.Linear(embeddingSize * 3, 2))

-- regmodel = nn.Sequential()
--     :add(pmodule)
--     :add(nn.JoinTable(2))
--     :add(nn.Linear(embeddingSize * 3, 1))
--     :add(nn.LogSigmoid())
--     :add(nn.SoftMax())


-- final = nn.ConcatTable()
--     :add(nnmodel)
--     :add(regmodel)

logmod = nn.Sequential()
    :add(nn.Linear(embeddingSize * 3, 1))
    :add(nn.LogSigmoid())
    :add(nn.SoftMax())

regmod = nn.Sequential()
    :add(nn.Linear(embeddingSize * 3, 2))

fullmod = nn.ConcatTable()
    :add(regmod)
    :add(logmod)

nnmodel = nn.Sequential()
    :add(pmodule)
    :add(nn.JoinTable(2))
    :add(fullmod)

In [22]:
nnmodel:forward({streams, queries, predsummary})

{
  1 : DoubleTensor - size: 10x2
  2 : DoubleTensor - size: 10x1
}


In [23]:
print({nnmodel:forward({streams, queries, predsummary}),regmodel:forward({streams, queries, predsummary})})
print(final:forward({streams, queries, predsummary}))

{
  1 : 
    {
      1 : DoubleTensor - size: 10x2
      2 : DoubleTensor - size: 10x1
    }
  2 : DoubleTensor - size: 10x1
}


{
  1 : DoubleTensor - size: 10x2
  2 : DoubleTensor - size: 10x1
}


In [24]:
preds = nnmodel:forward({streams, queries, predsummary})

In [26]:
criterion = nn.ParallelCriterion():add(nn.MSECriterion()):add(nn.BCECriterion())

In [27]:
qValues = torch.rand(10, 2)
ones = torch.ones(10, 1)

In [28]:
preds

{
  1 : DoubleTensor - size: 10x2
  2 : DoubleTensor - size: 10x1
}


In [29]:
criterion:forward(preds, {qVlaues, ones})

0.24190414505724	


In [30]:
criterion:backward(preds, {qVlaues, ones})

{
  1 : DoubleTensor - size: 10x2
  2 : DoubleTensor - size: 10x1
}
