In [1]:
require 'nn'
require 'rnn'

In [2]:
-- Some useful functions
function genNbyK(n, k, a, b)
    out = torch.LongTensor(n, k)
    for i=1, n do
        for j = 1, k do
            out[i][j] = torch.random(a, b)
        end
    end
    return out
end

function buildModel(model, vocabSize, embeddingSize, metric, adapt, use_cuda)
    -- Small experiments seem to show that the Tanh activations performed better\
    --      than the ReLU for the bow model
    if model == 'bow' then
        print(string.format("Running bag-of-words model to learn %s", metric))
        sentenceLookup = nn.Sequential()
                    :add(nn.LookupTableMaskZero(vocabSize, embeddingSize))
                    :add(nn.Sum(2, 3, true)) -- Not averaging blows up model so keep this true
                    :add(nn.Tanh())
    else
        print(string.format("Running LSTM model to learn %s", metric))
        sentenceLookup = nn.Sequential()
                    :add(nn.LookupTableMaskZero(vocabSize, embeddingSize))
                    :add(nn.SplitTable(2))
                    :add(nn.Sequencer(nn.LSTM(embeddingSize, embeddingSize)))
                    :add(nn.SelectTable(-1))            -- selects last state of the LSTM
                    :add(nn.Linear(embeddingSize, embeddingSize))
                    :add(nn.ReLU())
    end
    local queryLookup = sentenceLookup:clone("weight", "gradWeight") 
    local summaryLookup = sentenceLookup:clone("weight", "gradWeight")
    local pmodule = nn.ParallelTable()
                :add(sentenceLookup)
                :add(queryLookup)
                :add(summaryLookup)

    if model == 'bow' then
        nnmodel = nn.Sequential()
            :add(pmodule)
            :add(nn.JoinTable(2))
            :add(nn.Tanh())
            :add(nn.Linear(embeddingSize * 3, 2))
    else
        nnmodel = nn.Sequential()
            :add(pmodule)
            :add(nn.JoinTable(2))
            :add(nn.ReLU())
            :add(nn.Linear(embeddingSize * 3, 2))
    end

    if adapt then 
        print("Adaptive regularization")
        local logmod = nn.Sequential()
            :add(nn.Linear(embeddingSize * 3, 1))
            :add(nn.LogSigmoid())
            :add(nn.SoftMax())

        local regmod = nn.Sequential()
            :add(nn.Linear(embeddingSize * 3, 2))

        local fullmod = nn.ConcatTable()
            :add(regmod)
            :add(logmod)

        local final = nn.Sequential()
            :add(pmodule)
            :add(nn.JoinTable(2))
            :add(fullmod)

        nnmodel = final
    end

    if use_cuda then
        return nnmodel:cuda()
    end
    return nnmodel
end

In [3]:
-- Setting parameters
n = 10
n_s = 5
k = 7
q = 5
a = 1
b = 100
embDim = 50
SELECT = 2
SKIP = 1

In [4]:
-- Simulating streams and queries
queries = genNbyK(n, q, a, b)

-- Note that the sentences are batched by sentence index so sentences[1] is the first sentence of each article
sentences = {}
for i=1, n_s do
    sentences[i] = genNbyK(n, k, a, b)
end

-- Using this to generate the optimal actions
true_actions = {}
for i=1, n_s do 
    ---- Simulating the data
    trueqValues = torch.rand(n, 2)                   
     ---- Generating the max values and getting the indices
    qMaxtrue, qindxtrue = torch.max(trueqValues, 2) 
    --- I want to select the qindx elements for each row
    true_actions[i] = torch.zeros(n, 2):scatter(2, qindxtrue, torch.ones(trueqValues:size()))
end

In [5]:
function stackMemory(newinput, memory_hist, memsize, adapt, use_cuda)
    local sentMemory = torch.cat(newinput[1][1]:double(), memory_hist[1][1]:double(), 1)
    local queryMemory = torch.cat(newinput[1][2]:double(), memory_hist[1][2]:double(), 1)
    local sumryMemory = torch.cat(newinput[1][3]:double(), memory_hist[1][3]:double(), 1)
    local rewardMemory = torch.cat(newinput[2]:double(), memory_hist[2]:double(), 1)

    if adapt then
        regMemory = torch.cat(newinput[4]:double(), memory_hist[4]:double(), 1)
    end 

    if use_cuda then 
        actionMemory = torch.cat(newinput[3]:double(), memory_hist[3]:double(), 1)
    else 
        actionMemory = torch.cat(newinput[3], memory_hist[3], 1)
    end
    --- specifying rows to index 
    if sentMemory:size(1) <= memsize then
        nend = sentMemory:size(1)
        nstart = 1
    else 
        nstart = math.max(memsize - sentMemory:size(1), 1)
        nend = memsize + nstart
    end
    --- Selecting n last data points
    sentMemory = sentMemory[{{nstart, nend}}]
    queryMemory = queryMemory[{{nstart, nend}}]
    sumryMemory = sumryMemory[{{nstart, nend}}]
    rewardMemory = rewardMemory[{{nstart, nend}}]
    actionMemory = actionMemory[{{nstart, nend}}]

    if use_cuda then
        inputMemory = {sentMemory:cuda(), queryMemory:cuda(), sumryMemory:cuda()}
        rewardMemory = rewardMemory:cuda()
        actionMemory = torch.ByteTensor(#actionMemory):copy(actionMemory):cuda()
    end

    inputMemory = {sentMemory, queryMemory, sumryMemory}
    if adapt then
        regMemory = regMemory[{{nstart, nend}}]
        return {inputMemory, rewardMemory, actionMemory, regMemory}
    end 
    return {inputMemory, rewardMemory, actionMemory}
end    

In [6]:
querySize = queries:size(2)
streamSize = sentences[1]:size(2)
-- summaryBatch = summaryBuffer:narrow(1, 1, streamSize)
-- queryBatch = queries:view(1, querySize):expand(streamSize, querySize) 
-- input = {sentenceStream, queryBatch, summaryBatch}

In [7]:
model = buildModel('bow', b, embDim, 'f1', false, false)

Running bag-of-words model to learn f1	


In [8]:
preds = model:forward({sentences[1], queries, torch.zeros(n, q)})
print(preds)

 0.0133  0.3126
 0.0867  0.2945
-0.0580  0.0579
 0.2729  0.0461
-0.1056 -0.1760
-0.1638 -0.1753
 0.1614 -0.3803
 0.0252 -0.0558
 0.0155 -0.5019
 0.1707 -0.0464
[torch.DoubleTensor of size 10x2]



In [10]:
qMax, qindx = torch.max(preds, 2)
actions = torch.zeros(n, 2):scatter(2, qindx, torch.ones(preds:size()))

In [11]:
predsummary = torch.zeros(sentences[1]:size())
for i=1, actions:size(1) do
    if actions[i][2]==1 then
        predsummary[i]:copy(sentences[1][i])
    end
end

In [12]:
indices = torch.linspace(1, actions:size(1), actions:size(1) ):long()
selected = indices[actions:select(2,1):eq(1)]

In [13]:
torch.totable(selected)

{
  1 : 4
  2 : 5
  3 : 6
  4 : 7
  5 : 8
  6 : 9
  7 : 10
}


In [14]:
predsummary = torch.zeros(sentences[1]:size())

In [15]:
sentences[1]:index(1, torch.LongTensor(selected)):double()

  62   34   41    8   13    4    3
  36   69   64   82   25   77   39
  15   74   19   54   10   22   91
  36   99    5   50   56   31   79
  88   75   79  100    5   31   10
  55   33   43   50   70   32   20
  22   67   31   21   50   14   46
[torch.DoubleTensor of size 7x7]



In [None]:
-- predsummary:scatter(1, torch.LongTensor(selected), sentences[1]:double())

In [16]:
predsummary:index(1, torch.LongTensor(selected)):copy(sentences[1]:index(1, torch.LongTensor(selected)):double())

  62   34   41    8   13    4    3
  36   69   64   82   25   77   39
  15   74   19   54   10   22   91
  36   99    5   50   56   31   79
  88   75   79  100    5   31   10
  55   33   43   50   70   32   20
  22   67   31   21   50   14   46
[torch.DoubleTensor of size 7x7]



In [19]:
torch.eq(actions:select(2,2), 1):view(n, 1):expand(n, k)

 1  1  1  1  1  1  1
 1  1  1  1  1  1  1
 1  1  1  1  1  1  1
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
[torch.ByteTensor of size 10x7]



In [20]:
predsummary

 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
 0  0  0  0  0  0  0
[torch.DoubleTensor of size 10x7]



In [21]:
test = torch.zeros(n, 7)
-- test:copy(actions:selec(2, 1))

In [22]:
-- torch.zeros(n, 7):scatter(2, actions:select(2, 1), 

In [23]:
maskedSelect = nn.MaskedSelect()
mask1 = torch.eq(actions:select(2,2), 1):view(n, 1):expand(n, k)
allTokens = sentences[1]:maskedSelect(mask1)
mask2 = torch.gt(allTokens,0)
allTokens = allTokens:maskedSelect(mask2)

In [24]:
torch.zeros(n, 15):scatter(2, qindx, sentences[1]:double())

  0  61   0   0   0   0   0   0   0   0   0   0   0   0   0
  0  49   0   0   0   0   0   0   0   0   0   0   0   0   0
  0  25   0   0   0   0   0   0   0   0   0   0   0   0   0
 62   0   0   0   0   0   0   0   0   0   0   0   0   0   0
 36   0   0   0   0   0   0   0   0   0   0   0   0   0   0
 15   0   0   0   0   0   0   0   0   0   0   0   0   0   0
 36   0   0   0   0   0   0   0   0   0   0   0   0   0   0
 88   0   0   0   0   0   0   0   0   0   0   0   0   0   0
 55   0   0   0   0   0   0   0   0   0   0   0   0   0   0
 22   0   0   0   0   0   0   0   0   0   0   0   0   0   0
[torch.DoubleTensor of size 10x15]



In [25]:
sentences[1]

  61   48    7   38   28   14    6
  49   25   35   96   57   83   24
  25   41    6   54   90   87   65
  62   34   41    8   13    4    3
  36   69   64   82   25   77   39
  15   74   19   54   10   22   91
  36   99    5   50   56   31   79
  88   75   79  100    5   31   10
  55   33   43   50   70   32   20
  22   67   31   21   50   14   46
[torch.LongTensor of size 10x7]



In [None]:
indices = torch.linspace(1, x:size(1), x:size(1) ):long()
selected = indices[x:eq(1)]

In [27]:
indx = torch.range(1, 10)
-- Summary index for selecting the words
sumindx = nn.MaskedSelect():forward({indx, actions:select(2,1):resize(10, 1):byte()})  --- sentences
-- need a LongTensor to do the selection
-- sumindx = torch.LongTensor(sumindx:size(1)):copy(sumindx)

In [None]:
indx = torch.range(1, 10)
-- Summary index for selecting the words
sumindx = nn.MaskedSelect():forward({indx, actions:select(2,1):resize(10, 1):byte()})  --- sentences
-- need a LongTensor to do the selection
sumindx = torch.LongTensor(sumindx:size(1)):copy(sumindx)
--- This gives the wrong selection we want
print(sentences:index(1, sumindx))