In [1]:
import os
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn

from common.function import conv2d, deconv2d, lrelu, fc, embedding_lookup

In [2]:
GPU = torch.cuda.is_available()
GPU

True

## Generator

- source image를 입력받아서 Encoding과 Decoding 과정을 거쳐 fake image를 생성해내는 Generator

In [3]:
def Generator(images, En, De, embeddings, embedding_ids, GPU=False):
    encoded_source, encode_layers = En(images)
    local_embeddings = embedding_lookup(embeddings, embedding_ids, GPU=GPU)
    if GPU:
        encoded_source = encoded_source.cuda()
        local_embeddings = local_embeddings.cuda()
    embedded = torch.cat((encoded_source, local_embeddings), 1)
    fake_target = De(embedded, encode_layers)
    return fake_target, encoded_source

### Encoder
- Generator 모델 안의 Encoding을 담당하는 Encoder
- source image를 입력받아서 특징을 추출해 latent code를 출력

In [4]:
class Encoder(nn.Module):
    
    def __init__(self, img_dim=1, conv_dim=64):
        super(Encoder, self).__init__()
        self.conv1 = conv2d(img_dim, conv_dim, k_size=5, stride=2, pad=2, dilation=2, lrelu=False, bn=False)
        self.conv2 = conv2d(conv_dim, conv_dim*2, k_size=5, stride=2, pad=2, dilation=2)
        self.conv3 = conv2d(conv_dim*2, conv_dim*4, k_size=4, stride=2, pad=1, dilation=1)
        self.conv4 = conv2d(conv_dim*4, conv_dim*8)
        self.conv5 = conv2d(conv_dim*8, conv_dim*8)
        self.conv6 = conv2d(conv_dim*8, conv_dim*8)
        self.conv7 = conv2d(conv_dim*8, conv_dim*8)
        self.conv8 = conv2d(conv_dim*8, conv_dim*8)
    
    def forward(self, images):
        encode_layers = dict()
        
        e1 = self.conv1(images)
        encode_layers['e1'] = e1
        e2 = self.conv2(e1)
        encode_layers['e2'] = e2
        e3 = self.conv3(e2)
        encode_layers['e3'] = e3
        e4 = self.conv4(e3)
        encode_layers['e4'] = e4
        e5 = self.conv5(e4)
        encode_layers['e5'] = e5
        e6 = self.conv6(e5)
        encode_layers['e6'] = e6
        e7 = self.conv7(e6)
        encode_layers['e7'] = e7
        encoded_source = self.conv8(e7)
        encode_layers['e8'] = encoded_source
        
        return encoded_source, encode_layers

### Decoder
- Encoder가 생성한 latent code를 다시 decoding 해서 fake image를 출력

In [5]:
class Decoder(nn.Module):
    
    def __init__(self, img_dim=1, embedded_dim=640, conv_dim=64):
        super(Decoder, self).__init__()
        self.deconv1 = deconv2d(embedded_dim, conv_dim*8, dropout=True)
        self.deconv2 = deconv2d(conv_dim*16, conv_dim*8, dropout=True, k_size=4)
        self.deconv3 = deconv2d(conv_dim*16, conv_dim*8, k_size=5, dilation=2, dropout=True)
        self.deconv4 = deconv2d(conv_dim*16, conv_dim*8, k_size=4, dilation=2, stride=2)
        self.deconv5 = deconv2d(conv_dim*16, conv_dim*4, k_size=4, dilation=2, stride=2)
        self.deconv6 = deconv2d(conv_dim*8, conv_dim*2, k_size=4, dilation=2, stride=2)
        self.deconv7 = deconv2d(conv_dim*4, conv_dim*1, k_size=4, dilation=2, stride=2)
        self.deconv8 = deconv2d(conv_dim*2, img_dim, k_size=4, dilation=2, stride=2, bn=False)
    
    
    def forward(self, embedded, encode_layers):
        
        d1 = self.deconv1(embedded)
        d1 = torch.cat((d1, encode_layers['e7']), dim=1)
        d2 = self.deconv2(d1)
        d2 = torch.cat((d2, encode_layers['e6']), dim=1)
        d3 = self.deconv3(d2)
        d3 = torch.cat((d3, encode_layers['e5']), dim=1)
        d4 = self.deconv4(d3)
        d4 = torch.cat((d4, encode_layers['e4']), dim=1)
        d5 = self.deconv5(d4)
        d5 = torch.cat((d5, encode_layers['e3']), dim=1)
        d6 = self.deconv6(d5)
        d6 = torch.cat((d6, encode_layers['e2']), dim=1)
        d7 = self.deconv7(d6)
        d7 = torch.cat((d7, encode_layers['e1']), dim=1)
        d8 = self.deconv8(d7)        
        fake_target = torch.tanh(d8)
        
        return fake_target

## Discriminator
- Generator가 생성해낸 fake image와, real image인 target image를 입력받아서 진짜일 확률값(0~1) 출력

In [6]:
class Discriminator(nn.Module):
    def __init__(self, category_num, img_dim=2, disc_dim=64):
        super(Discriminator, self).__init__()
        self.conv1 = conv2d(img_dim, disc_dim, bn=False)
        self.conv2 = conv2d(disc_dim, disc_dim*2)
        self.conv3 = conv2d(disc_dim*2, disc_dim*4)
        self.conv4 = conv2d(disc_dim*4, disc_dim*8)
        self.fc1 = fc(disc_dim*8*8*8, 1)
        self.fc2 = fc(disc_dim*8*8*8, category_num)
        
    def forward(self, images):
        batch_size = images.shape[0]
        h1 = self.conv1(images)
        h2 = self.conv2(h1)
        h3 = self.conv3(h2)
        h4 = self.conv4(h3)
        
        tf_loss_logit = self.fc1(h4.reshape(batch_size, -1))
        tf_loss = torch.sigmoid(tf_loss_logit)
        cat_loss = self.fc2(h4.reshape(batch_size, -1))
        
        return tf_loss, tf_loss_logit, cat_loss

## Check shapes
- 각 모델이 생성하는 데이터의 shape들을 확인해보자

In [7]:
fixed_dir = './fixed_sample'

source = torch.load(os.path.join(fixed_dir, 'fixed_source_all.pkl')).cuda()
target = torch.load(os.path.join(fixed_dir, 'fixed_target_all.pkl')).cuda()
label = torch.load(os.path.join(fixed_dir, 'fixed_label_all.pkl')).cuda()
source.shape, target.shape, label.shape

(torch.Size([32, 1, 128, 128]),
 torch.Size([32, 1, 128, 128]),
 torch.Size([32]))

### Encoder

In [8]:
img_dim = 1
conv_dim = 64
encode_layers = dict()

conv1 = conv2d(img_dim, conv_dim, k_size=5, stride=2, pad=2, dilation=2, lrelu=False, bn=False).cuda()
conv2 = conv2d(conv_dim, conv_dim*2, k_size=5, stride=2, pad=2, dilation=2).cuda()
conv3 = conv2d(conv_dim*2, conv_dim*4, k_size=4, stride=2, pad=1, dilation=1).cuda()
conv4 = conv2d(conv_dim*4, conv_dim*8).cuda()
conv5 = conv2d(conv_dim*8, conv_dim*8).cuda()
conv6 = conv2d(conv_dim*8, conv_dim*8).cuda()
conv7 = conv2d(conv_dim*8, conv_dim*8).cuda()
conv8 = conv2d(conv_dim*8, conv_dim*8).cuda()

e1 = conv1(source)
encode_layers['e1'] = e1
e2 = conv2(e1)
encode_layers['e2'] = e2
e3 = conv3(e2)
encode_layers['e3'] = e3
e4 = conv4(e3)
encode_layers['e4'] = e4
e5 = conv5(e4)
encode_layers['e5'] = e5
e6 = conv6(e5)
encode_layers['e6'] = e6
e7 = conv7(e6)
encode_layers['e7'] = e7
encoded_source = conv8(e7)
encode_layers['e8'] = encoded_source

In [9]:
print("source shape:\t", source.shape)
print("e1 shape:\t", e1.shape)
print("e2 shape:\t", e2.shape)
print("e3 shape:\t", e3.shape)
print("e4 shape:\t", e4.shape)
print("e5 shape:\t", e5.shape)
print("e6 shape:\t", e6.shape)
print("e7 shape:\t", e7.shape)
print("encoded shape:\t", encoded_source.shape)

source shape:	 torch.Size([32, 1, 128, 128])
e1 shape:	 torch.Size([32, 64, 64, 64])
e2 shape:	 torch.Size([32, 128, 32, 32])
e3 shape:	 torch.Size([32, 256, 16, 16])
e4 shape:	 torch.Size([32, 512, 8, 8])
e5 shape:	 torch.Size([32, 512, 4, 4])
e6 shape:	 torch.Size([32, 512, 2, 2])
e7 shape:	 torch.Size([32, 512, 1, 1])
encoded shape:	 torch.Size([32, 512, 1, 1])


In [10]:
encoded_source.device

device(type='cuda', index=0)

### Encoded source + embedding

In [11]:
embeddings = torch.load(os.path.join(fixed_dir, 'EMBEDDINGS.pkl')).cuda()
embeddings.shape

torch.Size([100, 1, 1, 128])

In [12]:
local_embeddings = embedding_lookup(embeddings, label, GPU=GPU)
local_embeddings.shape

torch.Size([32, 128, 1, 1])

In [13]:
local_embeddings.device

device(type='cuda', index=0)

In [14]:
embedded = torch.cat((encoded_source, local_embeddings), 1)
embedded.shape

torch.Size([32, 640, 1, 1])

### Decoder

In [15]:
img_dim = 1
embedded_dim = 640
conv_dim = 64

deconv1 = deconv2d(embedded_dim, conv_dim*8, dropout=True).cuda()
deconv2 = deconv2d(conv_dim*16, conv_dim*8, dropout=True, k_size=4).cuda()
deconv3 = deconv2d(conv_dim*16, conv_dim*8, k_size=5, dilation=2, dropout=True).cuda()
deconv4 = deconv2d(conv_dim*16, conv_dim*8, k_size=4, dilation=2, stride=2).cuda()
deconv5 = deconv2d(conv_dim*16, conv_dim*4, k_size=4, dilation=2, stride=2).cuda()
deconv6 = deconv2d(conv_dim*8, conv_dim*2, k_size=4, dilation=2, stride=2).cuda()
deconv7 = deconv2d(conv_dim*4, conv_dim*1, k_size=4, dilation=2, stride=2).cuda()
deconv8 = deconv2d(conv_dim*2, img_dim, k_size=4, dilation=2, stride=2, bn=False).cuda()

d1 = deconv1(embedded)
d1 = torch.cat((d1, encode_layers['e7']), dim=1)
d2 = deconv2(d1)
d2 = torch.cat((d2, encode_layers['e6']), dim=1)
d3 = deconv3(d2)
d3 = torch.cat((d3, encode_layers['e5']), dim=1)
d4 = deconv4(d3)
d4 = torch.cat((d4, encode_layers['e4']), dim=1)
d5 = deconv5(d4)
d5 = torch.cat((d5, encode_layers['e3']), dim=1)
d6 = deconv6(d5)
d6 = torch.cat((d6, encode_layers['e2']), dim=1)
d7 = deconv7(d6)
d7 = torch.cat((d7, encode_layers['e1']), dim=1)
d8 = deconv8(d7)
fake_target = torch.tanh(d8)

In [16]:
print("embedded shape:\t", embedded.shape)
print("d1 shape:\t", d1.shape)
print("concat d1:\t", d1.shape)
print("d2 shape:\t", d2.shape)
print("concat d2:\t", d2.shape)
print("d3 shape:\t", d3.shape)
print("concat d3:\t", d3.shape)
print("d4 shape:\t", d4.shape)
print("concat d4:\t", d4.shape)
print("d5 shape:\t", d5.shape)
print("concat d5:\t", d5.shape)
print("d6 shape:\t", d6.shape)
print("concat d6:\t", d6.shape)
print("d7 shape:\t", d7.shape)
print("concat d7:\t", d7.shape)
print("d8 shape:\t", d8.shape)

embedded shape:	 torch.Size([32, 640, 1, 1])
d1 shape:	 torch.Size([32, 1024, 1, 1])
concat d1:	 torch.Size([32, 1024, 1, 1])
d2 shape:	 torch.Size([32, 1024, 2, 2])
concat d2:	 torch.Size([32, 1024, 2, 2])
d3 shape:	 torch.Size([32, 1024, 4, 4])
concat d3:	 torch.Size([32, 1024, 4, 4])
d4 shape:	 torch.Size([32, 1024, 8, 8])
concat d4:	 torch.Size([32, 1024, 8, 8])
d5 shape:	 torch.Size([32, 512, 16, 16])
concat d5:	 torch.Size([32, 512, 16, 16])
d6 shape:	 torch.Size([32, 256, 32, 32])
concat d6:	 torch.Size([32, 256, 32, 32])
d7 shape:	 torch.Size([32, 128, 64, 64])
concat d7:	 torch.Size([32, 128, 64, 64])
d8 shape:	 torch.Size([32, 1, 128, 128])


### Generator 

In [17]:
En = Encoder().cuda()
De = Decoder().cuda()

fake_image, encoded_source = Generator(source, En, De, embeddings, label, GPU=GPU)
print("fake_image shape:", fake_image.shape)
print("encoded_source shape:", encoded_source.shape)

fake_image shape: torch.Size([32, 1, 128, 128])
encoded_source shape: torch.Size([32, 512, 1, 1])


### Discriminator

In [18]:
real_TS = torch.cat([source, target], dim=1).cuda()
fake_TS = torch.cat([source, fake_image], dim=1).cuda()
real_TS.shape, fake_TS.shape

(torch.Size([32, 2, 128, 128]), torch.Size([32, 2, 128, 128]))

In [19]:
img_dim = 2
disc_dim = 64
category_num = 25

conv1 = conv2d(img_dim, disc_dim, bn=False).cuda()
conv2 = conv2d(disc_dim, disc_dim*2).cuda()
conv3 = conv2d(disc_dim*2, disc_dim*4).cuda()
conv4 = conv2d(disc_dim*4, disc_dim*8).cuda()
fc1 = fc(disc_dim*8*8*8, 1).cuda()
fc2 = fc(disc_dim*8*8*8, category_num).cuda()

In [20]:
batch_size = real_TS.shape[0]
print("batch_size:\t", batch_size)

print("trg & src :\t", real_TS.shape)
h1 = conv1(real_TS)
print("h1 shape:\t", h1.shape)
h2 = conv2(h1).cuda()
print("h2 shape:\t", h2.shape)
h3 = conv3(h2).cuda()
print("h3 shape:\t", h3.shape)
h4 = conv4(h3).cuda()
print("h4 shape:\t", h4.shape)

tf_loss_logit = fc1(h4.reshape(batch_size, -1))
tf_loss = torch.sigmoid(tf_loss_logit)
print("tf_loss shape:\t", tf_loss.shape)
cat_loss = fc2(h4.reshape(batch_size, -1))
print("cat_loss shape:\t", cat_loss.shape)

batch_size:	 32
trg & src :	 torch.Size([32, 2, 128, 128])
h1 shape:	 torch.Size([32, 64, 64, 64])
h2 shape:	 torch.Size([32, 128, 32, 32])
h3 shape:	 torch.Size([32, 256, 16, 16])
h4 shape:	 torch.Size([32, 512, 8, 8])
tf_loss shape:	 torch.Size([32, 1])
cat_loss shape:	 torch.Size([32, 25])


In [21]:
D = Discriminator(category_num=25).cuda()

tf_loss, tf_loss_logit, cat_loss = D(real_TS)
tf_loss.shape, tf_loss_logit.shape, cat_loss.shape

(torch.Size([32, 1]), torch.Size([32, 1]), torch.Size([32, 25]))