#### Copyright 2021 Google LLC.

In [None]:
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Hourglass: enwik8 evaluation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/research/examples/hourglass_enwik8.ipynb)

This notebook was designed to run on TPU.

To use TPUs in Colab, click "Runtime" on the main menu bar and select Change runtime type. Set "TPU" as the hardware accelerator.

### Install dependencies

In [2]:
TRAX_GITHUB_URL = 'git+https://github.com/google/trax.git'
!pip install -q --upgrade jax==0.2.21
!pip install -q --upgrade jaxlib==0.1.71+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install -q $TRAX_GITHUB_URL
!pip install -q pickle5
!pip install -q neptune-client
!pip install -q gin

[K     |████████████████████████████████| 4.4 MB 5.1 MB/s 
[?25h  Building wheel for trax (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 132 kB 5.2 MB/s 
[?25h  Building wheel for pickle5 (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 275 kB 5.2 MB/s 
[K     |████████████████████████████████| 829 kB 41.4 MB/s 
[K     |████████████████████████████████| 52 kB 1.4 MB/s 
[K     |████████████████████████████████| 180 kB 55.3 MB/s 
[K     |████████████████████████████████| 131 kB 62.7 MB/s 
[K     |████████████████████████████████| 8.0 MB 37.9 MB/s 
[K     |████████████████████████████████| 79 kB 6.9 MB/s 
[K     |████████████████████████████████| 138 kB 56.7 MB/s 
[K     |████████████████████████████████| 63 kB 1.6 MB/s 
[K     |████████████████████████████████| 127 kB 47.8 MB/s 
[K     |████████████████████████████████| 67 kB 5.1 MB/s 
[K     |████████████████████████████████| 129 kB 47.2 MB/s 
[?25h  Building wheel for

In [3]:
# Execute this for a proper TPU setup!
# Make sure the Colab Runtime is set to Accelerator: TPU.
import jax
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
    url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20200416'
    resp = requests.post(url)
    TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
jax.devices()

grpc://10.55.87.106:8470


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

### Download enwik8 dataset and load data

A standard script for enwik8 preprocessing is used.

In [7]:
!wget --continue http://mattmahoney.net/dc/enwik8.zip
!wget https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py
!python3 prep_enwik8.py

--2021-10-07 05:41:17--  http://mattmahoney.net/dc/enwik8.zip
Resolving mattmahoney.net (mattmahoney.net)... 67.195.197.24
Connecting to mattmahoney.net (mattmahoney.net)|67.195.197.24|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 36445475 (35M) [application/zip]
Saving to: ‘enwik8.zip’


2021-10-07 05:41:34 (2.00 MB/s) - ‘enwik8.zip’ saved [36445475/36445475]

--2021-10-07 05:41:34--  https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 818 [text/plain]
Saving to: ‘prep_enwik8.py.1’


2021-10-07 05:41:34 (29.6 MB/s) - ‘prep_enwik8.py.1’ saved [818/818]

Length of enwik8: 100000000
train.txt will have 90000000 bytes
- Tokenizing...
- Writing...
va

In [5]:
# The checkpoint was trained with python3.8 which uses pickle5, hence this hack.
layers_base_path = '/usr/local/lib/python3.7/dist-packages/trax/layers/base.py'
with open(layers_base_path, 'r') as f:
    lines = f.readlines()
idx = lines.index('import pickle\n')
lines[idx] = 'import pickle5 as pickle\n'
with open(layers_base_path, 'w') as f:
    f.writelines(lines)

In [8]:
import tensorflow.compat.v1 as tf
from trax.fastmath import numpy as jnp

def raw_ds_to_tensor(raw_file_path):
    with tf.io.gfile.GFile(raw_file_path, mode='rb') as f:
        raw_data = f.read()
        print(f'Bytes in {raw_file_path}:', len(raw_data))
    return jnp.array(list(raw_data))

testset_tensor, validset_tensor = map(raw_ds_to_tensor, [
    '/content/test.txt.raw',
    '/content/valid.txt.raw',
])

Bytes in /content/test.txt.raw: 5000000
Bytes in /content/valid.txt.raw: 5000000


### Download and load the trained checkpoint

In [None]:
!gdown https://drive.google.com/uc?id=18wrzKZLBtLuFOHwzuF-7i_p-rD2miE_6
!tar -zxvf enwik8_checkpoint.tar.gz

In [9]:
import gin
import trax

MODEL_DIR = 'enwik8_checkpoint'

gin.parse_config_file(f'./{MODEL_DIR}/config.gin')

model = trax.models.HourglassLM(mode='eval')

model.init_from_file(
    f'./{MODEL_DIR}/model.pkl.gz',
    weights_only=True
)

loss_fn = trax.layers.WeightedCategoryCrossEntropy()
model_eval = trax.layers.Accelerate(trax.layers.Serial(
    model,
    loss_fn
))

### Evaluate on the test set

In [10]:
from trax import fastmath
from trax.fastmath import numpy as jnp
from tqdm import tqdm


def batched_inputs(data_gen, batch_size):
  inp_stack, mask_stack = [], []

  for input_example, mask in data_gen:
    inp_stack.append(input_example)
    mask_stack.append(mask)
    if len(inp_stack) % batch_size == 0:
      if len(set(len(example) for example in inp_stack)) > 1:
        for x, m in zip(inp_stack, mask_stack):
          yield x, m
      else:
        input_batch = jnp.stack(inp_stack)
        mask_batch = jnp.stack(mask_stack)

        yield input_batch, mask_batch
      inp_stack, mask_stack = [], []

  if len(inp_stack) > 0:
    for x, m in zip(inp_stack, mask_stack):
      yield x, m


def run_full_evaluation(accelerated_model_with_loss, examples_data_gen,
                        batch_size, pad_to_len=None):
  # Important: we assume batch size per device = 1
  assert batch_size % fastmath.local_device_count() == 0
  assert fastmath.local_device_count() == 1 or \
         batch_size == fastmath.local_device_count()

  loss_sum, n_tokens = 0.0, 0

  def pad_right(inp_tensor):
    if pad_to_len:
      return jnp.pad(inp_tensor,
                     [[0, 0], [0, max(0, pad_to_len - inp_tensor.shape[1])]])
    else:
      return inp_tensor

  batch_gen = batched_inputs(examples_data_gen, batch_size)

  def batch_leftover_example(input_example, example_mask):
    def extend_shape_to_batch_size(tensor):
      return jnp.repeat(tensor, repeats=batch_size, axis=0)

    return map(extend_shape_to_batch_size,
               (input_example[None, ...], example_mask[None, ...]))

  for i, (inp, mask) in tqdm(enumerate(batch_gen)):
    leftover_batch = False
    # For leftover examples, we yield rank 1 tensors (unbatched) instead of
    # rank 2 batches from our `batched_inputs` function. This convention allows
    # a special behaviour for the leftover batches that have to be processed
    # one by one.
    if len(inp.shape) == 1:
      inp, mask = batch_leftover_example(inp, mask)
      leftover_batch = True

    inp, mask = map(pad_right, [inp, mask])

    example_losses = accelerated_model_with_loss((inp, inp, mask))

    if leftover_batch:
      example_losses = example_losses[:1]
      mask = mask[:1]

    example_lengths = mask.sum(axis=-1)

    loss_sum += (example_lengths * example_losses).sum()
    n_tokens += mask.sum()

    if i % 200 == 0:
      print(f'Batches: {i}, current loss: {loss_sum / float(n_tokens)}')

  return loss_sum / float(n_tokens)

We evaluate chunks of length $128$ bytes, preceded by a context of $128 \cdot 53$ bytes (total context length is $6912$)

In [11]:
# Prepare the input generator: it should yield (input, mask) tuples
def contextful_eval_data(bytes_tensor, CHUNK_LEN, N_CHUNKS_BEFORE):
    for start in range(0, len(bytes_tensor), CHUNK_LEN):
        shifted_chunk = bytes_tensor[max(0, start - (N_CHUNKS_BEFORE * CHUNK_LEN)):
                                                    start+CHUNK_LEN]
        mask = jnp.zeros_like(shifted_chunk)
        masked_len = min(CHUNK_LEN, len(bytes_tensor) - start)

        mask = fastmath.index_update(mask, jax.ops.index[-masked_len:], 1)

        shifted_chunk = trax.data.inputs._pad_to_multiple_of(shifted_chunk,
                                                             CHUNK_LEN, axis=0)
        mask = trax.data.inputs._pad_to_multiple_of(mask, CHUNK_LEN, axis=0)

        yield shifted_chunk, mask

# Split the input into chunks of 6912
PAD_TO_LEN = 6912 # We need to pad because shorten factor 3 is used.
CHUNK_LEN = 128 #
N_CHUNKS_BEFORE = 53

BATCH_SIZE = 8

test_data_gen = contextful_eval_data(testset_tensor, CHUNK_LEN, N_CHUNKS_BEFORE)

loss = run_full_evaluation(model_eval, test_data_gen, BATCH_SIZE, PAD_TO_LEN)

1it [08:27, 507.55s/it]

Batches: 0, current loss: 1.1698029041290283


201it [15:34,  2.17s/it]

Batches: 200, current loss: 0.6762865781784058


401it [22:42,  2.15s/it]

Batches: 400, current loss: 0.6353754997253418


601it [29:47,  2.13s/it]

Batches: 600, current loss: 0.6671227812767029


801it [36:53,  2.13s/it]

Batches: 800, current loss: 0.6871113777160645


1001it [44:01,  2.14s/it]

Batches: 1000, current loss: 0.7012861967086792


1201it [51:08,  2.12s/it]

Batches: 1200, current loss: 0.7057864665985107


1401it [58:13,  2.12s/it]

Batches: 1400, current loss: 0.7055128216743469


1601it [1:05:19,  2.12s/it]

Batches: 1600, current loss: 0.7026289701461792


1801it [1:12:27,  2.14s/it]

Batches: 1800, current loss: 0.6966000199317932


2001it [1:19:36,  2.14s/it]

Batches: 2000, current loss: 0.6651851534843445


2201it [1:26:43,  2.13s/it]

Batches: 2200, current loss: 0.6696659922599792


2401it [1:33:49,  2.12s/it]

Batches: 2400, current loss: 0.6754175424575806


2601it [1:40:55,  2.14s/it]

Batches: 2600, current loss: 0.6817601919174194


2801it [1:48:03,  2.15s/it]

Batches: 2800, current loss: 0.6837883591651917


3001it [1:55:13,  2.15s/it]

Batches: 3000, current loss: 0.6857511401176453


3201it [2:02:22,  2.15s/it]

Batches: 3200, current loss: 0.6849543452262878


3401it [2:09:32,  2.14s/it]

Batches: 3400, current loss: 0.6871944665908813


3601it [2:16:40,  2.15s/it]

Batches: 3600, current loss: 0.6876044273376465


3801it [2:23:47,  2.12s/it]

Batches: 3800, current loss: 0.690453052520752


4001it [2:30:53,  2.12s/it]

Batches: 4000, current loss: 0.689698338508606


4201it [2:37:59,  2.14s/it]

Batches: 4200, current loss: 0.6900818943977356


4401it [2:45:05,  2.14s/it]

Batches: 4400, current loss: 0.6902956366539001


4601it [2:52:15,  2.17s/it]

Batches: 4600, current loss: 0.689020037651062


4801it [2:59:22,  2.13s/it]

Batches: 4800, current loss: 0.6893221735954285


4938it [3:04:14,  2.24s/it]


In [12]:
print(f'Final perplexity: {loss}, final bpd: {loss / jnp.log(2)}')

Final perplexity: 0.6912750005722046, final bpd: 0.997299075126648


### Generate text from the model

In [None]:
import numpy as np
from tqdm import tqdm

def autoregressive_sample(model, temp=1.0, batch_size=8, l=3072, vocab_size=256):
  model = trax.layers.Accelerate(model)
  x = np.zeros((batch_size, l), dtype=np.int32)

  logits_prev = np.zeros((batch_size, l, vocab_size), dtype=np.float32)
  for i in tqdm(range(l)):
    logits = model(x)
    np.testing.assert_array_almost_equal(logits_prev[:, :i], logits[:, :i])
    logits_prev = logits

    sample = trax.layers.logsoftmax_sample(logits[:, i, :], temperature=temp)
    x[:, i] = sample
  return x

In [None]:
samples = autoregressive_sample(model, l=1026)

100%|██████████| 513/513 [05:47<00:00,  1.48it/s]


Text sample generated by the model (unconditional generation - without any prompts):

In [None]:
bytes((samples[0]).tolist()).decode()

' the political laws.  War also helped develop the [[Soviet Union|Soviet]] system in western Europe, as did [[Luxembourg]] and the Church of [[Sweden]]. The immediate impact of nuclear war took place in eastern Europe, bankrupted by serious [[Nuclear fallout|fallout]] from [[Early Modern Europe|miners from both sides]]. The state ally strategic eastern Europe had immediately concluded the war in both sides, although it was extremely weak. Meanwhile, the US was developing an urban and commercial life of more t'