In [1]:
mutable struct MultilayerPerceptron
    layers
    MultilayerPerceptron(layers...) = new(layers)
end
(m::MultilayerPerceptron)(x) = (for l in m.layers; x = l(x); end; x)

In [2]:
struct Layer0; w; b; end
Layer0(ir::Int, ic::Int, o::Int) = Layer0(param(o,ir),param0(o, ic))
(l::Layer0)(x) = (l.w * x .+ l.b)

In [3]:
#=
mutable struct EmbedModel
    w
end
=#

In [4]:
#=
function EmbedModel()
    dim1 = 30
    dim2 = 64 #EMBEDDING_SIZE
    dim3 = 32
    w = Param(reshape(KnetArray{Float32}(Knet.xavier(dim1*dim2*dim3)), (1,dim1,dim2,dim3)))
    return EmbedModel(w)
end
=#

In [5]:
#=
function (e::EmbedModel)(x)
    output = conv4(value(e.w), x, dilation=2)
    return output
end
=#

In [6]:
struct ConvModel; w; b; f; end
(c::ConvModel)(x) = c.f.(conv4(c.w, x, dilation=2) .+ c.b)
ConvModel(w1,w2,cx,cy,f=relu) = ConvModel(param(w1,w2,cx,cy), param0(1,2,cy,1), f)

ConvModel

In [7]:
mutable struct CaptionEncoder
    conv_model
end

In [8]:
function CaptionEncoder()
    dim1 = 30
    dim2 = 64 #EMBEDDING_SIZE
    dim3 = 32    
   conv_model = ConvModel(1,dim1,dim2,dim3)
   return CaptionEncoder(conv_model)
end

CaptionEncoder

In [9]:
function (c::CaptionEncoder)(captions, vocabid, vocab)
    vocabid, vocab, caption_embeds = arrange(captions, vocabid, vocab)
    input = createconvinput(caption_embeds)
    input = reshape(input, (1, MAX_LENGTH, EMBEDDING_SIZE, BATCH_SIZE*NUM_CAPTIONS_PER_SCENE))
    di_hat = c.conv_model(input)
    captions_hat = reshape(reshape(di_hat, (2, 32, BATCH_SIZE*NUM_CAPTIONS_PER_SCENE)), (EMBEDDING_SIZE, BATCH_SIZE*NUM_CAPTIONS_PER_SCENE))
    return captions_hat, vocabid, vocab
end

In [10]:
mutable struct AngleEncoder
    mlp_model
end

In [11]:
function AngleEncoder()
    dim1 = 2 # cos and sin
    dim2 = BATCH_SIZE*NUM_CAPTIONS_PER_SCENE
    dim3 = 32 # from paper
    mlp1=MultilayerPerceptron(Layer0(dim1, dim2, dim3)) # MLP1 dimensionality 32
    return AngleEncoder(mlp1)
end

AngleEncoder

In [12]:
function (a::AngleEncoder)(cameras)
    tuples = build_angles(cameras)
    cameras_hat = a.mlp_model(tuples)
    return cameras_hat
end

In [13]:
#=
mutable struct ImageConvModel
    w
end
=#

In [14]:
#=
function ImageConvModel()
    dim1 = 17
    dim2 = 17
    dim3 = 3 # RGB
    w = reshape(KnetArray{Float32}(Knet.xavier(dim1*dim2*dim3)), (dim1,dim2,dim3,1))
    return ImageConvModel(w)
end
=#

In [15]:
#=
function (i::ImageConvModel)(x)
    output = conv4(i.w, x)
    return output
end
=#

In [16]:
struct ImageConvModel; w; b; f; end
(i::ImageConvModel)(x) = i.f.(conv4(i.w, x) .+ i.b)
ImageConvModel(w1,w2,cx,cy,f=relu) = ImageConvModel(param(w1,w2,cx,cy), param0(16,16,cy,1), f)

ImageConvModel

In [17]:
mutable struct ImageEncoder
    image_conv_model
    sampling_model
end

In [18]:
function ImageEncoder()
    image_conv_model = ImageConvModel(17, 17, 3, 1)
    sampling_model = MultilayerPerceptron(Layer0(256,450,128), Layer0(128, 450, 18))
    return ImageEncoder(image_conv_model, sampling_model)
end

ImageEncoder

In [19]:
function (i::ImageEncoder)(images)
    imgencoderin = createimgencinput(images)
    imgencoderin2 = pool(imgencoderin, window=4, stride=4)
    himg = i.image_conv_model(imgencoderin2)
    himg = himg[:,:,1,:]
    himg = reshape(himg, (256,450))
    z = reshape(i.sampling_model(himg), (162, 50))
    return z
end

In [20]:
mutable struct RepresentationModel
    caption_encoder
    angle_encoder
    image_encoder
    mlp_model #mlp2 that takes concatenated di_hat and ci_hat, output = hi_hat
end

In [21]:
function RepresentationModel()
    caption_encoder = CaptionEncoder()
    angle_encoder = AngleEncoder()
    image_encoder = ImageEncoder()
    mlp_model = MultilayerPerceptron(Layer0(96, 500, 128), Layer0(128, 500, 196), Layer0(196, 500, 256)) # MLP2 dimensionality 256
   return RepresentationModel(caption_encoder, angle_encoder, image_encoder, mlp_model) 
end

RepresentationModel

In [22]:
function (re::RepresentationModel)(images, captions, cameras, vocabid, vocab)
    captions_hat, vocabid, vocab = re.caption_encoder(captions, vocabid, vocab)
    cameras_hat = re.angle_encoder(cameras)
    
    # move!
    unseen_ang = []
    for i in 1:size(cameras_hat, 2)
        if mod(i,10) == 0
           push!(unseen_ang, cameras_hat[:,i]) 
        end
    end
    unseen_ang = hcat(unseen_ang...)
    unseen_img = pool(permutedims(images[:,10,:,:,:], (2,3,4,1)), window=4, stride=4) # shall the batchsize be in the end
    
    h = re.mlp_model(cat(captions_hat, cameras_hat, dims=1))
    r = aggregate(h)
    z = re.image_encoder(images)
    
    return r, z, unseen_ang, unseen_img, vocabid, vocab
end

In [38]:
struct GenerationModel; w; b; f; end
(ge::GenerationModel)(x) = relu.(deconv4(ge.w, x) .+ ge.b)
GenerationModel() = GenerationModel(param(32,32,3,450), param0(32,32,3,1), relu)

GenerationModel

In [32]:
#=
function GenerationModel()
    dim1 = 32
    dim2 = 32
    dim3 = 3 # RGB
    dim4 = 450
    w = reshape(KnetArray{Float32}(Knet.xavier(dim3*dim1*dim2*dim4)), (dim1,dim2,dim3,dim4))
    return GenerationModel(w)
end
=#

In [33]:
#=
function (g::GenerationModel)(x)
    final_output = deconv4(g.w, x)
    return final_output
end
=#

In [34]:
mutable struct Network
    representationModel
    generationModel
end

In [35]:
function Network()
    representationModel = RepresentationModel()
    generationModel = GenerationModel()
    return Network(representationModel, generationModel)
end

Network

In [42]:
function (n::Network)(images, captions, cameras, vocabid, vocab)
    r, z, unseen_ang, unseen_img, vocabid, vocab = n.representationModel(images, captions, cameras, vocabid, vocab)
    gen_input = reshape(vcat(z, vcat(r, unseen_ang)),(1,1,450,50)) # decoder input
    output = n.generationModel(gen_input)
    return output #, vocabid, vocab
end

In [64]:
# train!@DRAW correspondence
function update_weights!(n::Network, x, y)
    J = @diff bce(n(x), y)
    for par in params(n)
        g = grad(J, par)
        update!(value(par), g; lr=0.1)
    end
    # return?
end

32×2 KnetArray{Float32,2}:
  0.0122153    0.00147552
  0.181434    -0.104982  
 -0.00618836   0.0924101 
  0.14375     -0.192867  
 -0.158582     0.171397  
  0.170024    -0.114838  
  0.229864    -0.18594   
 -0.140441    -0.123064  
 -0.0501638   -0.157717  
  0.138755     0.219494  
  0.0531359   -0.105213  
 -0.083853    -0.155687  
  0.071052    -0.0338888 
  ⋮                      
  0.0477721   -0.139004  
  0.207031     0.190264  
 -0.0376994   -0.165688  
  0.0328238    0.210589  
  0.0606524   -0.0770972 
 -0.0142767   -0.0843005 
 -0.236633    -0.107708  
 -0.20422     -0.0378661 
  0.219459     0.14916   
  0.0987915   -0.14041   
 -0.172657     0.132751  
  0.136296    -0.145437  