# Disco GAN for Gender Exchange
- [github and paper](https://github.com/SKTBrain/DiscoGAN)
- [celebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
    - [images](https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg)
    - [attributes](https://drive.google.com/drive/folders/0B7EVK8r0v71pOC0wOVZlQnFfaGs)

In [48]:
%matplotlib inline
import matplotlib.pyplot as plt

In [133]:
from PIL import Image
import torch
from torch.autograd import Variable
from torch import nn, optim
from torch.utils import data
from torchvision import transforms
from torchvision import utils
from torch.nn import functional as F
from os import path

import pandas as pd
from itertools import chain
import random

## data and preprocessing

In [138]:
## Better to implemente as image pairs

class CelebMaleFemale(data.Dataset):
    def __init__(self, image_folder, attr_list, transform=None):
        self.image_folder = image_folder
        self.attributes = self.parse_attributes(attr_list)
        self.transform = transform
        self.male_images = self.attributes[self.attributes.Male=="1"].image_path
        self.female_images = self.attributes[self.attributes.Male=="-1"].image_path
    def __len__(self):
        return max(len(self.male_images), len(self.female_images))
    def __getitem__(self, i):
        imale = random.randint(0, len(self.male_images)-1)
        male_img = Image.open(path.join(self.image_folder, self.male_images.iloc[imale]))
        ifemale = random.randint(0, len(self.female_images)-1)
        female_img = Image.open(path.join(self.image_folder, self.female_images.iloc[ifemale]))
        if self.transform is not None:
            male_img = self.transform(male_img)
            female_img = self.transform(female_img)
        return (male_img, female_img)
    def parse_attributes(self, attr_list):
        attributes = []
        lines = open(attr_list).readlines()
        columns = ["image_path"] + lines[1].split()
        assert len(columns) == 41
        for line in lines[2:]:
            fields = line.split()
            assert len(fields) == 41
            attributes.append(fields)
        return pd.DataFrame(attributes, columns=columns)

In [139]:
image_folder = "/home/dola/ws/data/celebA/img_align_celeba/"
attr_list = "/home/dola/ws/data/celebA/list_attr_celeba.txt"
transform = transforms.Compose([
    transforms.Scale(64),
    transforms.ToTensor()
])
celeb_images = CelebMaleFemale(image_folder, attr_list, transform)

## Model

### discriminator

In [111]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.cnn = nn.Sequential(
            # conv1
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False), # (64-4+1*2)/2+1=32
            nn.ELU(inplace=True),
            # conv2
            nn.Conv2d(64, 64*2, kernel_size=4, stride=2, padding=1, bias=False), #(32-4+1*2)/2+1=16
            nn.BatchNorm2d(64*2),
            nn.ELU(inplace=True),
            # conv3
            nn.Conv2d(64*2, 64*4, kernel_size=4, stride=2, padding=1, bias=False), #(16-4+2*1)/2+1=8
            nn.BatchNorm2d(64*4),
            nn.ELU(inplace=True),
            # conv4
            nn.Conv2d(64*4, 64*8, kernel_size=4, stride=2, padding=1, bias=False), #(8-4+2*1)/2+1=4
            nn.BatchNorm2d(64*8),
            nn.ELU(inplace=True),
            # conv5
            nn.Conv2d(64*8, 1, kernel_size=4, stride=1, padding=0, bias=False), #(4-4+0)/1+1=1
        )
    def forward(self, x):
        out = self.cnn(x)
        prob = F.sigmoid(out)
        return prob

In [113]:
d = Discriminator().cuda()
x = Variable(torch.rand([1, 3, 64, 64])).cuda()
d(x).size()

torch.Size([1, 1, 1, 1])

### generator

In [129]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            # (64-4+2)/2+1=32
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.ELU(inplace=True),
            # (32-4+2)/2+1=16
            nn.Conv2d(64, 64*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64*2),
            nn.ELU(inplace=True),
            # (16-4+2)/2+1=8
            nn.Conv2d(64*2, 64*4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64*4),
            nn.ELU(inplace=True),
            # (8-4+2)/2+1=4
            nn.Conv2d(64*4, 64*8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64*8),
            nn.ELU(inplace=True)
        )
        self.decoder = nn.Sequential(
            # (4-1)*2+4-2=8
            nn.ConvTranspose2d(64*8, 64*4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64*4),
            nn.ELU(inplace=True),
            # (8-1)*2+4-2=16
            nn.ConvTranspose2d(64*4, 64*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64*2),
            nn.ELU(inplace=True),
            # (16-1)*2+4-2=32
            nn.ConvTranspose2d(64*2, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ELU(inplace=True),
            # (32-1)*2+4-2=64
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False)
        )
    def forward(self, x):
        out = self.encoder(x)
        out = self.decoder(out)
        out = F.sigmoid(out) # normalize to [0,1]
        return out

In [130]:
g = Generator().cuda()
x = Variable(torch.rand([1, 3, 64, 64])).cuda()
g(x).size()

torch.Size([1, 3, 64, 64])

## Train

In [147]:
n_epochs = 5
batch_size = 128

generator_male = Generator().cuda()
generator_female = Generator().cuda()
discriminator_male = Discriminator().cuda()
discriminator_female = Discriminator().cuda()

reconstruction_objective = nn.MSELoss()
gan_criterion = nn.BCELoss()

image_pairs = data.DataLoader(celeb_images, batch_size=batch_size, shuffle=False, num_workers=2)

for epoch in range(n_epochs):
    for male_images, female_images in image_pairs:
        
        ## variables
        
        male = Variable(male_images).cuda()
        female = Variable(female_images).cuda()
        
        male2female = generator_female(male)
        female2male = generator_male(female)
        
        male2female2male = generator_male(male2female)
        female2male2female = generator_female(female2male)
        
        ## reconstruction loss
        recon_loss_male = reconstruction_objective(male, male2female2male)
        recon_loss_female = reconstruction_objective(female, female2male2female)
        
        

Process Process-7:
Process Process-8:
Traceback (most recent call last):
  File "/home/dola/anaconda3/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/dola/anaconda3/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/home/dola/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dola/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dola/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 35, in _worker_loop
    r = index_queue.get()
  File "/home/dola/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 41, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/dola/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader

KeyboardInterrupt: 