Skip to content

Commit

Permalink
Add train_size and validation training (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
minimaxir committed May 2, 2018
1 parent f81a51a commit a40184f
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions textgenrnn/textgenrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,16 @@ def train_on_texts(self, texts, context_labels=None,
verbose=1,
new_model=False,
gen_epochs=1,
prop_keep=1.0,
train_size=1.0,
max_gen_length=300,
**kwargs):

if context_labels:
context_labels = LabelBinarizer().fit_transform(context_labels)

if 'prop_keep' in kwargs:
train_size = prop_keep

if self.config['word_level']:
texts = [text_to_word_sequence(text, filters='') for text in texts]

Expand All @@ -115,8 +118,19 @@ def train_on_texts(self, texts, context_labels=None,
if self.config['single_text']:
indices_list = indices_list[self.config['max_length']:-2, :]

indices_list = indices_list[np.random.rand(
indices_list.shape[0]) < prop_keep, :]
indices_mask = np.random.rand(indices_list.shape[0]) < train_size

gen_val = None
val_steps = None
if train_size < 1.0:
indices_list_val = indices_list[~indices_mask, :]
gen_val = generate_sequences_from_texts(
texts, indices_list_val, self, context_labels, batch_size)
val_steps = max(
int(np.floor(indices_list_val.shape[0] / batch_size)), 1)

indices_list = indices_list[indices_mask, :]

num_tokens = indices_list.shape[0]
assert num_tokens >= batch_size, "Fewer tokens than batch_size."

Expand Down Expand Up @@ -157,7 +171,9 @@ def lr_linear_decay(epoch):
save_model_weights(
self.config['name'])],
verbose=verbose,
max_queue_size=2
max_queue_size=2,
validation_data=gen_val,
validation_steps=val_steps
)

# Keep the text-only version of the model if using context labels
Expand Down

0 comments on commit a40184f

Please sign in to comment.