# multiple_DR_GAN

ジュネレータに同じ人の複数の画像（異なるポーズ，シーンetc)を入れ，出力されたそれぞれの特徴量を
重み付けして足し合わせた特徴量を元に画像を生成


In [3]:
import os
import numpy as np
import torch
from torch import nn, optim
from torch.autograd import Variable


## Discriminator の定義

single-image_DR_GAN と同じ
> - 論文で用いられている  TensorFlow のConv オプション padding="SAME"と同じ挙動を再現するために padding layer を間に追加
- 入力は バッチ数(B)ｘ96x96x3
- 個人の識別(Nd+1) と　姿勢の推定(Np)を同時に行う

In [4]:
class Discriminator(nn.Module):
    def __init__(self, Nd, Np):
        super(Discriminator, self).__init__()
        convLayers = [
            nn.Conv2d(1, 32, 3, 1, 1, bias=False), # Bx1x96x96 -> Bx32x96x96
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.Conv2d(32, 64, 3, 1, 1, bias=False), # Bx32x96x96 -> Bx64x96x96
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx64x96x96 -> Bx64x97x97
            nn.Conv2d(64, 64, 3, 2, 0, bias=False), # Bx64x97x97 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 64, 3, 1, 1, bias=False), # Bx64x48x48 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 128, 3, 1, 1, bias=False), # Bx64x48x48 -> Bx128x48x48
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx128x48x48 -> Bx128x49x49
            nn.Conv2d(128, 128, 3, 2, 0, bias=False), #  Bx128x49x49 -> Bx128x24x24
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Conv2d(128, 96, 3, 1, 1, bias=False), #  Bx128x24x24 -> Bx96x24x24
            nn.BatchNorm2d(96),
            nn.ELU(),
            nn.Conv2d(96, 192, 3, 1, 1, bias=False), #  Bx96x24x24 -> Bx192x24x24
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx192x24x24 -> Bx192x25x25
            nn.Conv2d(192, 192, 3, 2, 0, bias=False), # Bx192x25x25 -> Bx192x12x12
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.Conv2d(192, 128, 3, 1, 1, bias=False), # Bx192x12x12 -> Bx128x12x12
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Conv2d(128, 256, 3, 1, 0, bias=False), # Bx128x12x12 -> Bx256x12x12
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx256x12x12 -> Bx256x13x13
            nn.Conv2d(256, 256, 3, 2, 0, bias=False),  # Bx256x13x13 -> Bx256x6x6
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.Conv2d(256, 160, 3, 1, 1, bias=False), # Bx256x6x6 -> Bx160x6x6
            nn.BatchNorm2d(160),
            nn.ELU(),
            nn.Conv2d(160, 321, 3, 1, 1, bias=False), # Bx160x6x6 -> Bx320x6x6
            nn.BatchNorm2d(321),
            nn.ELU(),
            nn.AvgPool2d(6, stride=1), #  Bx320x6x6 -> Bx320x1x1
        ]
        
        self.convLayers = nn.Sequential(*convLayers)
        self.fc = nn.Linear(321, Nd+1+Np)
        
        # 重みは全て N(0, 0.02) で初期化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.02)
                
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.02)
        
    def forward(self, input):
        # 畳み込み -> 平均プーリングの結果 B x 321 x 1 x 1の出力を得る
        x = self.convLayers(input)
        
        # １次元の次元を削除
        x = x.squeeze()
        
        # 全結合 
        x = self.fc(x) # Bx321 -> B x (Nd+1+Np)
        
        return x
    

# Generator の定義

- G_enc は 同一人物に n 枚の画像があるとして nB x 1 x 96 x 96 -> B x n x 321 -> B x 320 と特徴量をencode
- G_dec は single-image DR_GANと同じ
> single-image の時
    - G_enc は Discriminator と最後の全結合層が無い以外同じ構造
    - G_dec のアップサンプリング時は，ダウンサンプリング時に Zeropadding を行なったことの逆で，ConvTranspose2d 後に Crop（negative padding?)

In [5]:
## nn.Module を継承しても， super でコンストラクタを呼び出さないと メンバ変数 self._modues が
## 定義されずに後の重み初期化の際にエラーを出す
## sef._modules はモジュールが格納するモジュール名を格納しておくリスト

class Crop(nn.Module):
    def __init__(self, crop_list):
        super().__init__()
        
        # crop_lsit = [crop_top, crop_bottom, crop_left, crop_right]
        self.crop_list = crop_list
            
    def forward(self, x):
        H,W = x.size()
        x = x[ crop_list[0] : H - crop_list[1] , crop_list[2] : W - crop_list[3]]
        
        return x

In [4]:
class Generator(nn.Module):
    def __init__(self, Np, Nz):
        super(Generator, self).__init__()
        G_enc_convLayers = [
            nn.Conv2d(1, 32, 3, 1, 1, bias=False), # nBx1x96x96 -> nBx32x96x96
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.Conv2d(32, 64, 3, 1, 1, bias=False), # nBx32x96x96 -> nBx64x96x96
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # nBx64x96x96 -> nBx64x97x97
            nn.Conv2d(64, 64, 3, 2, 0, bias=False), # nBx64x97x97 -> nBx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 64, 3, 1, 1, bias=False), # nBx64x48x48 -> nBx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 128, 3, 1, 1, bias=False), # nBx64x48x48 -> nBx128x48x48
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # nBx128x48x48 -> nBx128x49x49
            nn.Conv2d(128, 128, 3, 2, 0, bias=False), #  nBx128x49x49 -> nBx128x24x24
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Conv2d(128, 96, 3, 1, 1, bias=False), #  nBx128x24x24 -> nBx96x24x24
            nn.BatchNorm2d(96),
            nn.ELU(),
            nn.Conv2d(96, 192, 3, 1, 1, bias=False), #  nBx96x24x24 -> nBx192x24x24
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # nBx192x24x24 -> nBx192x25x25
            nn.Conv2d(192, 192, 3, 2, 0, bias=False), # nBx192x25x25 -> nBx192x12x12
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.Conv2d(192, 128, 3, 1, 1, bias=False), # nBx192x12x12 -> nBx128x12x12
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Conv2d(128, 256, 3, 1, 0, bias=False), # nBx128x12x12 -> nBx256x12x12
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # nBx256x12x12 -> nBx256x13x13
            nn.Conv2d(256, 256, 3, 2, 0, bias=False),  # nBx256x13x13 -> nBx256x6x6
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.Conv2d(256, 160, 3, 1, 1, bias=False), # nBx256x6x6 -> nBx160x6x6
            nn.BatchNorm2d(160),
            nn.ELU(),
            
            # 同一人物の画像の特徴量を足し合わせる際の重みを示す値 w を１次元分チャネルに追加
            nn.Conv2d(160, 321, 3, 1, 1, bias=False), # nBx160x6x6 -> nBx321x6x6
            nn.BatchNorm2d(321),
            nn.ELU(),
            nn.AvgPool2d(6, stride=1), #  nBx321x6x6 -> nBx320x1x1
            
        ]
        self.G_enc_convLayers = nn.Sequential(*G_enc_convLayers)
        
        G_dec_convLayers = [
            nn.ConvTranspose2d(320,160, 3,1,1, bias=False), # Bx320x6x6 -> Bx160x6x6
            nn.BatchNorm2d(160),
            nn.ELU(),
            nn.ConvTranspose2d(160, 256, 3,1,1, bias=False), # Bx160x6x6 -> Bx256x6x6
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.ConvTranspose2d(160, 256, 3,2,0, bias=False), # Bx256x6x6 -> Bx256x13x13
            nn.BatchNorm2d(256),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(256, 128, 3,1,1, bias=False), # Bx256x12x12 -> Bx128x12x12  
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ConvTranspose2d(128, 192,  3,1,1, bias=False), # Bx128x12x12 -> Bx192x12x12            
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.ConvTranspose2d(192, 192,  3,2,0, bias=False), # Bx128x12x12 -> Bx192x25x25            
            nn.BatchNorm2d(192),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(192, 96,  3,1,1, bias=False), # Bx192x24x24 -> Bx96x24x24 
            nn.BatchNorm2d(96),
            nn.ELU(),
            nn.ConvTranspose2d(96, 128,  3,1,1, bias=False), # Bx96x24x24 -> Bx128x24x24
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ConvTranspose2d(128, 128,  3,2,0, bias=False), # Bx128x24x24 -> Bx128x49x49      
            nn.BatchNorm2d(128),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(128, 64,  3,1,1, bias=False), # Bx128x48x48 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ConvTranspose2d(64, 64,  3,1,1, bias=False), # Bx64x48x48 -> Bx64x48x48  
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ConvTranspose2d(64, 64,  3,2,0, bias=False), # Bx64x48x48 -> Bx64x97x97  
            nn.BatchNorm2d(128),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(64, 32,  3,1,1, bias=False), # Bx64x96x96 -> Bx32x96x96 
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ConvTranspose2d(32, 1,  3,1,1, bias=False), # Bx32x96x96 -> Bx1x96x96 
            nn.ELU(),
        ]
        
        self.G_dec_convLayers = nn.Sequential(*G_dec_convLayers)
        
        self.G_dec_fc = nn.Linear(320+Np+Nz, 320*6*6)
        
        # 重みは全て N(0, 0.02) で初期化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.02)
                
            elif isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0, 0.02)
                
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.02)
        
        
        
    def forward(self, input, pose, noise):
        
        x = self.convLayers(input) # Bx1x96x96 -> Bx320x1x1
        
        x = x.squeeze()
    
        x = torch.cat([x, pose, noise], 1)  # Bx320 -> B x (320+Np+Nz)
        
        x = self.G_dec_fc(x) # B x (320+Np+Nz) -> B x (320x6x6)
    
        x = x.view(-1, 320, 6, 6) # B x (320x6x6) -> B x 320 x 6 x 6
        
        x = self.G_dec_convLayers(x) #  B x 320 x 6 x 6 -> Bx1x96x96
        
        return x
    

# 画像の取得

# 訓練の実行

In [6]:
batch_size = 64 
epoch = 10000
# image_size = images.shape[0]
# epoch_time = np.ceil(image_size / batch_size)

Nd = 200 # number of ID (person)
Np = 9 # number of discrite poses
Nz = 50 # number of noise dimension

lr_Adam = 0.0002
m_Adam = 0.5

D = Discriminator(Nd, Np)
G = Generator(Np, Nz)
optimizer_D = optim.Adam(D.parameters())
optimizer_G = optim.Adam(G.parameters())
loss_criterion = nn.CrossEntropyLoss()

for epoch in range(epoch):
    for i in range(epoch_time):
        D.zero_grad()
        G.zero_grad()
        start = i*batch_size
        end = start + batch+size
        batch_image = images[start:end]
        batch_id_label = id_labels[start:end]
        batch_pose_label = pose_labels[start:end]
        minibatch_size = len(batch_image)
        
        # 学習の中で使われるVariable変数の定義
        # ラベルの定義
        img_tensor = Variable(torch.FloatTensor(batch_image))
        id_label_tensor = Variable(torch.FloatTensor(batch_id_label))
        pose_label_tensor = Variable(torch.FloatTensor(batch_pose_label))
        syn_id_label = np.zeros((batch_size, Nd+1))
        syn_id_label[:,-1] = 1
        syn_id_label_tensor = Variable(torch.FloatTensor(syn_id_label))
        
        # ノイズと姿勢コードを生成
        fixed_noise_tensor = Variable(torch.FloatTensor(np.random.uniform(-1,1, (batch_size, Nz))))
        pose_code = np.zeros((batch_size, Np))
        pose_code[:, np.random.randint(Np)] = 1
        pose_code_tensor = Variable(torch.FloatTensor(pose_code))
        
        # Generatorでイメージ生成
        generated = G(img_tensor, pose_code_tensor, fixed_noise_tensor)
        
        # イテレーション毎に交互に D と G の学習，　90%以上の精度の場合は 1:4の比率で学習
        if i%2==0:
            # Discriminator の学習
            real_ouput = D(img_tensor)
            syn_ouput = D(generated.detach()) # .detach() をすることでGeneratorのパラメータを更新しない
            
            # id についての出力とラベル, pose についての出力とラベル それぞれの交差エントロピー誤差を計算
            d_loss = loss_criterion(real_ouput[:Nd+1], id_label_tensor) +\
                                    loss_criterion(real_ouput[Nd+1:], pose_label_tensor) +\
                                    loss_criterion(syn_ouput[Nd+1:], syn_id_label_tensor)
                    
            if d_loss.data[0] > 0.1:
                d_loss.backward()
                optimizer_D.step()
                print("EPOCH : {0}, D : {1}".format(e, d_loss.data[0]))
        else:
            # Generatorの学習
            syn_ouput=D(generated)
            
            # id についての出力と元画像のラベル, poseについての出力と生成時に与えたposeコード それぞれの交差エントロピー誤差を計算
            g_loss = loss_criterion(syn_ouput[:Nd+1], id_label_tensor) +\
                                loss_criterion(syn_ouput[Nd+1:], pose_code_tensor) +\
            
            optimizer_G.step()
            print("EPOCH : {0}, D : {1}".format(e, g_loss.data[0]))
    
    # 各エポックで学習したモデルを保存， 学習した生成器から画像を生成して保存
    torch.save(D, "D.model")
    torch.save(G, "G.model")
    fixed_noise_tensor = Variable(torch.FloatTensor(np.random.uniform(-1,1, (batch_size, Nz))))
    pose_code = np.zeros((batch_size, Np))
    pose_code[:, np.random.randint(Np)] = 1
    pose_code_tensor = Variable(torch.FloatTensor(pose_code))
    #generated = G()
    
    

In [23]:
syn_id_label = np.zeros((batch_size, Nd+1))
syn_id_label[:,-1] = 1
syn_id_label_tensor = Variable(torch.FloatTensor(syn_id_label))

In [41]:
np.random.randint(Np)

7

In [46]:
a = np.array([1,2,3,4,5])

In [49]:
a[2:]

array([3, 4, 5])