Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example training an LSTM language model.
PiperOrigin-RevId: 303707157 Change-Id: I25a8b0f4d5a9707766fba2c8980de1e5d8dfdb66
- Loading branch information
Showing
3 changed files
with
280 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
load("//haiku/_src:build_defs.bzl", "hk_py_binary", "hk_py_library") | ||
|
||
licenses(["notice"]) | ||
|
||
exports_files(["LICENSE"]) | ||
|
||
hk_py_library( | ||
name = "dataset", | ||
srcs = ["dataset.py"], | ||
deps = [ | ||
# pip: numpy | ||
# pip: tensorflow | ||
# pip: tensorflow_datasets | ||
], | ||
) | ||
|
||
hk_py_binary( | ||
name = "train", | ||
srcs = ["train.py"], | ||
deps = [ | ||
":dataset", | ||
# pip: absl:app | ||
# pip: absl/flags | ||
# pip: absl/logging | ||
"//haiku", | ||
# pip: jax | ||
# pip: jax:optix | ||
# pip: numpy | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Lint as: python3 | ||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tiny Shakespeare as a language modelling dataset.""" | ||
|
||
from typing import Iterator, Mapping | ||
|
||
import numpy as np | ||
import tensorflow.compat.v2 as tf | ||
import tensorflow_datasets as tfds | ||
|
||
Batch = Mapping[str, np.ndarray] | ||
NUM_CHARS = 128 | ||
|
||
|
||
def load( | ||
split: tfds.Split, | ||
*, | ||
batch_size: int, | ||
sequence_length: int, | ||
) -> Iterator[Batch]: | ||
"""Creates the Tiny Shakespeare dataset as a character modelling task.""" | ||
|
||
def preprocess_fn(x: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]: | ||
x = x['text'] | ||
x = tf.strings.unicode_split(x, 'UTF-8') | ||
x = tf.squeeze(tf.io.decode_raw(x, tf.uint8), axis=-1) | ||
x = tf.cast(x, tf.int32) | ||
return {'input': x[:-1], 'target': x[1:]} | ||
|
||
ds = tfds.load(name='tiny_shakespeare', split=split) | ||
ds = ds.map(preprocess_fn) | ||
ds = ds.unbatch() | ||
ds = ds.batch(sequence_length, drop_remainder=True) | ||
ds = ds.shuffle(100) | ||
ds = ds.repeat() | ||
ds = ds.batch(batch_size) | ||
ds = ds.map(lambda b: tf.nest.map_structure(tf.transpose, b)) # Time major. | ||
|
||
return tfds.as_numpy(ds) | ||
|
||
|
||
def decode(x: np.ndarray) -> str: | ||
return ''.join([chr(x) for x in x]) | ||
|
||
|
||
def encode(x: str) -> np.ndarray: | ||
return np.array([ord(s) for s in x]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# Lint as: python3 | ||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Character-level language modelling with a recurrent network in JAX.""" | ||
|
||
from typing import Any, NamedTuple | ||
|
||
from absl import app | ||
from absl import flags | ||
from absl import logging | ||
|
||
import haiku as hk | ||
from haiku.examples.rnn import dataset | ||
import jax | ||
from jax import lax | ||
from jax import ops | ||
from jax.experimental import optix | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
flags.DEFINE_integer('train_batch_size', 32, '') | ||
flags.DEFINE_integer('eval_batch_size', 1000, '') | ||
flags.DEFINE_integer('sequence_length', 128, '') | ||
flags.DEFINE_integer('hidden_size', 256, '') | ||
flags.DEFINE_integer('sample_length', 128, '') | ||
flags.DEFINE_float('learning_rate', 1e-3, '') | ||
flags.DEFINE_integer('training_steps', 100_000, '') | ||
flags.DEFINE_integer('evaluation_interval', 100, '') | ||
flags.DEFINE_integer('sampling_interval', 100, '') | ||
flags.DEFINE_integer('seed', 42, '') | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
|
||
class LoopValues(NamedTuple): | ||
tokens: jnp.ndarray | ||
state: Any | ||
rng_key: jnp.ndarray | ||
|
||
|
||
class TrainingState(NamedTuple): | ||
params: hk.Params | ||
opt_state: Any | ||
|
||
|
||
def make_network() -> hk.RNNCore: | ||
"""Defines the network architecture.""" | ||
model = hk.DeepRNN([ | ||
lambda x: hk.one_hot(x, num_classes=dataset.NUM_CHARS), | ||
hk.LSTM(FLAGS.hidden_size), | ||
jax.nn.relu, | ||
hk.LSTM(FLAGS.hidden_size), | ||
hk.nets.MLP([FLAGS.hidden_size, dataset.NUM_CHARS]), | ||
]) | ||
return model | ||
|
||
|
||
def make_optimizer() -> optix.InitUpdate: | ||
"""Defines the optimizer.""" | ||
return optix.adam(FLAGS.learning_rate) | ||
|
||
|
||
def sequence_loss(batch: dataset.Batch) -> jnp.ndarray: | ||
"""Unrolls the network over a sequence of inputs & targets, gets loss.""" | ||
# Note: this function is impure; we hk.transform() it below. | ||
core = make_network() | ||
sequence_length, batch_size = batch['input'].shape | ||
initial_state = core.initial_state(batch_size) | ||
logits, _ = hk.dynamic_unroll(core, batch['input'], initial_state) | ||
log_probs = jax.nn.log_softmax(logits) | ||
one_hot_labels = hk.one_hot(batch['target'], num_classes=logits.shape[-1]) | ||
return -jnp.sum(one_hot_labels * log_probs) / (sequence_length * batch_size) | ||
|
||
|
||
@jax.jit | ||
def update(state: TrainingState, batch: dataset.Batch) -> TrainingState: | ||
"""Does a step of SGD given inputs & targets.""" | ||
_, optimizer = optix.adam(FLAGS.learning_rate) | ||
_, loss_fn = hk.transform(sequence_loss) | ||
gradients = jax.grad(loss_fn)(state.params, batch) | ||
updates, new_opt_state = optimizer(gradients, state.opt_state) | ||
new_params = optix.apply_updates(state.params, updates) | ||
return TrainingState(params=new_params, opt_state=new_opt_state) | ||
|
||
|
||
def sample( | ||
rng_key: jnp.ndarray, | ||
context: jnp.ndarray, | ||
sample_length: int, | ||
) -> jnp.ndarray: | ||
"""Draws samples from the model, given an initial context.""" | ||
# Note: this function is impure; we hk.transform() it below. | ||
assert context.ndim == 1 # No batching for now. | ||
core = make_network() | ||
|
||
def body_fn(t: int, v: LoopValues) -> LoopValues: | ||
token = v.tokens[t] | ||
next_logits, next_state = core(token, v.state) | ||
key, subkey = jax.random.split(v.rng_key) | ||
next_token = jax.random.categorical(subkey, next_logits, axis=-1) | ||
new_tokens = ops.index_update(v.tokens, ops.index[t + 1], next_token) | ||
return LoopValues(tokens=new_tokens, state=next_state, rng_key=key) | ||
|
||
logits, state = hk.dynamic_unroll(core, context, core.initial_state(None)) | ||
key, subkey = jax.random.split(rng_key) | ||
first_token = jax.random.categorical(subkey, logits[-1]) | ||
tokens = np.zeros(sample_length, dtype=np.int32) | ||
tokens = ops.index_update(tokens, ops.index[0], first_token) | ||
initial_values = LoopValues(tokens=tokens, state=state, rng_key=key) | ||
values: LoopValues = lax.fori_loop(0, sample_length, body_fn, initial_values) | ||
|
||
return values.tokens | ||
|
||
|
||
def main(_): | ||
FLAGS.alsologtostderr = True | ||
|
||
# Make training dataset. | ||
train_data = dataset.load( | ||
'train', | ||
batch_size=FLAGS.train_batch_size, | ||
sequence_length=FLAGS.sequence_length) | ||
|
||
# Make evaluation dataset(s). | ||
eval_data = { # pylint: disable=g-complex-comprehension | ||
split: dataset.load( | ||
split, | ||
batch_size=FLAGS.eval_batch_size, | ||
sequence_length=FLAGS.sequence_length) for split in ['train', 'test'] | ||
} | ||
|
||
# Make loss, sampler, and optimizer. | ||
params_init, loss_fn = hk.transform(sequence_loss) | ||
_, sample_fn = hk.transform(sample) | ||
opt_init, _ = make_optimizer() | ||
|
||
loss_fn = jax.jit(loss_fn) | ||
sample_fn = jax.jit(sample_fn, static_argnums=[3]) | ||
|
||
# Initialize training state. | ||
rng = hk.PRNGSequence(FLAGS.seed) | ||
initial_params = params_init(next(rng), next(train_data)) | ||
initial_opt_state = opt_init(initial_params) | ||
state = TrainingState(params=initial_params, opt_state=initial_opt_state) | ||
|
||
# Training loop. | ||
for step in range(FLAGS.training_steps + 1): | ||
# Do a batch of SGD. | ||
train_batch = next(train_data) | ||
state = update(state, train_batch) | ||
|
||
# Periodically generate samples. | ||
if step % FLAGS.sampling_interval == 0: | ||
context = train_batch['input'][:, 0] # First element of training batch. | ||
assert context.ndim == 1 | ||
rng_key = next(rng) | ||
samples = sample_fn(state.params, rng_key, context, FLAGS.sample_length) | ||
|
||
prompt = dataset.decode(context) | ||
continuation = dataset.decode(samples) | ||
|
||
logging.info('Prompt: %s', prompt) | ||
logging.info('Continuation: %s', continuation) | ||
|
||
# Periodically evaluate training and test loss. | ||
if step % FLAGS.evaluation_interval == 0: | ||
for split, ds in eval_data.items(): | ||
eval_batch = next(ds) | ||
loss = loss_fn(state.params, eval_batch) | ||
logging.info({ | ||
'step': step, | ||
'loss': float(loss), | ||
'split': split, | ||
}) | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run(main) |