Skip to content

Commit

Permalink
Merge pull request #277 from genekogan/master
Browse files Browse the repository at this point in the history
Make dataset root directory a flag
  • Loading branch information
carpedm20 committed Apr 12, 2018
2 parents ddb7fc2 + 8fd1c70 commit 60aa97b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ Or, you can use your own dataset (without central crop) by:
$ # example
$ python main.py --dataset=eyes --input_fname_pattern="*_cropped.png" --train

If your dataset is located in a different root directory:

$ python main.py --dataset DATASET_NAME --data_dir DATASET_ROOT_DIR --train
$ python main.py --dataset DATASET_NAME --data_dir DATASET_ROOT_DIR
$ # example
$ python main.py --dataset=eyes --data_dir ../datasets/ --input_fname_pattern="*_cropped.png" --train


## Results

![result](assets/training.gif)
Expand Down
7 changes: 5 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("data_dir", "./data", "Root directory of dataset [data]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
Expand Down Expand Up @@ -60,7 +61,8 @@ def main(_):
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)
sample_dir=FLAGS.sample_dir,
data_dir=FLAGS.data_dir)
else:
dcgan = DCGAN(
sess,
Expand All @@ -75,7 +77,8 @@ def main(_):
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)
sample_dir=FLAGS.sample_dir,
data_dir=FLAGS.data_dir)

show_all_variables()

Expand Down
9 changes: 5 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
batch_size=64, sample_num = 64, output_height=64, output_width=64,
y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None):
input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None, data_dir='./data'):
"""
Args:
Expand Down Expand Up @@ -69,12 +69,13 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
self.dataset_name = dataset_name
self.input_fname_pattern = input_fname_pattern
self.checkpoint_dir = checkpoint_dir
self.data_dir = data_dir

if self.dataset_name == 'mnist':
self.data_X, self.data_y = self.load_mnist()
self.c_dim = self.data_X[0].shape[-1]
else:
self.data = glob(os.path.join("./data", self.dataset_name, self.input_fname_pattern))
self.data = glob(os.path.join(self.data_dir, self.dataset_name, self.input_fname_pattern))
imreadImg = imread(self.data[0])
if len(imreadImg.shape) >= 3: #check if image is a non-grayscale image by checking channel number
self.c_dim = imread(self.data[0]).shape[-1]
Expand Down Expand Up @@ -192,7 +193,7 @@ def train(self, config):
batch_idxs = min(len(self.data_X), config.train_size) // config.batch_size
else:
self.data = glob(os.path.join(
"./data", config.dataset, self.input_fname_pattern))
config.data_dir, config.dataset, self.input_fname_pattern))
batch_idxs = min(len(self.data), config.train_size) // config.batch_size

for idx in xrange(0, batch_idxs):
Expand Down Expand Up @@ -451,7 +452,7 @@ def sampler(self, z, y=None):
return tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))

def load_mnist(self):
data_dir = os.path.join("./data", self.dataset_name)
data_dir = os.path.join(self.data_dir, self.dataset_name)

fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
Expand Down

0 comments on commit 60aa97b

Please sign in to comment.