Skip to content

Commit

Permalink
more robust parser semantics + housekeeping
Browse files Browse the repository at this point in the history
  • Loading branch information
kavorite committed Jul 14, 2021
1 parent a5eb196 commit 5aa52b7
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions train.py
@@ -1,31 +1,39 @@
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import tensorflow as tf

from model import LEVEL_DELIM, VOCAB_SIZE, tokenizer, transformer


def read_level_seqs(level_path, seq_len):
def read_level_ngrams(level_path, seq_len):
tokenize = tokenizer()

def token_windows():
def token_ngrams():
with tf.io.gfile.GFile(level_path) as istrm:
chunk_size = 8192
buf = ""
ctx = []
while line := istrm.readline():
buf += line
if len(buf) > chunk_size:
ctx.extend(tokenize(tf.convert_to_tensor(line)).numpy())
buf = ""

def parse_ngrams(chunk):
ctx = tokenize(tf.convert_to_tensor(chunk).numpy())
while len(ctx) > seq_len:
for i in range(len(ctx) - seq_len - 1):
window = ctx[i : i + seq_len + 1]
source = window[:-1]
target = window[1:]
yield source, target
ctx = ctx[seq_len + 1 :]
ctx = ctx[seq_len:]

chunk = ""
while line := istrm.readline():
chunk += line
if len(chunk) > chunk_size:
yield from parse_ngrams(chunk)
chunk = ""
yield from parse_ngrams(chunk)

return tf.data.Dataset.from_generator(
token_windows,
token_ngrams,
output_signature=(
tf.TensorSpec(shape=[seq_len], dtype=tf.int32),
tf.TensorSpec(shape=[seq_len], dtype=tf.int32),
Expand All @@ -41,7 +49,7 @@ def token_windows():
seq_len = 768
batch_size = 32
dataset = (
read_level_seqs("./levels.txt", seq_len)
read_level_ngrams("./levels.txt", seq_len)
.batch(batch_size)
.cache()
.shuffle(total_levels)
Expand Down Expand Up @@ -73,4 +81,9 @@ def token_windows():
tf.math.ceil((total_tokens - seq_len) / total_levels / batch_size)
),
epochs=8,
callbacks=[
tf.keras.callbacks.EarlyStopping(monitor="acc@1", restore_best_weights=True)
],
)

model.save("./generator.h5", include_optimizer=False)

0 comments on commit 5aa52b7

Please sign in to comment.