In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable

In [None]:
def C_BN_LR(c_in, c_out, activation, transpose=False, dropout=None):
    layers = []
    if not transpose:
        layers.append(         nn.Conv2d(c_in, c_out, kernel_size=4, stride=2, padding=1))
    else:
        layers.append(nn.ConvTranspose2d(c_in, c_out, kernel_size=4, stride=2, padding=1))
    if dropout:
        layers.append(nn.Dropout2d(dropout))
    layers.append(nn.BatchNorm2d(c_out))
    layers.append(activation)
    return nn.Sequential(*layers)

In [None]:
num_channels = 3
image_size = 64 # or 256
if image_size == 64:
    k_list = [num_channels, 16, 32, 64, 128, 256, 512] #, 512]
else:
    k_list = [num_channels, 16, 32, 64, 128, 256, 512, 512]
batch_size = 32
num_attr = 23
attr_dim = 23
# note that according to the paper, attr_dim = num_attr * 2

In [None]:
class Encoder(nn.Module):
    '''
    Input: (batch_size, num_channels, H, W)
    Output: (batch_size, 512, H / 2**7, W / 2**7)
    '''
    def __init__(self):
        super(Encoder, self).__init__()
        activation = nn.LeakyReLU(0.2)
        layers = []
        for i in range(1, len(k_list)):
            c_in, c_out = k_list[i - 1], k_list[i]
            layers.append(C_BN_LR(c_in, c_out, activation))
        self.convs = nn.Sequential(*layers)
    
    def forward(self, x):
        Ex = self.convs(x)
        return Ex

In [None]:
# class Concat_C_BN_LR(nn.Module):
#     '''
#     Input: (batch_size, c_in, H, W), (batch_size, attr_dim)
#     Output: (batch_size, c_out + attr_dim, H, W)
#     '''
#     def __init__(self, c_in, c_out, activation, transpose=True):
#         super(Concat_C_BN_LR, self).__init__()
#         self.conv = C_BN_LR(c_in, c_out, activation, transpose)

#     def forward(self, x, attrs):
#         H, W = x.size()[2], x.size()[3]
#         attrs_ = attrs.repeat(H, W, 1, 1).permute(2, 3, 0, 1)
#         x = torch.cat([x, attrs_], dim=1)
#         x = self.conv(x)
#         return x, attrs

In [None]:
class Decoder(nn.Module):
    '''
    Input: (batch_size, 512, H, W), (batch_size, attr_dim)
    Output: (batch_size, 3, H * 2**7, W * 2**7)
    '''
    def __init__(self, attr_dim, image_size=256, num_channels=3):
        super(Decoder, self).__init__()
        activation = nn.ReLU()
#         layers = []
#         for i in range(len(k_list) - 1, 0, -1):
#             c_in, c_out = k_list[i] + attr_dim, k_list[i - 1]
#             layers.append(Concat_C_BN_LR(c_in, c_out, activation, transpose=True))
#         self.deconvs = nn.Sequential(*layers)
        
        self.image_size = image_size
        if self.image_size == 256:
            self.deconv1 = C_BN_LR(k_list[7] + attr_dim, k_list[6], activation, transpose=True)
        self.deconv2 = C_BN_LR(k_list[6] + attr_dim, k_list[5], activation, transpose=True)
        self.deconv3 = C_BN_LR(k_list[5] + attr_dim, k_list[4], activation, transpose=True)
        self.deconv4 = C_BN_LR(k_list[4] + attr_dim, k_list[3], activation, transpose=True)
        self.deconv5 = C_BN_LR(k_list[3] + attr_dim, k_list[2], activation, transpose=True)
        self.deconv6 = C_BN_LR(k_list[2] + attr_dim, k_list[1], activation, transpose=True)
        self.deconv7 = C_BN_LR(k_list[1] + attr_dim, k_list[0], nn.Tanh(), transpose=True)
        
    def repeat_concat(self, Ex, attrs):
        H, W = Ex.size()[2], Ex.size()[3]
        attrs_ = attrs.repeat(H, W, 1, 1).permute(2, 3, 0, 1)
        Ex_ = torch.cat([Ex, attrs_], dim=1)
        return Ex_
        
    def forward(self, Ex, attrs):
        if self.image_size == 256:
            Ex = self.deconv1(self.repeat_concat(Ex, attrs))
        Ex = self.deconv2(self.repeat_concat(Ex, attrs))
        Ex = self.deconv3(self.repeat_concat(Ex, attrs))
        Ex = self.deconv4(self.repeat_concat(Ex, attrs))
        Ex = self.deconv5(self.repeat_concat(Ex, attrs))
        Ex = self.deconv6(self.repeat_concat(Ex, attrs))
        Ex = self.deconv7(self.repeat_concat(Ex, attrs))
        return Ex

In [None]:
class Discriminator(nn.Module):
    '''
    Input: (batch_size, 512, H / 2**7, W / 2**7)
    Output: (batch_size, num_attrs)
    '''
    def __init__(self, num_attrs, image_size=256):
        super(Discriminator, self).__init__()
        if image_size == 256:
            self.conv = C_BN_LR(512, 512, nn.LeakyReLU(0.2)) # ReLU? Dropout?
        self.fc1 = nn.Linear(512, 512)
        self.dp1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(512, num_attrs)
        self.dp2 = nn.Dropout(0.3)
    
    def forward(self, Ex):
        if image_size == 256:
            Ex = self.conv(Ex)
        p = Ex.view(Ex.size()[0], Ex.size()[1])
        p = self.dp1(self.fc1(p))
        p = self.dp2(self.fc2(p))
        return p

In [None]:
if __name__ == '__main__':
    Enc = Encoder()
    x = Variable(torch.zeros(32, 3, image_size, image_size))
    Ex = Enc(x)
    print(Ex.size())

    Dis = Discriminator(10)
    p = Dis(Ex)
    print(p.size())

    Dec = Decoder(20, image_size=image_size)
    attrs = Variable(torch.zeros(32, 20))
    x_ = Dec(Ex, attrs)
    print(x_.size())