In [1]:
require 'torch'
require 'nn'

In [35]:
function inference_block()
-- Q_v, Q_o = unpack(net:forward(X, Q_o, Q_v))
    local nInputX = 4096
    local nInputQ_o = 200
    local nInputQ_v = 150
    
    local function basic(nInput1, nInput2, nOutput)
        local net = nn.Sequential()
            :add(nn.ParallelTable()
                :add(nn.Linear(nInput1, nOutput))
                :add(nn.Linear(nInput2, nOutput)))
            :add(nn.CAddTable(true))
            :add(nn.Sigmoid())
        return net
    end
    
    local net = nn.Sequential()
        :add(nn.ConcatTable() 
            :add(nn.SelectTable(1))
            :add(nn.ConcatTable()
                :add(nn.Sequential()
                    :add(nn.ConcatTable()
                        :add(nn.SelectTable(1))
                        :add(nn.SelectTable(3)))
                    :add(basic(nInputX, nInputQ_v, nInputQ_o)))
                :add(nn.Sequential()
                    :add(nn.NarrowTable(1,2))
                    :add(basic(nInputX, nInputQ_o, nInputQ_v)))))
        :add(nn.FlattenTable())
    return net
end

function buildModel(opt)
    nUnits = opt.unitNum or 8
    
    local net = nn.Sequential()
    
    local unit = inference_block()
    for i=1, nBlocks do
        if opt.share then
            net:add(unit)
        else
            net:add(inference_block())
        end
    end
    
    net:add(nn.NarrowTable(2,2))
    net:add(nn.JoinTable(1))
    
    return net
end

In [41]:
input = { torch.rand(4096), torch.rand(200), torch.rand(150)}
opt = {
    share = true,
    unitNum = 8
}
model = buildModel(opt)
pred = model:forward(input)
grad = model:backward(input, torch.rand(350))
model:updateParameters(0.3)

-0.0854
-0.5113
-0.0553
-0.2110
-0.1224
-0.0754
-0.3007
-0.3311
-0.3634
-0.5457
-0.0456
 0.0232
-0.1950
-0.2410
-0.5033
-0.0544
-0.4879
-0.0754
-0.2398
-0.4883
-0.1356
-0.2180
-0.2295
-0.3872
-0.4070
-0.2387
-0.2168
-0.4807
-0.3030
-0.0083
-0.3612
-0.3549
-0.3698
-0.3232
-0.2190
-0.5129
-0.2721
-0.5506
-0.1055
-0.1006
-0.2926
-0.0384
-0.3929
-0.1558
-0.4377
-0.0852
-0.0822
-0.2133
-0.4290
-0.3381
-0.5543
-0.2004
-0.1907
-0.3440
-0.0562
-0.4822
-0.1900
-0.2204
-0.3183
-0.2787
-0.3207
-0.3483
-0.3247
-0.5408
-0.2595
-0.4197
-0.4415
-0.1337
-0.2871
-0.4672
-0.2751
-0.3196
-0.0304
-0.3869
-0.1587
-0.4773
-0.2737
-0.4733
-0.1766
-0.3362
-0.3366
-0.2382
-0.4900
-0.4009
-0.1153
-0.0874
 0.0487
-0.1632
-0.1880
-0.1299
-0.2065
-0.2117
-0.0801
-0.5871
-0.1296
-0.0394
-0.1505
-0.3845
-0.4766
-0.1491
-0.1355
-0.3939
-0.5421
-0.5591
-0.1229
-0.0675
-0.4530
-0.3716
-0.1814
-0.3645
-0.0615
-0.3651
-0.3483
-0.1198
-0.2136
-0.1269
-0.1818
-0.5597
-0.1125
-0.4452
-0.5154
-0.6353
-0.5312
-0.0872
-0.1969
