Skip to content
Permalink
Browse files

initial commit

  • Loading branch information
bellchenx committed Jan 10, 2018
1 parent 0adaf69 commit ae065256913d427a1ef1a0cbe62cb72c111dad72
Showing with 439 additions and 0 deletions.
  1. +243 −0 main.py
  2. +132 −0 model.py
  3. +64 −0 util.py
243 main.py
@@ -0,0 +1,243 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import itertools
import argparse

import util
import model

import torch
from torch.autograd import Variable
from torch.optim import Adam
from torchvision.utils import save_image
import torch.backends.cudnn as cudnn

parser = argparse.ArgumentParser(description='CycleGAN')

# Directory
parser.add_argument('--dataset_A', type=str, default='A')
parser.add_argument('--dataset_B', type=str, default='B')

# Data
parser.add_argument('--image_size', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--resume', '-r', action='store_true')

# Network
parser.add_argument('--G_channel', type=int, default=32)
parser.add_argument('--D_channel', type=int, default=32)
parser.add_argument('--G_downsample', type=int, default=2)
parser.add_argument('--D_downsample', type=int, default=5)

parser.add_argument('--G_input', type=int, default=3)
parser.add_argument('--G_output', type=int, default=3)
parser.add_argument('--D_input', type=int, default=3)
parser.add_argument('--D_output', type=int, default=1)
parser.add_argument('--D_layer', type=int, default=5)

parser.add_argument('--G_block', type=int, default=6)
parser.add_argument('--G_block_type', type=str, default='conv')
parser.add_argument('--G_enable_se', type=bool, default=True)

# Training
parser.add_argument('--learning_rate', type=int, default=2e-4)
parser.add_argument('--lr_decay_epoch', type=int, default=100)
parser.add_argument('--max_epoch', type=int, default=200)
parser.add_argument('--log_frequency', type=int, default=25)
parser.add_argument('--save_frequency', type=int, default=20)

config = parser.parse_args()

use_cuda = torch.cuda.is_available()

if config.resume:
print('-- Resuming From Checkpoint')
assert os.path.isdir('checkpoint'), '-- Error: No checkpoint directory found!'
checkpoint = torch.load('./checkpoint/cyclegan.nn')
G_A = checkpoint['G_A']
G_B = checkpoint['G_B']
D_A = checkpoint['D_A']
D_B = checkpoint['D_B']
start = checkpoint['epoch'] + 1
else:
G_A = model.Generator(config)
G_B = model.Generator(config)
D_A = model.Discriminator(config)
D_B = model.Discriminator(config)
start = 1

if use_cuda:
G_A = G_A.cuda()
G_B = G_B.cuda()
D_A = D_A.cuda()
D_B = D_B.cuda()
cudnn.benchmark = True

util.print_network(G_A)
util.print_network(D_A)
G_A.train()
G_B.train()
D_A.train()
D_B.train()

MSE_Loss = torch.nn.MSELoss()
L1_Loss = torch.nn.L1Loss()

G_A_Optimizer = Adam(G_A.parameters(), lr=config.learning_rate, betas=(0.5, 0.999))
G_B_Optimizer = Adam(G_B.parameters(), lr=config.learning_rate, betas=(0.5, 0.999))
D_A_Optimizer = Adam(D_A.parameters(), lr=config.learning_rate, betas=(0.5, 0.999))
D_B_Optimizer = Adam(D_B.parameters(), lr=config.learning_rate, betas=(0.5, 0.999))

a_loader = util.get_loader(config, config.dataset_A + '/train')
b_loader = util.get_loader(config, config.dataset_B + '/train')
a_test_loader = util.get_loader(config, config.dataset_A + '/test')
b_test_loader = util.get_loader(config, config.dataset_B + '/test')

a_real_fixed = Variable(iter(a_test_loader).next()[0], volatile=True)
b_real_fixed = Variable(iter(b_test_loader).next()[0], volatile=True)
if use_cuda:
a_real_fixed = a_real_fixed.cuda()
b_real_fixed = b_real_fixed.cuda()

a_fake_pool = util.ItemPool()
b_fake_pool = util.ItemPool()

def adjust_learning_rate(optimizer, epoch):
lr_now = config.learning_rate
if epoch > config.lr_decay_epoch:
lr_now = lr_now - lr_now*(epoch - config.lr_decay_epoch)/config.lr_decay_epoch
for param_group in optimizer.param_groups:
param_group['lr'] = lr_now

def train(start, epoch):
last_time = time.time()
epoch_time = time.time()
print('-- Current Epoch: %d'%epoch)

adjust_learning_rate(G_A_Optimizer, epoch)
adjust_learning_rate(G_B_Optimizer, epoch)
adjust_learning_rate(D_A_Optimizer, epoch)
adjust_learning_rate(D_B_Optimizer, epoch)

for i, (a_real, b_real) in enumerate(itertools.izip(a_loader, b_loader)):
# Train Generators
a_real = Variable(a_real[0])
b_real = Variable(b_real[0])
if use_cuda:
a_real = a_real.cuda()
b_real = b_real.cuda()

a_fake = G_A(b_real)
b_fake = G_B(a_real)
a_rec = G_A(b_fake)
b_rec = G_B(a_fake)
a_fake_result = D_A(a_fake)
b_fake_result = D_B(b_fake)

real_labels = Variable(torch.ones(a_fake_result.size()))
if use_cuda:
real_labels = real_labels.cuda()

G_A_loss = MSE_Loss(a_fake_result, real_labels)
G_B_loss = MSE_Loss(b_fake_result, real_labels)
a_rec_loss = L1_Loss(a_rec, a_real)
b_rec_loss = L1_Loss(b_rec, b_real)
G_loss = G_A_loss + G_B_loss + a_rec_loss*10 + b_rec_loss*10

G_A.zero_grad()
G_B.zero_grad()
G_loss.backward()
G_A_Optimizer.step()
G_B_Optimizer.step()

# Train Discriminators
a_fake = Variable(torch.Tensor(a_fake_pool([a_fake.cpu().data.numpy()])[0]))
b_fake = Variable(torch.Tensor(b_fake_pool([b_fake.cpu().data.numpy()])[0]))
if use_cuda:
a_fake = a_fake.cuda()
b_fake = b_fake.cuda()

a_real_result = D_A(a_real)
a_fake_result = D_A(a_fake)
b_real_result = D_B(b_real)
b_fake_result = D_B(b_fake)

real_labels = Variable(torch.ones(a_real_result.size()))
fake_labels = Variable(torch.zeros(a_fake_result.size()))
if use_cuda:
real_labels = real_labels.cuda()
fake_labels = fake_labels.cuda()

D_A_real_loss = MSE_Loss(a_real_result, real_labels)
D_A_fake_loss = MSE_Loss(a_fake_result, fake_labels)
D_B_real_loss = MSE_Loss(b_real_result, real_labels)
D_B_fake_loss = MSE_Loss(b_fake_result, fake_labels)

D_A_loss = D_A_fake_loss + D_A_real_loss
D_B_loss = D_B_fake_loss + D_B_real_loss

D_A.zero_grad()
D_B.zero_grad()
D_A_loss.backward()
D_B_loss.backward()
D_A_Optimizer.step()
D_B_Optimizer.step()

# Log
if i % config.log_frequency == 0:
speed = time.time() - last_time
last_time = time.time()
format_str = ('Step: %d; Loss: G-A: %.3f, D-A: %.3f, G-B: %.3f, D-B: %.3f; Speed: %.2f sec/step')
print(format_str % (i, G_A_loss, D_A_loss, G_B_loss, D_B_loss, speed/config.log_frequency))

# Save Data
print('-- Saving parameters and sample images.')
state = {'G_A': G_A, 'G_B': G_B, 'D_A': D_A, 'D_B': D_B, 'epoch': epoch}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/cyclegan.nn')

if epoch >= 10 and epoch % config.save_frequency == 0:
# Test Images
for i, (a_real_test, b_real_test) in enumerate(itertools.izip(a_test_loader, b_test_loader)):
a_real_test = Variable(a_real_test[0])
b_real_test = Variable(b_real_test[0])
if use_cuda:
a_real_test = a_real_test.cuda()
b_real_test = b_real_test.cuda()

a_fake_test = G_A(b_real_test)
b_fake_test = G_B(a_real_test)
a_rec_test = G_A(b_fake_test)
b_rec_test = G_B(a_fake_test)

test = torch.cat([a_real_test, b_fake_test, a_rec_test, b_real_test, a_fake_test, b_rec_test], dim=0)
test = util.denorm(test).data
if not os.path.isdir('result'):
os.mkdir('result')
save_image(test, 'result/test%d-epoch-%d.jpg' % (i, epoch))
else:
# Sample Image
a_fake_fixed = G_A(b_real_fixed)
b_fake_fixed = G_B(a_real_fixed)
a_rec_fixed = G_A(b_fake_fixed)
b_rec_fixed = G_B(a_fake_fixed)
sample = torch.cat([a_real_fixed, b_fake_fixed, a_rec_fixed, b_real_fixed, a_fake_fixed, b_rec_fixed], dim=0)
sample = util.denorm(sample).data
if not os.path.isdir('result'):
os.mkdir('result')
save_image(sample, 'result/sample-epoch-%d.jpg' % (epoch))

epoch_time = (time.time() - epoch_time)/60
time_remain = (epoch_time * (config.max_epoch - epoch))/60
print('-- Epoch %d completed. Epoch Time: %.2f min, Time Est: %.2f hour.' %(epoch, epoch_time, time_remain))

# Training Loop
print('-- Start Training')
for epoch in range(start, config.max_epoch):
train(start, epoch)
132 model.py
@@ -0,0 +1,132 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class se_block_conv(nn.Module):
def __init__(self, channel, kernel_size, stride, padding, enable):
super(se_block_conv, self).__init__()
self.enable = enable

self.conv1 = nn.Conv2d(channel, channel, kernel_size, stride, padding)
self.conv1_norm = nn.InstanceNorm2d(channel)
self.conv2 = nn.Conv2d(channel, channel, kernel_size, stride, padding)
self.conv2_norm = nn.InstanceNorm2d(channel)

if self.enable:
self.se_conv1 = nn.Conv2d(channel, channel//16, kernel_size=1)
self.se_conv2 = nn.Conv2d(channel//16, channel, kernel_size=1)

def forward(self, x):
output = F.relu(self.conv1_norm(self.conv1(x)))
output = self.conv2_norm(self.conv2(output))

if self.enable:
se = F.avg_pool2d(output, output.size(2))
se = F.relu(self.se_conv1(se))
se = F.sigmoid(self.se_conv2(se))
output = output * se

output += x
output = F.relu(output)
return output

class se_block_deconv(nn.Module):
def __init__(self, channel, kernel_size, stride, padding, enable):
super(se_block_deconv, self).__init__()
self.enable = enable

self.conv1 = nn.ConvTranspose2d(channel, channel, kernel_size, stride, padding, bias=True)
self.conv1_norm = nn.InstanceNorm2d(channel)
self.conv2 = nn.ConvTranspose2d(channel, channel, kernel_size, stride, padding, bias=True)
self.conv2_norm = nn.InstanceNorm2d(channel)

self.se_conv1 = nn.Conv2d(channel, channel//16, kernel_size=1)
self.se_conv2 = nn.Conv2d(channel//16, channel, kernel_size=1)

def forward(self, x):
output = F.relu(self.conv1_norm(self.conv1(x)))
output = self.conv2_norm(self.conv2(output))

if self.enable:
se = F.avg_pool2d(output, output.size(2))
se = F.relu(self.se_conv1(se))
se = F.sigmoid(self.se_conv2(se))
output = output * se

output += x
output = F.relu(output)
return output

class Generator(nn.Module):
def __init__(self, config):
super(Generator, self).__init__()
input_nc = config.G_input
output_nc = config.G_output
n_downsample = config.G_downsample
ngf = config.G_channel
nb = config.G_block
enable_se = config.G_enable_se
if config.G_block_type == 'deconv':
block = [se_block_deconv(ngf * (2**n_downsample), kernel_size=3, stride=1, padding=1, enable=enable_se)]
else:
block = [se_block_conv(ngf * (2**n_downsample), kernel_size=3, stride=1, padding=1, enable=enable_se)]

downsample = [nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0),
nn.InstanceNorm2d(ngf, affine=True),
nn.ReLU(True)]
for i in range(n_downsample):
mult = 2**i
downsample += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
nn.InstanceNorm2d(ngf * mult * 2, affine=True),
nn.ReLU(True)]
self.downsample = nn.Sequential(*downsample)

resnet_blocks = []
for i in range(nb):
resnet_blocks += block
self.resnet_blocks = nn.Sequential(*resnet_blocks)

upsample = []
for i in range(n_downsample):
mult = 2**(n_downsample - i)
upsample += [nn.ConvTranspose2d(ngf * mult, ngf * mult // 2, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(ngf * mult // 2, affine=True),
nn.ReLU(True)]
self.upsample = nn.Sequential(*upsample)
self.final_conv = nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=0)


def forward(self, x):
out = F.pad(x, (3, 3, 3, 3), 'reflect')
out = self.downsample(out)
out = self.resnet_blocks(out)
out = self.upsample(out)
out = F.pad(out, (3, 3, 3, 3), 'reflect')
out = F.tanh(self.final_conv(out))
return out

class Discriminator(nn.Module):
def __init__(self, config):
super(Discriminator, self).__init__()
input_nc = config.D_input
output_nc = config.D_output
ndf = config.D_channel
n_layers = config.D_downsample - 2

model = [nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2)]

for i in range(n_layers):
mult = 2**i
stride=2
if i >= 2:
stride = 1
model += [nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=stride, padding=1),
nn.InstanceNorm2d(ndf * mult * 2, affine=True),
nn.LeakyReLU(0.2, True)]

model += [nn.Conv2d(ndf * (2**n_layers), output_nc, kernel_size=4, stride=1, padding=1)]
self.model = nn.Sequential(*model)

def forward(self, x):
return self.model(x)

0 comments on commit ae06525

Please sign in to comment.
You can’t perform that action at this time.