In [18]:
#F to 12 x 12 x 256
#stacked transposed convolutions, separated by ReLUs, with a kernel width of 5 and stride of 2 to upsample to 96 x 96 x 32
#1 x 1 convolution to 96 x 96 x 3

# torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(17)

f_vector = 736 # nn4 model의 output layer 전 마지막 layer dimension
landmark_num = 68

class Decoder(nn.Module):
    
    def __init__(self, useCude=False):
        super(Decoder, self).__init__()
        
        self.base_fc = nn.Linear(f_vector, f_vector)
        
        self.shallow_mlp = nn.Sequential(
            nn.Linear(f_vector, 256), # hidden1
            nn.ReLU(),
            nn.Linear(256, 128), # hidden2
            nn.ReLU(),
            nn.Linear(128, 68), # out
            nn.ReLU())
                
        self.texture_fc = nn.Linear(f_vector, 256 * 12 * 12)
        
        self.tConv1 = nn.ConvTranspose2d(256, 128, 5, stride=2, padding=2)
        self.tConv2 = nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2)
        self.tConv3 = nn.ConvTranspose2d(64, 32, 5, stride=2, padding=2)
        
        self.conv = nn.Sequential(
            nn.Conv2d(32, 3, kernel_size=1),
            nn.BatchNorm2d(3),
            nn.ReLU())

    # x := 논문에서 1024, 우리는 736 vector
    def forward(self, x):
        feature_vec = self.base_fc(x)
        
        x1 = self.shallow_mlp(feature_vec)
        y1 = self.shallow_mlp(feature_vec)
        
        texture_out = self.texture_fc(feature_vec)
        texture_out = texture_out.view(1, 256, 12, 12)
        texture_out = F.relu(self.tConv1(texture_out, output_size=[None, None, 24, 24]))
        texture_out = F.relu(self.tConv2(texture_out, output_size=[None, None, 48, 48]))
        texture_out = F.relu(self.tConv3(texture_out, output_size=[None, None, 96, 96]))
        texture_out = self.conv(texture_out)
        return x1, y1, texture_out

dtype = torch.FloatTensor

# input : 736 dimension
x = Variable(torch.randn(1, 1, 1, f_vector))

# ground_truth_textures := actual input image
#y = Variable(ground_truth_textures, requires_grad=False)
ground_truth_textures = Variable(torch.randn(1, 3, 96, 96), requires_grad=False)
landmark_x = Variable(torch.randn(1, landmark_num))
landmark_y = Variable(torch.randn(1, landmark_num))

net = Decoder()

loss_fn_landmark = torch.nn.MSELoss(size_average=False)
loss_fn_texture = torch.nn.L1Loss()

learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

for t in range(500):
    x_pred, y_pred, texture_pred = net(x)

    loss_x = loss_fn_landmark(x_pred, landmark_x)
    loss_y = loss_fn_landmark(y_pred, landmark_y)
    loss_texture = loss_fn_texture(texture_pred, ground_truth_textures)
    
    total_loss = loss_x + loss_y + loss_texture
    if (t % 100 == 0):
        print(t, total_loss.data[0])

    optimizer.zero_grad()

    total_loss.backward()

    optimizer.step()

0 133.39932250976562
100 107.00054168701172
200 106.9864730834961
300 106.9794692993164
400 106.97309875488281
