Skip to content

Commit

Permalink
Add example training an LSTM language model.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 303707157
Change-Id: I25a8b0f4d5a9707766fba2c8980de1e5d8dfdb66
  • Loading branch information
aslanides authored and Copybara-Service committed Mar 30, 2020
1 parent 9201846 commit a7bb782
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 0 deletions.
30 changes: 30 additions & 0 deletions examples/rnn/BUILD
@@ -0,0 +1,30 @@
load("//haiku/_src:build_defs.bzl", "hk_py_binary", "hk_py_library")

licenses(["notice"])

exports_files(["LICENSE"])

hk_py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = [
# pip: numpy
# pip: tensorflow
# pip: tensorflow_datasets
],
)

hk_py_binary(
name = "train",
srcs = ["train.py"],
deps = [
":dataset",
# pip: absl:app
# pip: absl/flags
# pip: absl/logging
"//haiku",
# pip: jax
# pip: jax:optix
# pip: numpy
],
)
60 changes: 60 additions & 0 deletions examples/rnn/dataset.py
@@ -0,0 +1,60 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tiny Shakespeare as a language modelling dataset."""

from typing import Iterator, Mapping

import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

Batch = Mapping[str, np.ndarray]
NUM_CHARS = 128


def load(
split: tfds.Split,
*,
batch_size: int,
sequence_length: int,
) -> Iterator[Batch]:
"""Creates the Tiny Shakespeare dataset as a character modelling task."""

def preprocess_fn(x: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
x = x['text']
x = tf.strings.unicode_split(x, 'UTF-8')
x = tf.squeeze(tf.io.decode_raw(x, tf.uint8), axis=-1)
x = tf.cast(x, tf.int32)
return {'input': x[:-1], 'target': x[1:]}

ds = tfds.load(name='tiny_shakespeare', split=split)
ds = ds.map(preprocess_fn)
ds = ds.unbatch()
ds = ds.batch(sequence_length, drop_remainder=True)
ds = ds.shuffle(100)
ds = ds.repeat()
ds = ds.batch(batch_size)
ds = ds.map(lambda b: tf.nest.map_structure(tf.transpose, b)) # Time major.

return tfds.as_numpy(ds)


def decode(x: np.ndarray) -> str:
return ''.join([chr(x) for x in x])


def encode(x: str) -> np.ndarray:
return np.array([ord(s) for s in x])
190 changes: 190 additions & 0 deletions examples/rnn/train.py
@@ -0,0 +1,190 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Character-level language modelling with a recurrent network in JAX."""

from typing import Any, NamedTuple

from absl import app
from absl import flags
from absl import logging

import haiku as hk
from haiku.examples.rnn import dataset
import jax
from jax import lax
from jax import ops
from jax.experimental import optix
import jax.numpy as jnp
import numpy as np

flags.DEFINE_integer('train_batch_size', 32, '')
flags.DEFINE_integer('eval_batch_size', 1000, '')
flags.DEFINE_integer('sequence_length', 128, '')
flags.DEFINE_integer('hidden_size', 256, '')
flags.DEFINE_integer('sample_length', 128, '')
flags.DEFINE_float('learning_rate', 1e-3, '')
flags.DEFINE_integer('training_steps', 100_000, '')
flags.DEFINE_integer('evaluation_interval', 100, '')
flags.DEFINE_integer('sampling_interval', 100, '')
flags.DEFINE_integer('seed', 42, '')

FLAGS = flags.FLAGS


class LoopValues(NamedTuple):
tokens: jnp.ndarray
state: Any
rng_key: jnp.ndarray


class TrainingState(NamedTuple):
params: hk.Params
opt_state: Any


def make_network() -> hk.RNNCore:
"""Defines the network architecture."""
model = hk.DeepRNN([
lambda x: hk.one_hot(x, num_classes=dataset.NUM_CHARS),
hk.LSTM(FLAGS.hidden_size),
jax.nn.relu,
hk.LSTM(FLAGS.hidden_size),
hk.nets.MLP([FLAGS.hidden_size, dataset.NUM_CHARS]),
])
return model


def make_optimizer() -> optix.InitUpdate:
"""Defines the optimizer."""
return optix.adam(FLAGS.learning_rate)


def sequence_loss(batch: dataset.Batch) -> jnp.ndarray:
"""Unrolls the network over a sequence of inputs & targets, gets loss."""
# Note: this function is impure; we hk.transform() it below.
core = make_network()
sequence_length, batch_size = batch['input'].shape
initial_state = core.initial_state(batch_size)
logits, _ = hk.dynamic_unroll(core, batch['input'], initial_state)
log_probs = jax.nn.log_softmax(logits)
one_hot_labels = hk.one_hot(batch['target'], num_classes=logits.shape[-1])
return -jnp.sum(one_hot_labels * log_probs) / (sequence_length * batch_size)


@jax.jit
def update(state: TrainingState, batch: dataset.Batch) -> TrainingState:
"""Does a step of SGD given inputs & targets."""
_, optimizer = optix.adam(FLAGS.learning_rate)
_, loss_fn = hk.transform(sequence_loss)
gradients = jax.grad(loss_fn)(state.params, batch)
updates, new_opt_state = optimizer(gradients, state.opt_state)
new_params = optix.apply_updates(state.params, updates)
return TrainingState(params=new_params, opt_state=new_opt_state)


def sample(
rng_key: jnp.ndarray,
context: jnp.ndarray,
sample_length: int,
) -> jnp.ndarray:
"""Draws samples from the model, given an initial context."""
# Note: this function is impure; we hk.transform() it below.
assert context.ndim == 1 # No batching for now.
core = make_network()

def body_fn(t: int, v: LoopValues) -> LoopValues:
token = v.tokens[t]
next_logits, next_state = core(token, v.state)
key, subkey = jax.random.split(v.rng_key)
next_token = jax.random.categorical(subkey, next_logits, axis=-1)
new_tokens = ops.index_update(v.tokens, ops.index[t + 1], next_token)
return LoopValues(tokens=new_tokens, state=next_state, rng_key=key)

logits, state = hk.dynamic_unroll(core, context, core.initial_state(None))
key, subkey = jax.random.split(rng_key)
first_token = jax.random.categorical(subkey, logits[-1])
tokens = np.zeros(sample_length, dtype=np.int32)
tokens = ops.index_update(tokens, ops.index[0], first_token)
initial_values = LoopValues(tokens=tokens, state=state, rng_key=key)
values: LoopValues = lax.fori_loop(0, sample_length, body_fn, initial_values)

return values.tokens


def main(_):
FLAGS.alsologtostderr = True

# Make training dataset.
train_data = dataset.load(
'train',
batch_size=FLAGS.train_batch_size,
sequence_length=FLAGS.sequence_length)

# Make evaluation dataset(s).
eval_data = { # pylint: disable=g-complex-comprehension
split: dataset.load(
split,
batch_size=FLAGS.eval_batch_size,
sequence_length=FLAGS.sequence_length) for split in ['train', 'test']
}

# Make loss, sampler, and optimizer.
params_init, loss_fn = hk.transform(sequence_loss)
_, sample_fn = hk.transform(sample)
opt_init, _ = make_optimizer()

loss_fn = jax.jit(loss_fn)
sample_fn = jax.jit(sample_fn, static_argnums=[3])

# Initialize training state.
rng = hk.PRNGSequence(FLAGS.seed)
initial_params = params_init(next(rng), next(train_data))
initial_opt_state = opt_init(initial_params)
state = TrainingState(params=initial_params, opt_state=initial_opt_state)

# Training loop.
for step in range(FLAGS.training_steps + 1):
# Do a batch of SGD.
train_batch = next(train_data)
state = update(state, train_batch)

# Periodically generate samples.
if step % FLAGS.sampling_interval == 0:
context = train_batch['input'][:, 0] # First element of training batch.
assert context.ndim == 1
rng_key = next(rng)
samples = sample_fn(state.params, rng_key, context, FLAGS.sample_length)

prompt = dataset.decode(context)
continuation = dataset.decode(samples)

logging.info('Prompt: %s', prompt)
logging.info('Continuation: %s', continuation)

# Periodically evaluate training and test loss.
if step % FLAGS.evaluation_interval == 0:
for split, ds in eval_data.items():
eval_batch = next(ds)
loss = loss_fn(state.params, eval_batch)
logging.info({
'step': step,
'loss': float(loss),
'split': split,
})


if __name__ == '__main__':
app.run(main)

0 comments on commit a7bb782

Please sign in to comment.