#### Copyright 2021 Google LLC.

In [None]:
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Hourglass: enwik8 evaluation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/research/examples/hourglass_enwik8.ipynb)

This notebook was designed to run on TPU.

To use TPUs in Colab, click "Runtime" on the main menu bar and select Change runtime type. Set "TPU" as the hardware accelerator.

In [None]:
!gdown https://drive.google.com/uc?id=18wrzKZLBtLuFOHwzuF-7i_p-rD2miE_6
!tar -zxvf enwik8_checkpoint.tar.gz

### Install dependencies

In [None]:
TRAX_GITHUB_URL = 'git+https://github.com/google/trax.git'
!pip install -q --upgrade jax==0.2.21
!pip install -q --upgrade jaxlib==0.1.71+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install -q $TRAX_GITHUB_URL
!pip install -q pickle5
!pip install -q neptune-client
!pip install -q gin

In [None]:
# Execute this for a proper TPU setup!
# Make sure the Colab Runtime is set to Accelerator: TPU.
import jax
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
    url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20200416'
    resp = requests.post(url)
    TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
jax.devices()

### Download enwik8 dataset and load data

A standard script for enwik8 preprocessing is used.

In [None]:
!wget --continue http://mattmahoney.net/dc/enwik8.zip
!wget https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py
!python3 prep_enwik8.py

In [None]:
# The checkpoint was trained with python3.8 which uses pickle5, hence this hack.
layers_base_path = '/usr/local/lib/python3.7/dist-packages/trax/layers/base.py'
with open(layers_base_path, 'r') as f:
    lines = f.readlines()
idx = lines.index('import pickle\n')
lines[idx] = 'import pickle5 as pickle\n'
with open(layers_base_path, 'w') as f:
    f.writelines(lines)

In [None]:
import tensorflow.compat.v1 as tf
from trax.fastmath import numpy as jnp

def raw_ds_to_tensor(raw_file_path):
    with tf.io.gfile.GFile(raw_file_path, mode='rb') as f:
        raw_data = f.read()
        print(f'Bytes in {raw_file_path}:', len(raw_data))
    return jnp.array(list(raw_data))

testset_tensor, validset_tensor = map(raw_ds_to_tensor, [
    '/content/test.txt.raw',
    '/content/valid.txt.raw',
])

### Load the trained checkpoint

In [None]:
import gin
import trax

MODEL_DIR = 'enwik8_checkpoint'

gin.parse_config_file(f'./{MODEL_DIR}/config.gin')

model = trax.models.HourglassLM(mode='eval')

model.init_from_file(
    f'./{MODEL_DIR}/model.pkl.gz',
    weights_only=True
)

loss_fn = trax.layers.WeightedCategoryCrossEntropy()
model_eval = trax.layers.Accelerate(trax.layers.Serial(
    model,
    loss_fn
))

### Evaluate on the test set

In [None]:
from trax import fastmath
from trax.fastmath import numpy as jnp
from tqdm import tqdm


def batched_inputs(data_gen, batch_size):
  inp_stack, mask_stack = [], []

  for input_example, mask in data_gen:
    inp_stack.append(input_example)
    mask_stack.append(mask)
    if len(inp_stack) % batch_size == 0:
      if len(set(len(example) for example in inp_stack)) > 1:
        for x, m in zip(inp_stack, mask_stack):
          yield x, m
      else:
        input_batch = jnp.stack(inp_stack)
        mask_batch = jnp.stack(mask_stack)

        yield input_batch, mask_batch
      inp_stack, mask_stack = [], []

  if len(inp_stack) > 0:
    for inp, mask in zip(inp_stack, mask_stack):
      yield inp, mask


def run_full_evaluation(accelerated_model_with_loss, examples_data_gen,
                        batch_size, pad_to_len=None):
  # Important: we assume batch size per device = 1
  assert batch_size % fastmath.local_device_count() == 0
  assert fastmath.local_device_count() == 1 or \
         batch_size == fastmath.local_device_count()

  loss_sum, n_tokens = 0.0, 0

  def pad_right(inp_tensor):
    if pad_to_len:
      return jnp.pad(inp_tensor,
                     [[0, 0], [0, max(0, pad_to_len - inp_tensor.shape[1])]])
    else:
      return inp_tensor

  batch_gen = batched_inputs(examples_data_gen, batch_size)

  def batch_leftover_example(input_example, example_mask):
    def extend_shape_to_batch_size(tensor):
      return jnp.repeat(tensor, repeats=batch_size, axis=0)

    return map(extend_shape_to_batch_size,
               (input_example[None, ...], example_mask[None, ...]))

  for i, (inp, mask) in tqdm(enumerate(batch_gen)):
    leftover_batch = False
    if len(inp.shape) == 1:
      inp, mask = batch_leftover_example(inp, mask)
      leftover_batch = True

    inp, mask = map(pad_right, [inp, mask])

    example_losses = accelerated_model_with_loss((inp, inp, mask))

    if leftover_batch:
      example_losses = example_losses[:1]
      mask = mask[:1]

    example_lengths = mask.sum(axis=-1)

    loss_sum += (example_lengths * example_losses).sum()
    n_tokens += mask.sum()

    if i % 200 == 0:
      print(f'Batches: {i}, current loss: {loss_sum / float(n_tokens)}')

  return loss_sum / float(n_tokens)

We evaluate chunks of length $128$ bytes, preceded by a context of $128 \cdot 31$ bytes (the whole context is $4096$)

In [None]:
# Prepare the input generator: it should yield (input, mask) tuples
def contextful_eval_data(bytes_tensor, CHUNK_LEN, N_CHUNKS_BEFORE):
    for start in range(0, len(bytes_tensor), CHUNK_LEN):
        shifted_chunk = bytes_tensor[max(0, start - (N_CHUNKS_BEFORE * CHUNK_LEN)):
                                                    start+CHUNK_LEN]
        mask = jnp.zeros_like(shifted_chunk)
        masked_len = min(CHUNK_LEN, len(bytes_tensor) - start)

        mask = fastmath.index_update(mask, jax.ops.index[-masked_len:], 1)

        shifted_chunk = trax.data.inputs._pad_to_multiple_of(shifted_chunk,
                                                             CHUNK_LEN, axis=0)
        mask = trax.data.inputs._pad_to_multiple_of(mask, CHUNK_LEN, axis=0)

        yield shifted_chunk, mask

# Split the input into chunks of 4096
PAD_TO_LEN = 4098 # We need to pad because shorten factor 3 is used.
CHUNK_LEN = 128 #
N_CHUNKS_BEFORE = 31

BATCH_SIZE = 8

test_data_gen = contextful_eval_data(testset_tensor, CHUNK_LEN, N_CHUNKS_BEFORE)

loss = run_full_evaluation(model_eval, test_data_gen, BATCH_SIZE, PAD_TO_LEN)

1it [00:01,  1.48s/it]

Batches: 0, current loss: 1.1697633266448975


201it [03:35,  1.09s/it]

Batches: 200, current loss: 0.671483039855957


401it [07:11,  1.09s/it]

Batches: 400, current loss: 0.6324439644813538


601it [10:46,  1.09s/it]

Batches: 600, current loss: 0.669732928276062


801it [14:22,  1.08s/it]

Batches: 800, current loss: 0.690700888633728


1001it [17:59,  1.08s/it]

Batches: 1000, current loss: 0.7042409777641296


1201it [21:35,  1.07s/it]

Batches: 1200, current loss: 0.7051951885223389


1401it [25:09,  1.06s/it]

Batches: 1400, current loss: 0.7044456005096436


1601it [28:45,  1.12s/it]

Batches: 1600, current loss: 0.7035351991653442


1801it [32:20,  1.07s/it]

Batches: 1800, current loss: 0.690122663974762


2001it [35:55,  1.07s/it]

Batches: 2000, current loss: 0.6649767756462097


2201it [39:30,  1.07s/it]

Batches: 2200, current loss: 0.6716358661651611


2401it [43:04,  1.08s/it]

Batches: 2400, current loss: 0.6756933331489563


2601it [46:39,  1.08s/it]

Batches: 2600, current loss: 0.6825714707374573


2801it [50:13,  1.06s/it]

Batches: 2800, current loss: 0.6843773722648621


3001it [53:46,  1.06s/it]

Batches: 3000, current loss: 0.6865374445915222


3201it [57:22,  1.09s/it]

Batches: 3200, current loss: 0.6855794787406921


3401it [1:00:58,  1.07s/it]

Batches: 3400, current loss: 0.6887989640235901


3601it [1:04:32,  1.06s/it]

Batches: 3600, current loss: 0.688316822052002


3801it [1:08:05,  1.06s/it]

Batches: 3800, current loss: 0.6921071410179138


4001it [1:11:40,  1.07s/it]

Batches: 4000, current loss: 0.6904897093772888


4201it [1:15:16,  1.07s/it]

Batches: 4200, current loss: 0.6908246278762817


4401it [1:18:52,  1.07s/it]

Batches: 4400, current loss: 0.6909059882164001


4601it [1:22:26,  1.20s/it]

Batches: 4600, current loss: 0.6896733045578003


4801it [1:26:02,  1.08s/it]

Batches: 4800, current loss: 0.6903342604637146


4917it [1:28:07,  1.08s/it]


In [None]:
print(f'Final perplexity: {loss}, final bpd: {loss / jnp.log(2)}')

Final perplexity: 0.6918511986732483, final bpd: 0.9981303811073303


### Generate text from the model

In [None]:
import numpy as np
from tqdm import tqdm

def autoregressive_sample(model, temp=1.0, batch_size=8, l=3072, vocab_size=256):
  model = trax.layers.Accelerate(model)
  x = np.zeros((batch_size, l), dtype=np.int32)

  logits_prev = np.zeros((batch_size, l, vocab_size), dtype=np.float32)
  for i in tqdm(range(l)):
    logits = model(x)
    np.testing.assert_array_almost_equal(logits_prev[:, :i], logits[:, :i])
    logits_prev = logits

    sample = trax.layers.logsoftmax_sample(logits[:, i, :], temperature=temp)
    x[:, i] = sample
  return x

In [None]:
samples = autoregressive_sample(model, l=1026)

100%|██████████| 513/513 [05:47<00:00,  1.48it/s]


In [None]:
bytes((samples[0]).tolist()).decode()

' the political laws.  War also helped develop the [[Soviet Union|Soviet]] system in western Europe, as did [[Luxembourg]] and the Church of [[Sweden]]. The immediate impact of nuclear war took place in eastern Europe, bankrupted by serious [[Nuclear fallout|fallout]] from [[Early Modern Europe|miners from both sides]]. The state ally strategic eastern Europe had immediately concluded the war in both sides, although it was extremely weak. Meanwhile, the US was developing an urban and commercial life of more t'