In [1]:
# Install dependencies
!pip install --upgrade -q jax
!pip install --upgrade -q jaxlib
!pip install --upgrade -q trax
!pip install --upgrade -q sentencepiece
!pip install --upgrade -q gin 

# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

[?25l[K     |▊                               | 10kB 27.9MB/s eta 0:00:01[K     |█▍                              | 20kB 2.1MB/s eta 0:00:01[K     |██                              | 30kB 2.5MB/s eta 0:00:01[K     |██▊                             | 40kB 3.0MB/s eta 0:00:01[K     |███▍                            | 51kB 2.3MB/s eta 0:00:01[K     |████                            | 61kB 2.6MB/s eta 0:00:01[K     |████▊                           | 71kB 2.9MB/s eta 0:00:01[K     |█████▍                          | 81kB 3.1MB/s eta 0:00:01[K     |██████                          | 92kB 3.3MB/s eta 0:00:01[K     |██████▊                         | 102kB 3.3MB/s eta 0:00:01[K     |███████▍                        | 112kB 3.3MB/s eta 0:00:01[K     |████████                        | 122kB 3.3MB/s eta 0:00:01[K     |████████▊                       | 133kB 3.3MB/s eta 0:00:01[K     |█████████▍                      | 143kB 3.3MB/s eta 0:00:01[K     |██████████                

In [2]:
import gin
import os
import numpy as np
from scipy.special import softmax

# Zipping and downloading files
from google.colab import files
import shutil

# Trax
import jax
import trax
from trax.data import inputs
import jax.numpy as jnp

# NLP Vocab Generation
import sentencepiece as spm

# TensorFlow
from tensorflow.compat.v1.io.gfile import GFile
import tensorflow as tf

In [3]:
# Download `The Republic` by Plato text
FILENAME = '.'.join(['the_republic', 'txt'])
URL = 'http://www.gutenberg.org/cache/epub/1497/pg1497.txt'
tf.keras.utils.get_file(FILENAME, URL, cache_dir='.')
TEXT_PATH = os.path.join('datasets', FILENAME)

Downloading data from http://www.gutenberg.org/cache/epub/1497/pg1497.txt


In [4]:
# Use only novel text
with GFile(TEXT_PATH) as f:
    text = f.read()

start = text.rfind('INTRODUCTION AND ANALYSIS')
start = text.find('The Republic', start + 1)
end = text.rfind('End of the Project Gutenberg EBook of The Republic, by Plato')
text = text[start:end].strip()

In [5]:
# Train a BPE model on the text
spm.SentencePieceTrainer.train('--input=datasets/the_republic.txt \
                                --model_prefix=cp.320 \
                                --vocab_size=320 \
                                --model_type=bpe') 
# Load BPE vocabulary
TOKENIZER = spm.SentencePieceProcessor() 
TOKENIZER.load('cp.320.model')

True

In [6]:
# Tokenize
IDS = TOKENIZER.EncodeAsIds(text)
IDS = np.asarray(IDS, dtype=np.int32)
PAD_AMOUNT = 512 * 1024 - len(IDS)
print("Number of tokens:", IDS.shape[0])

Number of tokens: 512874


In [7]:
# Set up the data pipeline.
def my_inputs(n_devices):
  while True:
    inputs = []
    mask = []
    pad_amounts = np.random.choice(PAD_AMOUNT, n_devices)
    for i in range(n_devices):
      inputs.append(np.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]), # Pad IDS by different amount for each device
                            mode='constant'))
      mask.append(np.pad(np.ones_like(IDS, dtype=np.float32),
                          (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
                          mode='constant'))
    inputs = np.stack(inputs)
    mask = np.stack(mask)
    yield (inputs, inputs, mask)

print("(device count, tokens per device) = ",
      next(my_inputs(trax.fastmath.device_count()))[0].shape)

(device count, tokens per device) =  (8, 524288)


In [8]:
# Configure hyperparameters.
gin.parse_config("""
import trax.layers
import trax.models
import trax.optimizers
import trax.data.inputs
import trax.supervised.trainer_lib

# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.ReformerLM
# Model will have 6 layers, alternating between the LSH attention
# and local attention within a certain context window.
n_layers = 6
attn_type = [
  @trax.layers.SelfAttention,
  @LSHSelfAttention,  
  @trax.layers.SelfAttention,
  @LSHSelfAttention,
  @trax.layers.SelfAttention,
  @LSHSelfAttention,
  ]
share_qk = False  # LSH attention ignores this flag and always shares q & k
n_heads = 2
attn_kv = 64
dropout = 0.05
n_tokens = 524288

# Parameters for multifactor:
# ==============================================================================
multifactor.constant = 0.01
multifactor.factors = 'constant * linear_warmup * cosine_decay'
multifactor.warmup_steps = 100
multifactor.steps_per_cycle = 900

# Parameters for Adam:
# ==============================================================================
Adam.weight_decay_rate=0.0
Adam.b1 = 0.86
Adam.b2 = 0.92
Adam.eps = 1e-9

# Parameters for SelfAttention:
# ==============================================================================
trax.layers.SelfAttention.attention_dropout = 0.05
trax.layers.SelfAttention.chunk_len = 64
trax.layers.SelfAttention.n_chunks_before = 1
trax.layers.SelfAttention.n_parallel_heads = 1

# Parameters for LSHSelfAttention:
# ==============================================================================
LSHSelfAttention.attention_dropout = 0.0
LSHSelfAttention.chunk_len = 64
LSHSelfAttention.n_buckets = [64, 128]
LSHSelfAttention.n_chunks_after = 0
LSHSelfAttention.n_chunks_before = 1
LSHSelfAttention.n_hashes = 1
LSHSelfAttention.n_parallel_heads = 1
LSHSelfAttention.predict_drop_len = 128
LSHSelfAttention.predict_mem_len = 1024

# Parameters for ReformerLM:
# ==============================================================================
ReformerLM.attention_type = %attn_type
ReformerLM.d_attention_key = %attn_kv
ReformerLM.d_attention_value = %attn_kv
ReformerLM.d_model = 256
ReformerLM.d_ff = 512
ReformerLM.dropout = %dropout
ReformerLM.ff_activation = @trax.layers.Relu
ReformerLM.max_len = %n_tokens
ReformerLM.mode = 'train'
ReformerLM.n_heads = %n_heads
ReformerLM.n_layers = %n_layers
ReformerLM.vocab_size = 320
ReformerLM.axial_pos_shape = (512, 1024)
ReformerLM.d_axial_pos_embs= (64, 192)
""")

In [9]:
# Trainer.
output_dir = os.path.expanduser('model')
!rm -f ~/model/model.pkl.gz  # Remove old model

trainer = trax.supervised.Trainer(
    model=trax.models.ReformerLM,
    loss_fn=trax.layers.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.multifactor(),
    inputs=trax.data.inputs.Inputs(my_inputs),
    output_dir=output_dir)

In [None]:
# Train Model
for _ in range(50):
  trainer.train_epoch(n_steps=100, n_eval_steps=20)

In [13]:
# Zip directory contents
shutil.make_archive("project", "zip", ".")

# Download zipped directory
files.download('project.zip')

'/content/text_generation.zip'

In [14]:
# In the Reformer paper, increasing the number of hashing rounds helps with quality. 
# The number of hashing rounds at can be increased at evaluation time only.
gin.parse_config("""LSHSelfAttention.n_hashes = 8""")

In [18]:
# Load the trained Reformer in 'predict' mode
model = trax.models.ReformerLM(mode='predict')
model.init_from_file(os.path.join(output_dir,'model.pkl.gz'),
                     weights_only=True)

# Sample from ReformerLM
output_token_ids = trax.supervised.decoding.autoregressive_sample(
    model, temperature=0.2)

# Decode token IDs
# Reformer outputed a batch with one item so access it using [0]
# tolist() converts from int64 to int, the type SentencePiece expects
TOKENIZER.DecodeIds(output_token_ids[0].tolist())

'The Republic of Plato. The man is remost highest philosophy of Herodic has been said to have recognised an ancient to the authority of Hellas. The greatest of all knowing that he doubted at Socrates and Plato, is the practicabil'