In [1]:
require 'nn'
require 'rnn'

In [2]:
function genNbyK(n, k, a, b)
    out = torch.LongTensor(n, k)
    for i=1, n do
        for j = 1, k do
            out[i][j] = torch.random(a, b)
        end
    end
    return out
end

In [5]:
trueqValues = torch.rand(10, 2)              ---- Simulating the data
qMaxtrue, qindxtrue = torch.max(trueqValues, 2)  --- Generating the max values and getting the indices
--- I want to select the qindx elements for each row
trueactions = torch.zeros(10, 2):scatter(2, qindxtrue, torch.ones(trueqValues:size()))

-- Simulating streams and queries
sentences = genNbyK(10, 15, 1, 100)
queries = genNbyK(10, 4, 1, 100)

In [6]:
function buildModel(model, vocabSize, embeddingSize, metric, adapt, use_cuda)
    -- Small experiments seem to show that the Tanh activations performed better\
    --      than the ReLU for the bow model
    if model == 'bow' then
        print(string.format("Running bag-of-words model to learn %s", metric))
        sentenceLookup = nn.Sequential()
                    :add(nn.LookupTableMaskZero(vocabSize, embeddingSize))
                    :add(nn.Sum(2, 3, true)) -- Not averaging blows up model so keep this true
                    :add(nn.Tanh())
    else
        print(string.format("Running LSTM model to learn %s", metric))
        sentenceLookup = nn.Sequential()
                    :add(nn.LookupTableMaskZero(vocabSize, embeddingSize))
                    :add(nn.SplitTable(2))
                    :add(nn.Sequencer(nn.LSTM(embeddingSize, embeddingSize)))
                    :add(nn.SelectTable(-1))            -- selects last state of the LSTM
                    :add(nn.Linear(embeddingSize, embeddingSize))
                    :add(nn.ReLU())
    end
    local queryLookup = sentenceLookup:clone("weight", "gradWeight") 
    local summaryLookup = sentenceLookup:clone("weight", "gradWeight")
    local pmodule = nn.ParallelTable()
                :add(sentenceLookup)
                :add(queryLookup)
                :add(summaryLookup)

    if model == 'bow' then
        nnmodel = nn.Sequential()
            :add(pmodule)
            :add(nn.JoinTable(2))
            :add(nn.Tanh())
            :add(nn.Linear(embeddingSize * 3, 2))
    else
        nnmodel = nn.Sequential()
            :add(pmodule)
            :add(nn.JoinTable(2))
            :add(nn.ReLU())
            :add(nn.Linear(embeddingSize * 3, 2))
    end

    if adapt then 
        print("Adaptive regularization")
        local logmod = nn.Sequential()
            :add(nn.Linear(embeddingSize * 3, 1))
            :add(nn.LogSigmoid())
            :add(nn.SoftMax())

        local regmod = nn.Sequential()
            :add(nn.Linear(embeddingSize * 3, 2))

        local fullmod = nn.ConcatTable()
            :add(regmod)
            :add(logmod)

        local final = nn.Sequential()
            :add(pmodule)
            :add(nn.JoinTable(2))
            :add(fullmod)

        nnmodel = final
    end

    if use_cuda then
        return nnmodel:cuda()
    end
    return nnmodel
end

In [7]:
model = buildModel('bow', 100, 50, 'f1', false, false)

Running bag-of-words model to learn f1	


In [8]:
preds = model:forward({sentences, queries, torch.zeros(10, 5)})
print(preds)

-0.4660  0.1093
-0.0732 -0.0767
-0.0166 -0.1861
 0.1325  0.1624
 0.4015 -0.1054
-0.1860 -0.0452
 0.0051 -0.1971
 0.1170 -0.2537
-0.1520 -0.2123
-0.0405 -0.1671
[torch.DoubleTensor of size 10x2]



In [9]:
qMax, qindx = torch.max(preds, 2)
actions = torch.zeros(10, 2):scatter(2, qindx, torch.ones(preds:size()))

In [10]:
indx = torch.range(1, 10)
-- Summary index for selecting the words
sumindx = nn.MaskedSelect():forward({indx, actions:select(2,1):resize(10, 1):byte()})  --- sentences
sumindx = torch.LongTensor(5):copy(sumindx)
print(sentences:index(1, sumindx))

 58  91   3  86  68  85  55  24  63  10  13  35  19  89   5
  5  59  56  40  37  86  94  62  47   1  69  19  49  28  78
 57   2  13  22  51  67   8  87  27  92  24  31  70  31  86
 31  38  45   6  20  57   6   2  29  57  85  67  64  34  70
 72  72  99   7  41  63  42  79  14   2  72  28  69  40   6
[torch.LongTensor of size 5x15]

