In [11]:
require 'torch'
require 'nn'
require 'cunn'
require 'cudnn'

local activation = nn.ReLU

local function inference_block(opt)
-- X, Q_v', Q_o' = unpack(net:forward(X, Q_o, Q_v))

    local nInputX = opt.nFeatures or 4096
    local nInputQ_o = opt.nObjects or 38
    local nInputQ_v = opt.nVerbs or 33
    
    local function basic(nInput1, nInput2, nOutput)
        local net = nn.Sequential()
            :add(nn.ParallelTable()
                :add(nn.Linear(nInput1, nOutput))
                :add(nn.Linear(nInput2, nOutput)))
            :add(nn.CAddTable(true))
            :add(nn.BatchNormalization(nOutput))
            :add(activation())
        return net
    end
    
    local net = nn.Sequential()
        :add(nn.ConcatTable() 
            :add(nn.SelectTable(1))
            :add(nn.ConcatTable()
                :add(nn.Sequential()
                    :add(nn.ConcatTable()
                        :add(nn.SelectTable(1))
                        :add(nn.SelectTable(3)))
                    :add(basic(nInputX, nInputQ_v, nInputQ_o)))
                :add(nn.Sequential()
                    :add(nn.NarrowTable(1,2))
                    :add(basic(nInputX, nInputQ_o, nInputQ_v)))))
        :add(nn.FlattenTable())
    return net
end

function createModel(opt)
    local nUnits = opt.unitNum or 8
    local inputSize = opt.nObjects + opt.nVerbs
    local lstmOutputSize = 256

    local model = nn.Sequential()
    
    -- Inference units
    local unit = inference_block(opt)
    for i=1, nUnits do
        if opt.share then
            model:add(unit)
        else
            model:add(inference_block(opt))
        end
    end
    model:add(nn.NarrowTable(2,2))
    model:add(nn.JoinTable(2))

    -- LSTM layer
    local lstm = cudnn.LSTM(inputSize, lstmOutputSize, 1, true) 
    model:add(nn.View(opt.batchSize, opt.timesteps, -1))
    model:add(lstm)
--     -- Dropout layer
--     if opt.dropout > 0 then 
--         model:add(nn.Dropout(opt.dropout))
--     end
    -- Last FC layer
    --model:add(nn.Reshape(opt.batchSize * opt.timesteps, lstmOutputSize))
    --model:add(nn.Linear(lstmOutputSize, opt.nClasses))
    model:cuda()
    
    print(tostring(model))

    return model
end

In [13]:
opt = {
    share = false,
    unitNum = 8,
    nObjects = 50,
    nFeatures = 256,
    nVerbs = 35,
    nClasses = 157,
    dropout = 0.5,
    batchSize = 32,
    timesteps = 16
}
X = torch.rand(opt.batchSize * opt.timesteps, opt.nFeatures):cuda()
Y = torch.rand(opt.batchSize * opt.timesteps, opt.nObjects):cuda()
Z = torch.rand(opt.batchSize * opt.timesteps, opt.nVerbs):cuda()
model = createModel(opt)
-- pred = model:forward(input)
-- grad = model:backward(input, torch.rand(opt.nClasses))
-- model:updateParameters(0.3)
input = {X, Y, Z}

generateGraph = require 'optnet.graphgen'

-- visual properties of the generated graph
-- follows graphviz attributes
graphOpts = {
displayProps =  {shape='ellipse',fontsize=14, style='solid'},
nodeData = function(oldData, tensor)
  return oldData .. '\n' .. 'Size: '.. tensor:numel()
end
}

g = generateGraph(model, input, graphOpts)

graph.dot(g,'SRT_Net','no_share')

nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> (12) -> output]
  (1): nn.Sequential {
    [input -> (1) -> (2) -> output]
    (1): nn.ConcatTable {
      input
        |`-> (1): nn.SelectTable(1)
         `-> (2): nn.ConcatTable {
               input
                 |`-> (1): nn.Sequential {
                 |      [input -> (1) -> (2) -> output]
                 |      (1): nn.ConcatTable {
                 |        input
                 |          |`-> (1): nn.SelectTable(1)
                 |           `-> (2): nn.SelectTable(3)
                 |           ... -> output
                 |      }
                 |      (2): nn.Sequential {
                 |        [input -> (1) -> (2) -> (3) -> (4) -> output]
                 |        (1): nn.ParallelTable {
                 |          input
                 |            |`-> (1): nn.Linear(256 -> 50)
                 |             `-> (2): nn.Linear(35 -> 50)
        

l {
    [input -> (1) -> (2) -> output]
    (1): nn.ConcatTable {
      input
        |`-> (1): nn.SelectTable(1)
         `-> (2): nn.ConcatTable {
               input
                 |`-> (1): nn.Sequential {
                 |      [input -> (1) -> (2) -> output]
                 |      (1): nn.ConcatTable {
                 |        input
                 |          |`-> (1): nn.SelectTable(1)
                 |           `-> (2): nn.SelectTable(3)
                 |           ... -> output
                 |      }
                 |      (2): nn.Sequential {
                 |        [input -> (1) -> (2) -> (3) -> (4) -> output]
                 |        (1): nn.ParallelTable {
                 |          input
                 |            |`-> (1): nn.Linear(256 -> 50)
                 |             `-> (2): nn.Linear(35 -> 50)
                 |             ... -> output
                 |        }
                 |        (2): nn.CAddTable
                 |        (3): nn

able {
               input
                 |`-> (1): nn.Sequential {
                 |      [input -> (1) -> (2) -> output]
                 |      (1): nn.ConcatTable {
                 |        input
                 |          |`-> (1): nn.SelectTable(1)
                 |           `-> (2): nn.SelectTable(3)
                 |           ... -> output
                 |      }
                 |      (2): nn.Sequential {
                 |        [input -> (1) -> (2) -> (3) -> (4) -> output]
                 |        (1): nn.ParallelTable {
                 |          input
                 |            |`-> (1): nn.Linear(256 -> 50)
                 |             `-> (2): nn.Linear(35 -> 50)
                 |             ... -> output
                 |        }
                 |        (2): nn.CAddTable
                 |        (3): nn.BatchNormalization (2D) (50)
                 |        (4): nn.ReLU
                 |      }
                 |    }
                  `-> (

  |      (1): nn.ConcatTable {
                 |        input
                 |          |`-> (1): nn.SelectTable(1)
                 |           `-> (2): nn.SelectTable(3)
                 |           ... -> output
                 |      }
                 |      (2): nn.Sequential {
                 |        [input -> (1) -> (2) -> (3) -> (4) -> output]
                 |        (1): nn.ParallelTable {
                 |          input
                 |            |`-> (1): nn.Linear(256 -> 50)
                 |             `-> (2): nn.Linear(35 -> 50)
                 |             ... -> output
                 |        }
                 |        (2): nn.CAddTable
                 |        (3): nn.BatchNormalization (2D) (50)
                 |        (4): nn.ReLU
                 |      }
                 |    }
                  `-> (2): nn.Sequential {
                        [input -> (1) -> (2) -> output]
                        (1): nn.NarrowTable
                      

-0.0854
-0.5113
-0.0553
-0.2110
-0.1224
-0.0754
-0.3007
-0.3311
-0.3634
-0.5457
-0.0456
 0.0232
-0.1950
-0.2410
-0.5033
-0.0544
-0.4879
-0.0754
-0.2398
-0.4883
-0.1356
-0.2180
-0.2295
-0.3872
-0.4070
-0.2387
-0.2168
-0.4807
-0.3030
-0.0083
-0.3612
-0.3549
-0.3698
-0.3232
-0.2190
-0.5129
-0.2721
-0.5506
-0.1055
-0.1006
-0.2926
-0.0384
-0.3929
-0.1558
-0.4377
-0.0852
-0.0822
-0.2133
-0.4290
-0.3381
-0.5543
-0.2004
-0.1907
-0.3440
-0.0562
-0.4822
-0.1900
-0.2204
-0.3183
-0.2787
-0.3207
-0.3483
-0.3247
-0.5408
-0.2595
-0.4197
-0.4415
-0.1337
-0.2871
-0.4672
-0.2751
-0.3196
-0.0304
-0.3869
-0.1587
-0.4773
-0.2737
-0.4733
-0.1766
-0.3362
-0.3366
-0.2382
-0.4900
-0.4009
-0.1153
-0.0874
 0.0487
-0.1632
-0.1880
-0.1299
-0.2065
-0.2117
-0.0801
-0.5871
-0.1296
-0.0394
-0.1505
-0.3845
-0.4766
-0.1491
-0.1355
-0.3939
-0.5421
-0.5591
-0.1229
-0.0675
-0.4530
-0.3716
-0.1814
-0.3645
-0.0615
-0.3651
-0.3483
-0.1198
-0.2136
-0.1269
-0.1818
-0.5597
-0.1125
-0.4452
-0.5154
-0.6353
-0.5312
-0.0872
-0.1969
