In [1]:
require 'nn'
require 'rnn'

In [10]:
-- Some useful functions
function genNbyK(n, k, a, b)
    out = torch.LongTensor(n, k)
    for i=1, n do
        for j = 1, k do
            out[i][j] = torch.random(a, b)
        end
    end
    return out
end

function buildModel(model, vocabSize, embeddingSize, metric, adapt, use_cuda)
    -- Small experiments seem to show that the Tanh activations performed better\
    --      than the ReLU for the bow model
    if model == 'bow' then
        print(string.format("Running bag-of-words model to learn %s", metric))
        sentenceLookup = nn.Sequential()
                    :add(nn.LookupTableMaskZero(vocabSize, embeddingSize))
                    :add(nn.Sum(2, 3, true)) -- Not averaging blows up model so keep this true
                    :add(nn.Tanh())
    else
        print(string.format("Running LSTM model to learn %s", metric))
        sentenceLookup = nn.Sequential()
                    :add(nn.LookupTableMaskZero(vocabSize, embeddingSize))
                    :add(nn.SplitTable(2))
                    :add(nn.Sequencer(nn.LSTM(embeddingSize, embeddingSize)))
                    :add(nn.SelectTable(-1))            -- selects last state of the LSTM
                    :add(nn.Linear(embeddingSize, embeddingSize))
                    :add(nn.ReLU())
    end
    local queryLookup = sentenceLookup:clone("weight", "gradWeight") 
    local summaryLookup = sentenceLookup:clone("weight", "gradWeight")
    local pmodule = nn.ParallelTable()
                :add(sentenceLookup)
                :add(queryLookup)
                :add(summaryLookup)

    if model == 'bow' then
        nnmodel = nn.Sequential()
            :add(pmodule)
            :add(nn.JoinTable(2))
            :add(nn.Tanh())
            :add(nn.Linear(embeddingSize * 3, 2))
    else
        nnmodel = nn.Sequential()
            :add(pmodule)
            :add(nn.JoinTable(2))
            :add(nn.ReLU())
            :add(nn.Linear(embeddingSize * 3, 2))
    end

    if adapt then 
        print("Adaptive regularization")
        local logmod = nn.Sequential()
            :add(nn.Linear(embeddingSize * 3, 1))
            :add(nn.LogSigmoid())
            :add(nn.SoftMax())

        local regmod = nn.Sequential()
            :add(nn.Linear(embeddingSize * 3, 2))

        local fullmod = nn.ConcatTable()
            :add(regmod)
            :add(logmod)

        local final = nn.Sequential()
            :add(pmodule)
            :add(nn.JoinTable(2))
            :add(fullmod)

        nnmodel = final
    end

    if use_cuda then
        return nnmodel:cuda()
    end
    return nnmodel
end

In [17]:
n = 10
k = 15
q = 5
a = 1
b = 100
embDim = 50

In [18]:
trueqValues = torch.rand(n, 2)              ---- Simulating the data
qMaxtrue, qindxtrue = torch.max(trueqValues, 2)  --- Generating the max values and getting the indices

--- I want to select the qindx elements for each row
trueactions = torch.zeros(n, 2):scatter(2, qindxtrue, torch.ones(trueqValues:size()))

-- Simulating streams and queries
sentences = genNbyK(n, k, a, b)
queries = genNbyK(n, q, a, b)

In [19]:
model = buildModel('bow', b, embDim, 'f1', false, false)

Running bag-of-words model to learn f1	


In [20]:
preds = model:forward({sentences, queries, torch.zeros(n, q)})
print(preds)

 0.0518  0.0576
 0.0192  0.2835
 0.1888  0.0355
 0.0983  0.0208
 0.0706  0.1803
-0.0029  0.2815
-0.0268 -0.1170
 0.0811 -0.0627
-0.1178  0.0416
-0.0486  0.0576
[torch.DoubleTensor of size 10x2]



In [22]:
qMax, qindx = torch.max(preds, 2)
actions = torch.zeros(n, 2):scatter(2, qindx, torch.ones(preds:size()))

In [23]:
sentences

  55   32   85   37    3   10   79   57   47   64   89   81   69   62   30
  58   10   70   93   59   57   48   31   26   93   69   11   76   82   83
  31   66   50   50   76   57   23   42   50   57    8   29   58   70   27
  55   78   35   61   14   55   45   75   64   62   46    1   31   64   37
  93    3   97   89    5   70   77   68   77   46   63   90   92   19   58
  40   56   35    8   96   75   53   86   28   76   16    2   19   93    6
  44   85   62   10   43   18  100   13   18   69   13   60   85    5   33
  68   17   35   52   53   33   77    3   87   40   32    2    5   37   66
  33   79   91   66    7   77   10   18   15   85   13    9   10   57   32
  24   22   43   44   19    6   88   54   27   94   82   24   98   69   57
[torch.LongTensor of size 10x15]



In [24]:
indx = torch.range(1, 10)
-- Summary index for selecting the words
sumindx = nn.MaskedSelect():forward({indx, actions:select(2,1):resize(10, 1):byte()})  --- sentences
-- need a LongTensor to do the selection
sumindx = torch.LongTensor(sumindx:size(1)):copy(sumindx)
--- This gives the wrong selection we want
print(sentences:index(1, sumindx))

  58   10   70   93   59   57   48   31   26   93   69   11   76   82   83
  55   78   35   61   14   55   45   75   64   62   46    1   31   64   37
  93    3   97   89    5   70   77   68   77   46   63   90   92   19   58
  44   85   62   10   43   18  100   13   18   69   13   60   85    5   33
  24   22   43   44   19    6   88   54   27   94   82   24   98   69   57
[torch.LongTensor of size 5x15]

