Skip to content

Commit

Permalink
1. Added one_direction_test_model that generates the outputs in only …
Browse files Browse the repository at this point in the history
…one direction

2. Changed the option naming from ntrain to max_dataset_size
  • Loading branch information
Taesung Park committed Apr 27, 2017
1 parent af7420f commit e5b2fd6
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -87,7 +87,7 @@ python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix
- Test the model (`bash ./scripts/test_pix2pix.sh`):
```bash
#!./scripts/test_pix2pix.sh
python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --align_data --use_dropout
python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --align_data
```
The test results will be saved to a html file here: `./results/facades_pix2pix/latest_val/index.html`.

Expand Down
10 changes: 5 additions & 5 deletions data/aligned_data_loader.py
Expand Up @@ -8,10 +8,10 @@
from builtins import object

class PairedData(object):
def __init__(self, data_loader, fineSize, ntrain):
def __init__(self, data_loader, fineSize, max_dataset_size):
self.data_loader = data_loader
self.fineSize = fineSize
self.ntrain = ntrain
self.max_dataset_size = max_dataset_size
# st()

def __iter__(self):
Expand All @@ -21,7 +21,7 @@ def __iter__(self):

def __next__(self):
self.iter += 1
if self.iter > self.ntrain:
if self.iter > self.max_dataset_size:
raise StopIteration

AB, AB_paths = next(self.data_loader_iter)
Expand Down Expand Up @@ -60,7 +60,7 @@ def initialize(self, opt):
num_workers=int(self.opt.nThreads))

self.dataset = dataset
self.paired_data = PairedData(data_loader, opt.fineSize, opt.ntrain)
self.paired_data = PairedData(data_loader, opt.fineSize, opt.max_dataset_size)

def name(self):
return 'AlignedDataLoader'
Expand All @@ -69,4 +69,4 @@ def load_data(self):
return self.paired_data

def __len__(self):
return min(len(self.dataset), self.opt.ntrain)
return min(len(self.dataset), self.opt.max_dataset_size)
10 changes: 5 additions & 5 deletions data/unaligned_data_loader.py
Expand Up @@ -7,12 +7,12 @@
from pdb import set_trace as st

class PairedData(object):
def __init__(self, data_loader_A, data_loader_B, ntrain):
def __init__(self, data_loader_A, data_loader_B, max_dataset_size):
self.data_loader_A = data_loader_A
self.data_loader_B = data_loader_B
self.stop_A = False
self.stop_B = False
self.ntrain = ntrain
self.max_dataset_size = max_dataset_size

def __iter__(self):
self.stop_A = False
Expand Down Expand Up @@ -41,7 +41,7 @@ def __next__(self):
self.data_loader_B_iter = iter(self.data_loader_B)
B, B_paths = next(self.data_loader_B_iter)

if (self.stop_A and self.stop_B) or self.iter > self.ntrain:
if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size:
self.stop_A = False
self.stop_B = False
raise StopIteration()
Expand Down Expand Up @@ -79,7 +79,7 @@ def initialize(self, opt):
num_workers=int(self.opt.nThreads))
self.dataset_A = dataset_A
self.dataset_B = dataset_B
self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.ntrain)
self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.max_dataset_size)

def name(self):
return 'UnalignedDataLoader'
Expand All @@ -88,4 +88,4 @@ def load_data(self):
return self.paired_data

def __len__(self):
return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.ntrain)
return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.max_dataset_size)
9 changes: 7 additions & 2 deletions models/models.py
Expand Up @@ -4,12 +4,17 @@ def create_model(opt):
print(opt.model)
if opt.model == 'cycle_gan':
from .cycle_gan_model import CycleGANModel
assert(opt.align_data == False)
#assert(opt.align_data == False)
model = CycleGANModel()
if opt.model == 'pix2pix':
elif opt.model == 'pix2pix':
from .pix2pix_model import Pix2PixModel
assert(opt.align_data == True)
model = Pix2PixModel()
elif opt.model == 'one_direction_test':
from .one_direction_test_model import OneDirectionTestModel
model = OneDirectionTestModel()
else:
raise ValueError("Model [%s] not recognized." % opt.model)
model.initialize(opt)
print("model [%s] was created" % (model.name()))
return model
51 changes: 51 additions & 0 deletions models/one_direction_test_model.py
@@ -0,0 +1,51 @@
from torch.autograd import Variable
from collections import OrderedDict
import util.util as util
from .base_model import BaseModel
from . import networks


class OneDirectionTestModel(BaseModel):
def name(self):
return 'OneDirectionTestModel'

def initialize(self, opt):
BaseModel.initialize(self, opt)

nb = opt.batchSize
size = opt.fineSize
self.input_A = self.Tensor(nb, opt.input_nc, size, size)

assert(not self.isTrain)
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG,
opt.norm, opt.use_dropout,
self.gpu_ids)
which_epoch = opt.which_epoch
AtoB = self.opt.which_direction == 'AtoB'
which_network = 'G_A' if AtoB else 'G_B'
self.load_network(self.netG_A, which_network, which_epoch)

print('---------- Networks initialized -------------')
networks.print_network(self.netG_A)
print('-----------------------------------------------')

def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
input_A = input['A' if AtoB else 'B']
self.input_A.resize_(input_A.size()).copy_(input_A)
self.image_paths = input['A_paths' if AtoB else 'B_paths']

def test(self):
self.real_A = Variable(self.input_A)
self.fake_B = self.netG_A.forward(self.real_A)

#get image paths
def get_image_paths(self):
return self.image_paths

def get_current_visuals(self):
real_A = util.tensor2im(self.real_A.data)
fake_B = util.tensor2im(self.fake_B.data)
return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])

2 changes: 2 additions & 0 deletions options/base_options.py
Expand Up @@ -35,6 +35,8 @@ def initialize(self):
self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1')
self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')

self.initialized = True

def parse(self):
Expand Down
1 change: 0 additions & 1 deletion options/train_options.py
Expand Up @@ -13,7 +13,6 @@ def initialize(self):
self.parser.add_argument('--niter', type=int, default=200, help='# of iter at starting learning rate')
self.parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
self.parser.add_argument('--ntrain', type=int, default=float("inf"), help='# of examples per epoch.')
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
Expand Down

0 comments on commit e5b2fd6

Please sign in to comment.