Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
jakezhaojb committed Jun 11, 2018
1 parent 7f9e746 commit 0ea7f61
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 27 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -12,7 +12,7 @@ Major updates on 06/11/2018:

## File structure
* lang: ARAE for language generation, on both 1B word benchmark and SNLI
* yelp: ARAE for yelp style transfer
* yelp: ARAE for language style transfer
* mnist (in Torch): ARAE for discretized MNIST


Expand Down
2 changes: 1 addition & 1 deletion yelp/README.md
@@ -1,4 +1,4 @@
# ARAE for Yelp Transfer Experiment
# ARAE for language style transfer on Yelp

python train.py --data_path ./data

28 changes: 3 additions & 25 deletions yelp/train.py
Expand Up @@ -93,8 +93,6 @@
help='beta1 for adam. default=0.5')
parser.add_argument('--clip', type=float, default=1,
help='gradient clipping, max norm')
parser.add_argument('--gan_clamp', type=float, default=0.01,
help='WGAN clamp')
parser.add_argument('--gan_gp_lambda', type=float, default=0.1,
help='WGAN GP penalty lambda')
parser.add_argument('--grad_lambda', type=float, default=0.01,
Expand All @@ -116,8 +114,6 @@
parser.add_argument('--no-cuda', dest='cuda', action='store_true',
help='not using CUDA')
parser.set_defaults(cuda=True)
parser.add_argument('--debug', action='store_true',
help='debug')
parser.add_argument('--device_id', type=str, default='0')

args = parser.parse_args()
Expand Down Expand Up @@ -153,8 +149,6 @@
(os.path.join(args.data_path, "valid2.txt"), "valid2", False),
(os.path.join(args.data_path, "train1.txt"), "train1", True),
(os.path.join(args.data_path, "train2.txt"), "train2", True)]
if args.debug:
datafiles = datafiles[:2]
vocabdict = None
if args.load_vocab != "":
vocabdict = json.load(args.vocab)
Expand All @@ -163,8 +157,7 @@
maxlen=args.maxlen,
vocab_size=args.vocab_size,
lowercase=args.lowercase,
vocab=vocabdict,
debug=args.debug)
vocab=vocabdict)

# dumping vocabulary
with open('{}/vocab.json'.format(args.outf), 'w') as f:
Expand All @@ -183,12 +176,8 @@
eval_batch_size = 100
test1_data = batchify(corpus.data['valid1'], eval_batch_size, shuffle=False)
test2_data = batchify(corpus.data['valid2'], eval_batch_size, shuffle=False)
if args.debug:
train1_data = batchify(corpus.data['valid1'], args.batch_size, shuffle=True)
train2_data = batchify(corpus.data['valid2'], args.batch_size, shuffle=True)
else:
train1_data = batchify(corpus.data['train1'], args.batch_size, shuffle=True)
train2_data = batchify(corpus.data['train2'], args.batch_size, shuffle=True)
train1_data = batchify(corpus.data['train1'], args.batch_size, shuffle=True)
train2_data = batchify(corpus.data['train2'], args.batch_size, shuffle=True)

print("Loaded data!")

Expand Down Expand Up @@ -485,10 +474,6 @@ def calc_gradient_penalty(netD, real_data, fake_data):


def train_gan_d(whichdecoder, batch):
# clamp parameters to a cube
#for p in gan_disc.parameters():
# p.data.clamp_(-args.gan_clamp, args.gan_clamp)

gan_disc.train()
optimizer_gan_d.zero_grad()

Expand Down Expand Up @@ -527,10 +512,6 @@ def train_gan_d(whichdecoder, batch):


def train_gan_d_into_ae(whichdecoder, batch):
## clamp parameters to a cube
#for p in gan_disc.parameters():
# p.data.clamp_(-args.gan_clamp, args.gan_clamp)

autoencoder.train()
optimizer_ae.zero_grad()

Expand Down Expand Up @@ -693,9 +674,6 @@ def train_gan_d_into_ae(whichdecoder, batch):
evaluate_generator(1, fixed_noise, "end_of_epoch_{}".format(epoch))
evaluate_generator(2, fixed_noise, "end_of_epoch_{}".format(epoch))

if args.debug:
continue

# shuffle between epochs
train1_data = batchify(corpus.data['train1'], args.batch_size, shuffle=True)
train2_data = batchify(corpus.data['train2'], args.batch_size, shuffle=True)
Expand Down

0 comments on commit 0ea7f61

Please sign in to comment.