# StarGAN 실습

In [None]:
# pyTorch 관련 된 라이브러리.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim # optimization에 관한 모듈.
import torchvision # 이미지 관련 전처리, pretrained된 모델, 데이터 로딩에 관한 패키지입니다.
from torchvision.utils import save_image # 이미지 저장을 위한 torchvision의 모듈
import torchvision.datasets as vision_dsets
import torchvision.transforms as T # 이미지 전처리 모듈입니다.
from torchvision.datasets import ImageFolder
from torch.utils import data

# 기타 필요한 라이브러리.
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import datetime
import random
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

## StarGAN

#### StarGAN은 multi-domain의 image translation이 가능하다는 것에 가장 큰 contribution을 갖는다.

![구조](./imgs/4_stargan1.jpg)

---

#### 위 구조를 실제 dataset 상황에서 보자면 아래와 같다.



![구조2](./imgs/4_stargan2.jpg)

# Training에 사용될 Hyper-parameter를 지정합니다.

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # tensor.to(device) 방식을 통해서 cpu -> gpu로 보낼 수 있습니다.
g_lr = 0.0001 # learning rate for Generator
d_lr = 0.0001 # learning rate for Discriminator
batch_size = 16 # mini-batch size

# Training에 사용될 데이터를 불러옵니다.

!bash download.sh celeba

## Data Loader

CelebA의 경우에는 pytorch에서 기본으로 제공하는 data loader가 없기에 조금 복잡한 data loader를 구현해줘야 합니다. dataloader의 경우는 각각의 dataset들마다 코딩해줘야하는 방향이 천차만별입니다.

In [None]:
class CelebA(data.Dataset):
    """Dataset class for the CelebA dataset."""

    def __init__(self, image_dir, attr_path, selected_attrs, transform, mode):
        """Initialize and preprocess the CelebA dataset."""
        self.image_dir = image_dir
        self.attr_path = attr_path
        self.selected_attrs = selected_attrs
        self.transform = transform
        self.mode = mode
        self.train_dataset = []
        self.test_dataset = []
        self.attr2idx = {}
        self.idx2attr = {}
        self.preprocess()

        if mode == 'train':
            self.num_images = len(self.train_dataset)
        else:
            self.num_images = len(self.test_dataset)

    def preprocess(self):
        """Preprocess the CelebA attribute file."""
        lines = [line.rstrip() for line in open(self.attr_path, 'r')]
        all_attr_names = lines[1].split()
        for i, attr_name in enumerate(all_attr_names):
            self.attr2idx[attr_name] = i
            self.idx2attr[i] = attr_name

        lines = lines[2:]
        random.seed(1234)
        random.shuffle(lines)
        for i, line in enumerate(lines):
            split = line.split()
            filename = split[0]
            values = split[1:]

            label = []
            for attr_name in self.selected_attrs:
                idx = self.attr2idx[attr_name]
                label.append(values[idx] == '1')

            if (i+1) < 2000:
                self.test_dataset.append([filename, label])
            else:
                self.train_dataset.append([filename, label])

        print('Finished preprocessing the CelebA dataset...')

    def __getitem__(self, index):
        """Return one image and its corresponding attribute label."""
        dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
        filename, label = dataset[index]
        image = Image.open(os.path.join(self.image_dir, filename))
        return self.transform(image), torch.FloatTensor(label)

    def __len__(self):
        """Return the number of images."""
        return self.num_images
    
def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128, 
               batch_size=16, dataset='CelebA', mode='train', num_workers=1):
    """Build and return a data loader."""
    transform = []
    if mode == 'train':
        transform.append(T.RandomHorizontalFlip())
    transform.append(T.CenterCrop(crop_size))
    transform.append(T.Resize(image_size))
    transform.append(T.ToTensor())
    transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = T.Compose(transform)

    dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode)

    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=(mode=='train'),
                                  num_workers=num_workers)
    return data_loader

!bash download.sh celeba

In [None]:
#Dataloader들을 불러옵니다.
celeba_image_dir = 'data/celeba/images'
attr_path = 'data/celeba/list_attr_celeba.txt'
selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']
celeba_crop_size = 178
image_size = 128
num_workers = 1

celeba_trainloader = get_loader(celeba_image_dir, attr_path, selected_attrs,
                           celeba_crop_size, image_size, batch_size,
                           'CelebA', mode='train', num_workers=num_workers)

celeba_testloader = get_loader(celeba_image_dir, attr_path, selected_attrs,
                           celeba_crop_size, image_size, batch_size,
                           'CelebA', mode='test', num_workers=num_workers)


NameError: name 'get_loader' is not defined

## 모델 구현

### ResidualBlock

Residual Block 이라는 개념은 2015년 Imagenet Challenge에서 ResNet이라는 모델이 우승을 차지하며 그 효과를 입증받은 모델이다.

![resnet](./imgs/4_stargan3.png)

In [None]:
'''
코드 단순화를 위한 함수들을 정의해 줍니다.
'''
class ResidualBlock(nn.Module):
    """Residual Block with instance normalization."""
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))

    def forward(self, x):
        return x + self.main(x)

def conv(c_in, c_out, k_size, stride=2, pad=1, In=True, activation='relu'):
    """ 
    코드 단순화를 위한 convolution block 생성을 위한 함수입니다.
    Conv -> Instancenorm -> Activation function 으로 이어지는 일련의 레이어를 생성합니다.
    """
    layers = []
    
    # Conv.
    layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=False))
    
    # Instance Norm
    if In:
        layers.append(nn.InstanceNorm2d(c_out, affine=True, track_running_stats=True))
    
    # Activation
    if activation == 'lrelu':
        layers.append(nn.LeakyReLU(0.01))
    if activation == 'tanh':
        layers.append(nn.Tanh())
    if activation == 'relu':
        layers.append(nn.ReLU(inplace=True))
    if activation == 'none':
        pass
                
    return nn.Sequential(*layers)
    
def deconv(c_in, c_out, k_size, stride=2, pad=1, In=True, activation='relu'):
    """ 
    코드 단순화를 위한 deconvolution block 생성을 위한 함수입니다.
    Deconv -> Activation function 으로 이어지는 일련의 레이어를 생성합니다.
    """
    
    layers = []
    
    # Deconv.
    layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=False))
    
    if In:
        layers.append(nn.InstanceNorm2d(c_out, affine=True, track_running_stats=True))
    
    # Activation
    if activation == 'lrelu':
        layers.append(nn.LeakyReLU(0.01))
    if activation == 'tanh':
        layers.append(nn.Tanh())
    if activation == 'relu':
        layers.append(nn.ReLU(inplace=True))
    if activation == 'none':
        pass
                
    return nn.Sequential(*layers)

In [None]:
'''
Generator와 Discriminator를 선언해 줍니다. 
이 때 StarGAN의 Generator에는 input image와 target attribute에 대한 입력값을 받는 것을 기억합니다!!! 
'''

class Generator(nn.Module):
    """Generator network."""
    def __init__(self, conv_dim=64, c_dim=5):
        super(Generator, self).__init__()

        layers = []
        self.conv1 = conv(3+c_dim, conv_dim, 7, 1, 3)

        # Down-sampling layers.
        self.conv2 = conv(conv_dim, conv_dim * 2, 4, 2, 1)
        self.conv3 = conv(conv_dim * 2, conv_dim * 4, 4, 2, 1)

        # Bottleneck layers.
        self.res1 = ResidualBlock(dim_in=conv_dim * 4, dim_out=conv_dim * 4)
        self.res2 = ResidualBlock(dim_in=conv_dim * 4, dim_out=conv_dim * 4)
        self.res3 = ResidualBlock(dim_in=conv_dim * 4, dim_out=conv_dim * 4)
        self.res4 = ResidualBlock(dim_in=conv_dim * 4, dim_out=conv_dim * 4)
        self.res5 = ResidualBlock(dim_in=conv_dim * 4, dim_out=conv_dim * 4)
        self.res6 = ResidualBlock(dim_in=conv_dim * 4, dim_out=conv_dim * 4)

        # Up-sampling layers.
        self.deconv1 = deconv(conv_dim * 4, conv_dim * 2, 4, 2, 1)
        self.deconv2 = deconv(conv_dim * 2, conv_dim, 4, 2, 1)
        self.conv4 = conv(conv_dim, 3, 7, 1, 3, In=False, activation=None)
        self.tanh = nn.Tanh()

    def forward(self, x, c):
        # target domain인 c를 입력 이미지 x의 사이즈와 동일하게 만들어준뒤, concatenate하여 모델에 입력으로 줍니다.
        c = c.view(c.size(0), c.size(1), 1, 1)
        c = c.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat([x, c], dim=1)
        
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        
        out = self.res1(out)
        out = self.res2(out)
        out = self.res3(out)
        out = self.res4(out)
        out = self.res5(out)
        out = self.res6(out)
        
        out = self.deconv1(out)
        out = self.deconv2(out)
        out = self.conv4(out)
        out = self.tanh(out)
        
        return out

class Discriminator(nn.Module):
    """Discriminator network with PatchGAN."""
    def __init__(self, image_size=128, conv_dim=64, c_dim=5):
        super(Discriminator, self).__init__()
        
        self.conv1 = conv(3, conv_dim, 4, 2, 1, In=False, activation='lrelu')
        
        self.conv2 = conv(conv_dim, conv_dim * 2, 4, 2, 1, In=False, activation='lrelu')
        self.conv3 = conv(conv_dim * 2, conv_dim * 4, 4, 2, 1, In=False, activation='lrelu')
        self.conv4 = conv(conv_dim * 4, conv_dim * 8, 4, 2, 1, In=False, activation='lrelu')
        self.conv5 = conv(conv_dim * 8, conv_dim * 16, 4, 2, 1, In=False, activation='lrelu')
        self.conv6 = conv(conv_dim * 16, conv_dim * 32, 4, 2, 1, In=False, activation='lrelu')

        kernel_size = int(image_size / np.power(2, 6))
        self.gen = nn.Conv2d(conv_dim * 32, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.cls = nn.Conv2d(conv_dim * 32, c_dim, kernel_size=kernel_size, bias=False)
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = self.conv6(out)
        
        out_gen = self.gen(out)
        out_cls = self.cls(out)

        return out_gen, out_cls.view(out_cls.size(0), out_cls.size(1))

## 학습을 위한 사전 모델/optimizer 선언 & 기타준비

In [None]:
# 모델 선언 및 train모드로 만들어 줍니다.
Gen = Generator(64, 5).train()
Dis = Discriminator(128, 64, 5).train()

# 선언한 모델들을 GPU에서 사용할 수 있도록 해 줍니다.
Gen.to(device)
Dis.to(device)

# Optimizer 선언
g_optimizer = torch.optim.Adam(Gen.parameters(), lr=g_lr, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(Dis.parameters(), lr=d_lr, betas=(0.5, 0.999))

# Iteration 수 선언
num_iters = 20000

# trainig 과정에서 생성되는 이미지가 어떻게 변화하는지 볼 수 있도록 고정된 샘플 데이터를 지정합니다.
celeba_iter = iter(celeba_trainloader)
x_fixed, c_org = next(celeba_iter)
x_fixed = x_fixed.to(device)
sample_step = 10000
sample_dir = 'stargan_celeba/samples' #샘플 데이터에 대한 모델 출력 결과가 저장되는 디렉토리입니다.
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
    
#training 과정에서 모델의 파라미터를 저장합니다.
model_save_dir = 'stargan_celeba/models'
if not os.path.exists(model_save_dir):
    os.makedirs(model_save_dir)

def create_labels(c_org, c_dim=5, selected_attrs=None):
        """Generate target domain labels for debugging and testing."""
        # Get hair color indices.
        hair_color_indices = []
        for i, attr_name in enumerate(selected_attrs):
            if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
                hair_color_indices.append(i)

        c_trg_list = []
        for i in range(c_dim):
            c_trg = c_org.clone()
            if i in hair_color_indices:  # Set one hair color to 1 and the rest to 0.
                c_trg[:, i] = 1
                for j in hair_color_indices:
                    if j != i:
                        c_trg[:, j] = 0
            else:
                c_trg[:, i] = (c_trg[:, i] == 0)  # Reverse attribute value.

            c_trg_list.append(c_trg.to(device))
        return c_trg_list
    
c_fixed_list = create_labels(c_org, 5, selected_attrs=selected_attrs)

#training 과정에서 사용되는 gradient penalty loss의 함수입니다.
def gradient_penalty(y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm-1)**2)

# 이미지 저장 과정에서 사용되는 스케일링 함수입니다.
def denorm(x):
    """Convert the range from [-1, 1] to [0, 1]."""
    out = (x + 1) / 2
    return out.clamp_(0, 1)

## Training StarGAN 

여기서의 loss는 크게 7가지로 나누어 집니다.

- D: Real images들을 1로 분류하기 위한 loss (d_loss_real)
- D: Fake images들을 0로 분류하기 위한 loss (d_loss_fake)
- D: Images의 domain을 분류하기 위한 loss (d_loss_cls)
- D: gradient penalty loss (d_loss_gp)

- G: D를 속이는 Fake images들을 만들기 위한 loss (D에서 1로 분류함) (g_loss_fake)
- G: 다시 돌아 갔을 때 reconstruction을 위한 loss (g_loss_rec)
- G: Images의 domain을 분류하기 위한 loss (g_loss_cls)

In [None]:
start_time = time.time()
for step in range(num_iters):
    try:
        x_real, label_org = next(celeba_iter)
    except:
        celeba_iter = iter(celeba_trainloader)
        x_real, label_org = next(celeba_iter)
        
    # 변환하고자 하는 domain의 label을 랜덤생성합니다.
    rand_idx = torch.randperm(label_org.size(0))
    label_trg = label_org[rand_idx]

    c_org = label_org.clone()
    c_trg = label_trg.clone()

    x_real = x_real.to(device)           # Input images.
    c_org = c_org.to(device)             # Original domain labels.
    c_trg = c_trg.to(device)             # Target domain labels.
    label_org = label_org.to(device)     # Labels for computing classification loss.
    label_trg = label_trg.to(device)     # Labels for computing classification loss.

    # =================================================================================== #
    #                             1. Train the discriminator                              #
    # =================================================================================== #

    # Compute loss with real images.
    out_src, out_cls = Dis(x_real)
    d_loss_real = - torch.mean(out_src)
    d_loss_cls = F.binary_cross_entropy_with_logits(out_cls, label_org, size_average=False) / out_cls.size(0)
    
    # Compute loss with fake images.
    x_fake = Gen(x_real, c_trg)
    out_src, out_cls = Dis(x_fake.detach())
    d_loss_fake = torch.mean(out_src)

    # Compute loss for gradient penalty.
    alpha = torch.rand(x_real.size(0), 1, 1, 1).to(device)
    x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
    out_src, _ = Dis(x_hat)
    d_loss_gp = gradient_penalty(out_src, x_hat)

    # Backward and optimize.
    d_loss = d_loss_real + d_loss_fake + 1.0 * d_loss_cls + 10.0 * d_loss_gp
    Dis.zero_grad()
    d_loss.backward()
    d_optimizer.step()

    # Logging.
    loss = {}
    loss['D/loss_real'] = d_loss_real.item()
    loss['D/loss_fake'] = d_loss_fake.item()
    loss['D/loss_cls'] = d_loss_cls.item()
    loss['D/loss_gp'] = d_loss_gp.item()

    # =================================================================================== #
    #                               2. Train the generator                                #
    # =================================================================================== #

    if (step +1) % 5 == 0:
        # Original-to-target domain.
        x_fake = Gen(x_real, c_trg)
        out_src, out_cls = Dis(x_fake)
        g_loss_fake = - torch.mean(out_src)
        g_loss_cls = F.binary_cross_entropy_with_logits(out_cls, label_trg, size_average=False) / out_cls.size(0)

        # Target-to-original domain.
        x_reconst = Gen(x_fake, c_org)
        g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

        # Backward and optimize.
        g_loss = g_loss_fake + 10.0 * g_loss_rec + 1.0 * g_loss_cls
        Gen.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # Logging.
        loss['G/loss_fake'] = g_loss_fake.item()
        loss['G/loss_rec'] = g_loss_rec.item()
        loss['G/loss_cls'] = g_loss_cls.item()

    if (step + 1) % 10 == 0: #10 iteration마다 학습 중의 Loss들을 출력합니다.
            et = time.time() - start_time
            et = str(datetime.timedelta(seconds=et))[:-7]
            log = "Elapsed [{}], Iteration [{}/{}]".format(et, step+1, num_iters)
            for tag, value in loss.items():
                log += ", {}: {:.4f}".format(tag, value)
            print(log)

    if (step + 1) % 1000 == 0: #1000 iteration마다 특정 이미지의 translation 결과를 저장합니다.
        with torch.no_grad():
            x_fake_list = [x_fixed]
            for c_fixed in c_fixed_list:
                x_fake_list.append(Gen(x_fixed, c_fixed))
            x_concat = torch.cat(x_fake_list, dim=3)
            sample_path = os.path.join(sample_dir, '{}-images.jpg'.format(step+1))
            save_image(denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
            print('Saved real and fake images into {}...'.format(sample_path))
            
    if (step + 1) % 10000 == 0: #10000 iteration마다 모델을 저장합니다.
        G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(step+1))
        D_path = os.path.join(model_save_dir, '{}-D.ckpt'.format(step+1))
        torch.save(Gen.state_dict(), G_path)
        torch.save(Dis.state_dict(), D_path)
        print('Saved model checkpoints into {}...'.format(model_save_dir))


## Testing StarGAN 

In [None]:
resume_iters = 30

# resume_iters만큼 학습하여 저장한 모델의 파라미터를 Load합니다.
print('Loading the trained models from step {}...'.format(resume_iters))
Gen = Generator(64, 5)

Gen.to(device)

G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(resume_iters))
Gen.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))

#결과를 저장할 디렉토리 입니다.
result_dir = 'stargan_celeba/results'
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

#test 이미지에 대한 이미지 변환 결과를 출력합니다.
with torch.no_grad():
    for i, (x_real, c_org) in enumerate(celeba_testloader):
        # Prepare input images and target domain labels.
        x_real = x_real.to(device)
        c_trg_list = create_labels(c_org, 5, selected_attrs=selected_attrs)

        # Translate images.
        x_fake_list = [x_real]
        for c_trg in c_trg_list:
            x_fake_list.append(Gen(x_real, c_trg))

        # Save the translated images.
        x_concat = torch.cat(x_fake_list, dim=3)
        result_path = os.path.join(result_dir, '{}-images.jpg'.format(i+1))
        save_image(denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
        print('Saved real and fake images into {}...'.format(result_path))

## Pretrained Weight Download

StarGAN의 경우 모델 학습에 요구되는 iteration 수가 크기 때문에 실습 시간이라는 제한 내에서 결과를 보기 어렵습니다. 따라서 이번 실습에서는 위의 train code과 test code의 진행 가능 여부를 확인했으니, 많은 iteration으로 미리 학습해둔, 즉 pretrained weight를 통해 test를 진행하겠습니다.

Pretrained Weight는 보통 pt, pth, ckpt 등의 확장자를 가지고 있는 파일로, 이전에 학습해둔 모델의 weight 값들을 모델의 구조에 맞추어 저장하고 있습니다.

!bash download.sh pretrained-celeba-128x128

In [None]:
# resume_iters만큼 학습하여 저장한 모델의 파라미터를 Load합니다.
print('Loading the trained models with pretrained weights...')
resume_iters = 200000

Gen = Generator(64, 5)

Gen.to(device)

G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(resume_iters))

G_pretrained_dict = torch.load(G_path, map_location=lambda storage, loc: storage)
G_new_pretrained_dict = {}

G_dict = Gen.state_dict()

for k in list(G_dict.keys()):
    if k.split('.')[-1] == 'num_batches_tracked':
        del G_dict[k]

key_list = list(G_dict.keys())
value_list = list(G_pretrained_dict.values())
G_new_pretrained_dict = {}

for i in range(len(G_pretrained_dict.items())):
    dict_key = key_list[i]
    dict_value = value_list[i]

    G_new_pretrained_dict.update({dict_key : dict_value})
    
G_dict.update(G_new_pretrained_dict)
Gen.load_state_dict(G_new_pretrained_dict)

#결과를 저장할 디렉토리 입니다.
result_dir = 'stargan_celeba/results'
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

#test 이미지에 대한 이미지 변환 결과를 출력합니다.
with torch.no_grad():
    for i, (x_real, c_org) in enumerate(celeba_testloader):
        # Prepare input images and target domain labels.
        x_real = x_real.to(device)
        c_trg_list = create_labels(c_org, 5, selected_attrs=selected_attrs)

        # Translate images.
        x_fake_list = [x_real]
        for c_trg in c_trg_list:
            x_fake_list.append(Gen(x_real, c_trg))

        # Save the translated images.
        x_concat = torch.cat(x_fake_list, dim=3)
        result_path = os.path.join(result_dir, '{}-images.jpg'.format(i+1))
        save_image(denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
        print('Saved real and fake images into {}...'.format(result_path))