Skip to content

Commit

Permalink
Merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
elvisyjlin committed Aug 20, 2019
2 parents 9b4ce5c + 257cf58 commit 97f1502
Show file tree
Hide file tree
Showing 20 changed files with 269 additions and 298 deletions.
91 changes: 54 additions & 37 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
A PyTorch implementation of AttGAN - [Arbitrary Facial Attribute Editing: Only Change What You Want](https://arxiv.org/abs/1711.10678)

![Teaser](https://github.com/elvisyjlin/AttGAN-PyTorch/blob/master/pics/teaser.jpg)
Test on the CelebA validating set

![Custom](https://github.com/elvisyjlin/AttGAN-PyTorch/blob/master/pics/custom_images.jpg)
Test on my custom set

Inverting 13 attributes respectively. From left to right: _Input, Reconstruction, Bald, Bangs, Black_Hair, Blond_Hair, Brown_Hair, Bushy_Eyebrows, Eyeglasses, Male, Mouth_Slightly_Open, Mustache, No_Beard, Pale_Skin, Young_

Expand Down Expand Up @@ -35,7 +39,7 @@ pip3 install --upgrade torch==0.4.0
* Please see [here](https://github.com/willylulu/celeba-hq-modified).
* _Images_ should be placed in `./data/celeba-hq/celeba-*/*.jpg`
* _Image list_ should be placed in `./data/image_list.txt`
* [Pretrained models](https://goo.gl/mQkqNo): download the models you need and unzip the files to `./output/` as below,
* [Pretrained models](http://bit.ly/attgan-pretrain): download the models you need and unzip the files to `./output/` as below,
```text
output
├── 128_shortcut1_inject0_none
Expand All @@ -50,75 +54,88 @@ pip3 install --upgrade torch==0.4.0

## Usage

To train an AttGAN on CelebA 128x128
#### To train an AttGAN on CelebA 128x128

```bash
CUDA_VISIBLE_DEVICES=0 \
python train.py \
--img_size 128 \
--shortcut_layers 1 \
--inject_layers 1 \
--experiment_name 128_shortcut1_inject1_none \
CUDA_VISIBLE_DEVICES=0 \
python train.py \
--img_size 128 \
--shortcut_layers 1 \
--inject_layers 1 \
--experiment_name 128_shortcut1_inject1_none \
--gpu
```

To train an AttGAN on CelebA-HQ 256x256 with multiple GPUs
#### To train an AttGAN on CelebA-HQ 256x256 with multiple GPUs

```bash
CUDA_VISIBLE_DEVICES=0 \
python train.py \
--data CelebA-HQ \
--img_size 256 \
--shortcut_layers 1 \
--inject_layers 1 \
--experiment_name 256_shortcut1_inject1_none_hq \
--gpu
CUDA_VISIBLE_DEVICES=0 \
python train.py \
--data CelebA-HQ \
--img_size 256 \
--shortcut_layers 1 \
--inject_layers 1 \
--experiment_name 256_shortcut1_inject1_none_hq \
--gpu \
--multi_gpu
```

To visualize training details
#### To visualize training details

```bash
tensorboard \
--logdir ./output
```

To test with single attribute editing
#### To test with single attribute editing

![Test](https://github.com/elvisyjlin/AttGAN-PyTorch/blob/master/pics/sample_testing.jpg)

```bash
CUDA_VISIBLE_DEVICES=0 \
python test.py \
--experiment_name 128_shortcut1_inject1_none \
--test_int 1.0 \
CUDA_VISIBLE_DEVICES=0 \
python test.py \
--experiment_name 128_shortcut1_inject1_none \
--test_int 1.0 \
--gpu
```

To test with multiple attributes editing
#### To test with multiple attributes editing

![Test Multi](https://github.com/elvisyjlin/AttGAN-PyTorch/blob/master/pics/sample_testing_multi.jpg)

```bash
CUDA_VISIBLE_DEVICES=0 \
python test_multi.py \
--experiment_name 128_shortcut1_inject1_none \
--test_atts Pale_Skin Male \
--test_ints 0.5 0.5 \
CUDA_VISIBLE_DEVICES=0 \
python test_multi.py \
--experiment_name 128_shortcut1_inject1_none \
--test_atts Pale_Skin Male \
--test_ints 0.5 0.5 \
--gpu
```

To test with attribute intensity control
#### To test with attribute intensity control

![Test Slide](https://github.com/elvisyjlin/AttGAN-PyTorch/blob/master/pics/sample_testing_slide.jpg)

```bash
CUDA_VISIBLE_DEVICES=0 \
python test_slide.py \
--experiment_name 128_shortcut1_inject1_none \
--test_att Male \
--test_int_min -1.0 \
--test_int_max 1.0 \
--n_slide 10 \
CUDA_VISIBLE_DEVICES=0 \
python test_slide.py \
--experiment_name 128_shortcut1_inject1_none \
--test_att Male \
--test_int_min -1.0 \
--test_int_max 1.0 \
--n_slide 10 \
--gpu
```

#### To test with your custom images (supports `test.py`, `test_multi.py`, `test_slide.py`)

```bash
CUDA_VISIBLE_DEVICES=0 \
python test.py \
--experiment_name 384_shortcut1_inject1_none_hq \
--test_int 1.0 \
--gpu \
--custom_img
```

Your custom images are supposed to be in `./data/custom` and you also need an attribute list of the images `./data/list_attr_custom.txt`. Please crop and resize them into square images in advance.
48 changes: 26 additions & 22 deletions attgan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (C) 2018 Elvis Yu-Jing Lin <elvisyjlin@gmail.com>
#
# This work is licensed under the MIT License. To view a copy of this license,
# This work is licensed under the MIT License. To view a copy of this license,
# visit https://opensource.org/licenses/MIT.

"""AttGAN, generator, and discriminator."""
Expand All @@ -16,8 +16,8 @@
MAX_DIM = 64 * 16 # 1024

class Generator(nn.Module):
def __init__(self, enc_dim=64, enc_layers=5, enc_norm_fn='batchnorm', enc_acti_fn='lrelu',
dec_dim=64, dec_layers=5, dec_norm_fn='batchnorm', dec_acti_fn='relu',
def __init__(self, enc_dim=64, enc_layers=5, enc_norm_fn='batchnorm', enc_acti_fn='lrelu',
dec_dim=64, dec_layers=5, dec_norm_fn='batchnorm', dec_acti_fn='relu',
n_attrs=13, shortcut_layers=1, inject_layers=0, img_size=128):
super(Generator, self).__init__()
self.shortcut_layers = min(shortcut_layers, dec_layers - 1)
Expand Down Expand Up @@ -85,7 +85,7 @@ def forward(self, x, a=None, mode='enc-dec'):

class Discriminators(nn.Module):
# No instancenorm in fcs in source code, which is different from paper.
def __init__(self, dim=64, norm_fn='instancenorm', acti_fn='lrelu',
def __init__(self, dim=64, norm_fn='instancenorm', acti_fn='lrelu',
fc_dim=1024, fc_norm_fn='none', fc_acti_fn='lrelu', n_layers=5, img_size=128):
super(Discriminators, self).__init__()
self.f_size = img_size // 2**n_layers
Expand All @@ -100,11 +100,11 @@ def __init__(self, dim=64, norm_fn='instancenorm', acti_fn='lrelu',
n_in = n_out
self.conv = nn.Sequential(*layers)
self.fc_adv = nn.Sequential(
LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn),
LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn),
LinearBlock(fc_dim, 1, 'none', 'none')
)
self.fc_cls = nn.Sequential(
LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn),
LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn),
LinearBlock(fc_dim, 13, 'none', 'none')
)

Expand All @@ -122,32 +122,32 @@ def forward(self, x):

# multilabel_soft_margin_loss = sigmoid + binary_cross_entropy

l1 = 100.0
l2 = 10.0
l3 = 1.0

class AttGAN():
def __init__(self, args):
self.mode = args.mode
self.gpu = args.gpu
self.multi_gpu = args.multi_gpu if 'multi_gpu' in args else False
self.lambda_1 = args.lambda_1
self.lambda_2 = args.lambda_2
self.lambda_3 = args.lambda_3
self.lambda_gp = args.lambda_gp

self.G = Generator(
args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti,
args.dec_dim, args.dec_layers, args.dec_norm, args.dec_acti,
args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti,
args.dec_dim, args.dec_layers, args.dec_norm, args.dec_acti,
args.n_attrs, args.shortcut_layers, args.inject_layers, args.img_size
)
self.G.train()
if self.gpu: self.G.cuda()
summary(self.G, [(3, args.img_size, args.img_size), (args.n_attrs,)], batch_size=4, use_gpu=self.gpu)
summary(self.G, [(3, args.img_size, args.img_size), (args.n_attrs, 1, 1)], batch_size=4, device='cuda' if args.gpu else 'cpu')

self.D = Discriminators(
args.dis_dim, args.dis_norm, args.dis_acti,
args.dis_dim, args.dis_norm, args.dis_acti,
args.dis_fc_dim, args.dis_fc_norm, args.dis_fc_acti, args.dis_layers, args.img_size
)
self.D.train()
if self.gpu: self.D.cuda()
summary(self.D, [(3, args.img_size, args.img_size)], batch_size=4, use_gpu=self.gpu)
summary(self.D, [(3, args.img_size, args.img_size)], batch_size=4, device='cuda' if args.gpu else 'cpu')

if self.multi_gpu:
self.G = nn.DataParallel(self.G)
Expand Down Expand Up @@ -179,14 +179,14 @@ def trainG(self, img_a, att_a, att_a_, att_b, att_b_):
gf_loss = F.binary_cross_entropy_with_logits(d_fake, torch.ones_like(d_fake))
gc_loss = F.binary_cross_entropy_with_logits(dc_fake, att_b)
gr_loss = F.l1_loss(img_recon, img_a)
g_loss = gf_loss + l2 * gc_loss + l1 * gr_loss
g_loss = gf_loss + self.lambda_2 * gc_loss + self.lambda_1 * gr_loss

self.optim_G.zero_grad()
g_loss.backward()
self.optim_G.step()

errG = {
'g_loss': g_loss.item(), 'gf_loss': gf_loss.item(),
'g_loss': g_loss.item(), 'gf_loss': gf_loss.item(),
'gc_loss': gc_loss.item(), 'gr_loss': gr_loss.item()
}
return errG
Expand Down Expand Up @@ -235,7 +235,7 @@ def interpolate(a, b=None):
F.binary_cross_entropy_with_logits(d_fake, torch.zeros_like(d_fake))
df_gp = gradient_penalty(self.D, img_a)
dc_loss = F.binary_cross_entropy_with_logits(dc_real, att_a)
d_loss = df_loss + 10 * df_gp + l3 * dc_loss
d_loss = df_loss + self.lambda_gp * df_gp + self.lambda_3 * dc_loss

self.optim_D.zero_grad()
d_loss.backward()
Expand All @@ -257,9 +257,9 @@ def eval(self):

def save(self, path):
states = {
'G': self.G.state_dict(),
'D': self.D.state_dict(),
'optim_G': self.optim_G.state_dict(),
'G': self.G.state_dict(),
'D': self.D.state_dict(),
'optim_G': self.optim_G.state_dict(),
'optim_D': self.optim_D.state_dict()
}
torch.save(states, path)
Expand Down Expand Up @@ -302,6 +302,10 @@ def saveG(self, path):
parser.add_argument('--dec_acti', dest='dec_acti', type=str, default='relu')
parser.add_argument('--dis_acti', dest='dis_acti', type=str, default='lrelu')
parser.add_argument('--dis_fc_acti', dest='dis_fc_acti', type=str, default='relu')
parser.add_argument('--lambda_1', dest='lambda_1', type=float, default=100.0)
parser.add_argument('--lambda_2', dest='lambda_2', type=float, default=10.0)
parser.add_argument('--lambda_3', dest='lambda_3', type=float, default=1.0)
parser.add_argument('--lambda_gp', dest='lambda_gp', type=float, default=10.0)
parser.add_argument('--mode', dest='mode', default='wgan', choices=['wgan', 'lsgan', 'dcgan'])
parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='learning rate')
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5)
Expand All @@ -310,4 +314,4 @@ def saveG(self, path):
args = parser.parse_args()
args.n_attrs = 13
args.betas = (args.beta1, args.beta2)
attgan = AttGAN(args)
attgan = AttGAN(args)
48 changes: 34 additions & 14 deletions data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (C) 2018 Elvis Yu-Jing Lin <elvisyjlin@gmail.com>
#
# This work is licensed under the MIT License. To view a copy of this license,
# This work is licensed under the MIT License. To view a copy of this license,
# visit https://opensource.org/licenses/MIT.

"""Custom datasets for CelebA and CelebA-HQ."""
Expand All @@ -10,9 +10,31 @@
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from skimage import io
from PIL import Image


class Custom(data.Dataset):
def __init__(self, data_path, attr_path, image_size, selected_attrs):
self.data_path = data_path
att_list = open(attr_path, 'r', encoding='utf-8').readlines()[1].split()
atts = [att_list.index(att) + 1 for att in selected_attrs]
self.images = np.loadtxt(attr_path, skiprows=2, usecols=[0], dtype=np.str)
self.labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)

self.tf = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def __getitem__(self, index):
img = self.tf(Image.open(os.path.join(self.data_path, self.images[index])))
att = torch.tensor((self.labels[index] + 1) // 2)
return img, att

def __len__(self):
return len(self.images)

class CelebA(data.Dataset):
def __init__(self, data_path, attr_path, image_size, mode, selected_attrs):
super(CelebA, self).__init__()
Expand All @@ -33,16 +55,15 @@ def __init__(self, data_path, attr_path, image_size, mode, selected_attrs):
self.labels = labels[182637:]

self.tf = transforms.Compose([
transforms.ToPILImage(),
transforms.CenterCrop(170),
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transforms.CenterCrop(170),
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

self.length = len(self.images)
def __getitem__(self, index):
img = self.tf(io.imread(os.path.join(self.data_path, self.images[index])))
img = self.tf(Image.open(os.path.join(self.data_path, self.images[index])))
att = torch.tensor((self.labels[index] + 1) // 2)
return img, att
def __len__(self):
Expand Down Expand Up @@ -72,15 +93,14 @@ def __init__(self, data_path, attr_path, image_list_path, image_size, mode, sele
self.labels = labels[28500:]

self.tf = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

self.length = len(self.images)
def __getitem__(self, index):
img = self.tf(io.imread(os.path.join(self.data_path, self.images[index])))
img = self.tf(Image.open(os.path.join(self.data_path, self.images[index])))
att = torch.tensor((self.labels[index] + 1) // 2)
return img, att
def __len__(self):
Expand Down Expand Up @@ -125,7 +145,7 @@ def _set(att, value, att_name):
import torchvision.utils as vutils

attrs_default = [
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
]

Expand Down
Binary file added data/custom/donald_trump.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/custom/emma_watson.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/custom/jay_chou.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/custom/ji-eun_lee.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/custom/tom_cruise.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/custom/yui_aragaki.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions data/list_attr_custom.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
6
Bald Bangs Black_Hair Blond_Hair Brown_Hair Bushy_Eyebrows Eyeglasses Male Mouth_Slightly_Open Mustache No_Beard Pale_Skin Young
donald_trump.jpg -1 -1 -1 1 -1 1 -1 1 1 -1 1 -1 -1
emma_watson.jpg -1 1 -1 -1 1 -1 -1 -1 -1 -1 1 -1 1
jay_chou.jpeg -1 1 1 -1 -1 1 -1 1 1 -1 1 -1 1
ji-eun_lee.jpg -1 1 -1 -1 1 -1 -1 -1 1 -1 1 -1 1
tom_cruise.jpg -1 -1 -1 -1 1 1 -1 1 1 -1 1 -1 -1
yui_aragaki.jpg -1 1 -1 -1 1 -1 -1 -1 -1 -1 1 -1 1
2 changes: 1 addition & 1 deletion helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (C) 2018 Elvis Yu-Jing Lin <elvisyjlin@gmail.com>
#
# This work is licensed under the MIT License. To view a copy of this license,
# This work is licensed under the MIT License. To view a copy of this license,
# visit https://opensource.org/licenses/MIT.

"""Helper functions for training."""
Expand Down
Loading

0 comments on commit 97f1502

Please sign in to comment.