Skip to content

Commit

Permalink
Added Pix2Pix and PathGAN discriminator for CycleGAN
Browse files Browse the repository at this point in the history
  • Loading branch information
eriklindernoren committed Apr 21, 2018
1 parent 11de060 commit 0fcd191
Show file tree
Hide file tree
Showing 8 changed files with 493 additions and 8 deletions.
File renamed without changes.
8 changes: 8 additions & 0 deletions data/download_pix2pix_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
FILE=$1
URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz
TAR_FILE=./$FILE.tar.gz
TARGET_DIR=./$FILE/
wget -N $URL -O $TAR_FILE
mkdir $TARGET_DIR
tar -zxvf $TAR_FILE -C ./
rm $TAR_FILE
8 changes: 6 additions & 2 deletions implementations/cyclegan/cyclegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@

cuda = True if torch.cuda.is_available() else False

# Calculate output of image discriminator (PatchGAN)
patch_h, patch_w = int(opt.img_height / 2**3), int(opt.img_width / 2**3)
patch = (opt.batch_size, 1, patch_h, patch_w)

# Initialize generator and discriminator
G_AB = GeneratorResNet() if opt.generator_type == 'resnet' else GeneratorUNet()
G_BA = GeneratorResNet() if opt.generator_type == 'resnet' else GeneratorUNet()
Expand Down Expand Up @@ -94,8 +98,8 @@
input_A = Tensor(opt.batch_size, opt.channels, opt.img_height, opt.img_width)
input_B = Tensor(opt.batch_size, opt.channels, opt.img_height, opt.img_width)
# Adversarial ground truths
valid = Variable(Tensor(opt.batch_size).fill_(1.0), requires_grad=False)
fake = Variable(Tensor(opt.batch_size).fill_(0.0), requires_grad=False)
valid = Variable(Tensor(np.ones(patch)), requires_grad=False)
fake = Variable(Tensor(np.zeros(patch)), requires_grad=False)

# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
Expand Down
10 changes: 4 additions & 6 deletions implementations/cyclegan/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self, in_channels=3):

def discriminator_block(in_filters, out_filters, stride, normalize):
"""Returns layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride, 1)]
layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
Expand All @@ -167,11 +167,9 @@ def discriminator_block(in_filters, out_filters, stride, normalize):
layers.extend(discriminator_block(in_filters, out_filters, stride, normalize))
in_filters = out_filters

layers.append(nn.Conv2d(out_filters, 1, 4, 1))
layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))

self.model = nn.Sequential(*layers)

def forward(self, x):
x = self.model(x)
# Average pooling and flatten
return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
def forward(self, img):
return self.model(img)
28 changes: 28 additions & 0 deletions implementations/pix2pix/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import glob
import random
import os
import numpy as np

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class ImageDataset(Dataset):
def __init__(self, root, transforms_=None, mode='train'):
self.transform = transforms.Compose(transforms_)

self.files = sorted(glob.glob(os.path.join(root, '%s' % mode) + '/*.*'))

def __getitem__(self, index):

img_pair = self.transform(Image.open(self.files[index % len(self.files)]))
_, h, w = img_pair.shape
half_w = int(w/2)

item_A = img_pair[:, :, :half_w]
item_B = img_pair[:, :, half_w:]

return {'A': item_A, 'B': item_B}

def __len__(self):
return len(self.files)
182 changes: 182 additions & 0 deletions implementations/pix2pix/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import torch.nn as nn
import torch.nn.functional as F
import torch

##############################
# U-NET
##############################

class UNetDown(nn.Module):
def __init__(self, in_size, out_size, bn=True, dropout=0.0):
super(UNetDown, self).__init__()
model = [ nn.Conv2d(in_size, out_size, 3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True) ]

if bn:
model += [nn.InstanceNorm2d(out_size)]

if dropout:
model += [nn.Dropout(dropout)]

self.model = nn.Sequential(*model)

def forward(self, x):
return self.model(x)

class UNetUp(nn.Module):
def __init__(self, in_size, out_size, dropout=0.0):
super(UNetUp, self).__init__()
model = [ nn.Upsample(scale_factor=2),
nn.Conv2d(in_size, out_size, 3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.InstanceNorm2d(out_size) ]

if dropout:
model += [nn.Dropout(dropout)]

self.model = nn.Sequential(*model)

def forward(self, x, skip_input):
x = self.model(x)
out = torch.cat((x, skip_input), 1)
#out = torch.add(x, skip_input)
return out

class GeneratorUNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3):
super(GeneratorUNet, self).__init__()

self.down1 = UNetDown(in_channels, 64, bn=False)
self.down2 = UNetDown(64, 128)
self.down3 = UNetDown(128, 256)
self.down4 = UNetDown(256, 512, dropout=0.5)
self.down5 = UNetDown(512, 512, dropout=0.5)
self.down6 = UNetDown(512, 512, dropout=0.5)
self.down7 = UNetDown(512, 512, dropout=0.5)

self.up1 = UNetUp(512, 512, dropout=0.5)
self.up2 = UNetUp(1024, 512, dropout=0.5)
self.up3 = UNetUp(1024, 512, dropout=0.5)
self.up4 = UNetUp(1024, 256)
self.up5 = UNetUp(512, 128)
self.up6 = UNetUp(256, 64)


final = [ nn.Upsample(scale_factor=2),
nn.Conv2d(128, out_channels, 3, 1, 1),
nn.Tanh() ]
self.final = nn.Sequential(*final)

def forward(self, x):
# U-Net generator with skip connections from encoder to decoder
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d7 = self.down7(d6)
u1 = self.up1(d7, d6)
u2 = self.up2(u1, d5)
u3 = self.up3(u2, d4)
u4 = self.up4(u3, d3)
u5 = self.up5(u4, d2)
u6 = self.up6(u5, d1)

return self.final(u6)


##############################
# RESNET
##############################

class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()

conv_block = [ nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features) ]

self.conv_block = nn.Sequential(*conv_block)

def forward(self, x):
return x + self.conv_block(x)

class GeneratorResNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=9):
super(GeneratorResNet, self).__init__()

# Initial convolution block
model = [ nn.ReflectionPad2d(3),
nn.Conv2d(in_channels, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True) ]

# Downsampling
in_features = 64
out_features = in_features*2
for _ in range(2):
model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features*2

# Residual blocks
for _ in range(n_residual_blocks):
model += [ResidualBlock(in_features)]

# Upsampling
out_features = in_features//2
for _ in range(2):
model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features//2

# Output layer
model += [ nn.ReflectionPad2d(3),
nn.Conv2d(64, out_channels, 7),
nn.Tanh() ]

self.model = nn.Sequential(*model)

def forward(self, x):
return self.model(x)

class Discriminator(nn.Module):
def __init__(self, in_channels=3):
super(Discriminator, self).__init__()

def discriminator_block(in_filters, out_filters, stride, normalize):
"""Returns layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

layers = []
in_filters = in_channels*2
for out_filters, stride, normalize in [ (64, 2, False),
(128, 2, True),
(256, 2, True),
(512, 2, True)]:
layers.extend(discriminator_block(in_filters, out_filters, stride, normalize))
in_filters = out_filters

# Output layer
layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))

self.model = nn.Sequential(*layers)

def forward(self, img_A, img_B):
# Concatenate image and condition image by channels to produce input
img_input = torch.cat((img_A, img_B), 1)
return self.model(img_input)

0 comments on commit 0fcd191

Please sign in to comment.