In [5]:
require 'cutorch'
require 'cunn';
require 'nn';
require 'image';
require 'mnist_loader';

In [6]:
generator = nn.Sequential()
local p = nn.ParallelTable()

local noiseBranch = nn.Sequential()
noiseBranch:add(nn.Linear(100, 256))
noiseBranch:add(nn.BatchNormalization(256))
noiseBranch:add(nn.ReLU())

local classBranch = nn.Sequential()
classBranch:add(nn.Linear(10, 256))
noiseBranch:add(nn.BatchNormalization(256))
classBranch:add(nn.ReLU())

p:add(noiseBranch)
p:add(classBranch)

generator:add(p)
generator:add(nn.JoinTable(2))
generator:add(nn.Linear(512, 512))
generator:add(nn.BatchNormalization(512))
generator:add(nn.ReLU())
generator:add(nn.Linear(512, 1024))
generator:add(nn.BatchNormalization(1024))
generator:add(nn.ReLU())
generator:add(nn.Dropout(0.5))
generator:add(nn.Linear(1024, 32*32))
generator:add(nn.Tanh())
generator:add(nn.Reshape(32, 32))

In [7]:
function init(model) 
    queue = {model}
    local idx = 1
    while idx <= #queue do
        local t = torch.type(queue[idx])
        --　print("proprocessing ", t)
        if t == "nn.Sequential" or t == "nn.ParallelTable" then
            for m_idx = 1, #queue[idx].modules do
                local m = torch.type(queue[idx].modules[m_idx])
                -- print("insert modules", m)
                table.insert(queue, queue[idx].modules[m_idx])
            end
        elseif t == "nn.Linear" then
            -- print("init weight, bias", t)
--             print("before")
--             print(queue[idx].weight:mean())
--             print(queue[idx].weight:std())
            queue[idx].weight:add(-queue[idx].weight:mean())
            queue[idx].weight:div(queue[idx].weight:std()):mul(0.02)
--             print("after")
--             print(queue[idx].weight:mean())
--             print(queue[idx].weight:std())
            queue[idx].bias:zero()
        else
            -- print("ignore module", t)
        end
        idx = idx + 1
    end

end
init(generator)

before	
6.2222344318386e-07	
0.02552632963001	
after	
7.1891921398209e-19	
0.02	
before	
1.1009459294502e-05	


0.025527276196163	


after	


-1.1647737000512e-22	


0.02	
before	
-3.7970726798967e-05	


0.018039531450596	


after	


8.0856574004813e-19	


0.02	
before	
0.00045112352991782	
0.05756007579683	
after	
5.9983485192761e-19	
0.02	
before	
-0.0042674754936695	
0.18106573153292	
after	
-1.3491540783866e-18	
0.02	


In [None]:
parameters, gradParameters = generator:getParameters()
print(parameters:mean())
print(parameters:std())
init(generator)
print(parameters:mean())
print(parameters:std())

In [None]:
batchsize = 128
noises = torch.randn(batchsize, 100)
labels = torch.randn(batchsize, 10)
imgs = generator:forward({noises, labels})
itorch.image(imgs)

In [None]:
criterion = nn.ClassNLLCriterion()

In [None]:
discriminator = nn.Sequential()

local p = nn.ParallelTable()

local imgBranch = nn.Sequential()
imgBranch:add(nn.Reshape(32*32))


imgBranch:add(nn.Linear(32*32, 1024))
imgBranch:add(nn.LeakyReLU(0.2))

local classBranch = nn.Sequential()
classBranch:add(nn.Linear(10, 1024))
imgBranch:add(nn.LeakyReLU(0.2))

p:add(imgBranch)
p:add(classBranch)

discriminator:add(p)
discriminator:add(nn.JoinTable(2))
discriminator:add(nn.Linear(2048, 512))
discriminator:add(nn.BatchNormalization(512))
discriminator:add(nn.LeakyReLU(0.2))
discriminator:add(nn.Linear(512, 256))
discriminator:add(nn.BatchNormalization(256))
discriminator:add(nn.LeakyReLU(0.2))
discriminator:add(nn.Linear(256, 2))
discriminator:add(nn.LogSoftMax())

In [None]:
paramters, gradParameters = discriminator:getParameters()
print(paramters:mean())
print(paramters:std())

In [None]:
train_set = mnist.loadTrainSet()

In [None]:
batchsize = 5
inputImgInputBatch = torch.zeros(batchsize, 32, 32)
inputVectorInputBatch = torch.zeros(batchsize, 10)
batch_count = 1
for batch_count_idx = 1, batchsize do
    i = (batch_count - 1) * batchsize + batch_count_idx
    inputImgInput = train_set[i][1]:view(32, 32)
    inputVectorInput = train_set[i][2]
    inputVectorInputBatch[batch_count_idx] = inputVectorInput
    inputImgInputBatch[batch_count_idx] = inputImgInput
end    

In [None]:
print(inputImgInputBatch:type())
print(inputVectorInputBatch:type())

In [None]:
res = discriminator:forward({inputImgInputBatch, inputVectorInputBatch})
print(res)