In [81]:
import collections
import os
import pathlib
import re
import string
import sys
import tempfile
import time

import numpy as np
import matplotlib.pyplot as plt

import tensorflow_datasets as tfds
import tensorflow_text as text
import tensorflow as tf

In [82]:
tf.get_logger().setLevel('ERROR')
pwd = pathlib.Path.cwd()

In [83]:
dataset_name = 'wmt14_translate/fr-en'
data_dir = 'nlp_lab_dataset/'
train_samples = 60000  



# Load the dataset with specified splits
ds_splits = tfds.load(dataset_name, split=['train', 'validation', 'test'], data_dir=data_dir)

# Take a subset of the training set
train_ds = ds_splits[0].take(train_samples)
val_ds = ds_splits[1]
test_ds = ds_splits[2]

In [84]:
# Example of how to use the subsets
print("Training set samples:")
for batch in train_ds.batch(3).take(1):
    print('> Examples in English:')
    en_examples = batch["en"].numpy()
    for en in en_examples:
        print(en.decode("utf-8"))

    print()
    
    print('> Examples in French:')
    fr_examples = batch["fr"].numpy()
    for fr in fr_examples:
        print(fr.decode("utf-8"))

Training set samples:
> Examples in English:
In his briefing on economic development, Al Horner will give you details of programs we fund to foster partnerships between the private sector and First Nations and Inuit communities, in areas like resource development projects, for example.
(b) Positive aspects
Crop insurance payments include only government crop insurance programs; private hail insurance payments are excluded.

> Examples in French:
Dans sa présentation sur le développement économique, M. Al Horner vous donnera des détails sur les programmes que nous finançons pour favoriser l'établissement de partenariats entre le secteur privé et les collectivités des Premières nations et inuites dans des domaines comme celui de l'exploitation des ressources naturelles.
b) Aspects positifs
Les indemnités d’assurance-récolte comprennent uniquement celles des programmes publics; les indemnités de l’assurance-grêle privée sont exclues.


2024-05-19 21:47:56.958686: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [85]:
train_en = train_ds.map(lambda train: train["en"])
train_fr = train_ds.map(lambda train: train["fr"])

### Generate the vocabulary

In [86]:
from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab

In [87]:
bert_tokenizer_params=dict(lower_case=True)
reserved_tokens=["[PAD]", "[UNK]", "[START]", "[END]"]

bert_vocab_args = dict(
    # The target vocabulary size
    vocab_size = 8000,
    # Reserved tokens that must be included in the vocabulary
    reserved_tokens=reserved_tokens,
    # Arguments for `text.BertTokenizer`
    bert_tokenizer_params=bert_tokenizer_params,
    # Arguments for `wordpiece_vocab.wordpiece_tokenizer_learner_lib.learn`
    learn_params={},
)

In [88]:
%%time
fr_vocab = bert_vocab.bert_vocab_from_dataset(
    train_fr.batch(1000).prefetch(2),
    **bert_vocab_args
)



2024-05-19 21:47:58.443344: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


CPU times: user 1min 25s, sys: 1.81 s, total: 1min 27s
Wall time: 1min 20s


In [89]:
print(fr_vocab[:10])
print(fr_vocab[100:110])
print(fr_vocab[1000:1010])
print(fr_vocab[-10:])

['[PAD]', '[UNK]', '[START]', '[END]', '!', '"', '#', '$', '%', '&']
['ʼ', 'ˆ', 'ˇ', 'α', 'β', 'γ', 'δ', 'ε', 'η', 'θ']
['propriete', 'attention', 'vos', 'assurance', '##ure', 'debat', 'http', 'donner', 'eux', 'protocole']
['##\uf76d', '##\uf76e', '##\uf76f', '##\uf770', '##\uf772', '##\uf773', '##\uf774', '##\uf775', '##\uf7e9', '##\uf8e7']


In [90]:
def write_vocab_file(filepath, vocab):
  with open(filepath, 'w') as f:
    for token in vocab:
      print(token, file=f)

In [91]:
write_vocab_file('fr_vocab.txt', fr_vocab)

In [92]:
%%time
en_vocab = bert_vocab.bert_vocab_from_dataset(
    train_en.batch(1000).prefetch(2),
    **bert_vocab_args
)

2024-05-19 21:49:18.259709: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


CPU times: user 1min 9s, sys: 1.14 s, total: 1min 11s
Wall time: 1min 6s


In [93]:
print(en_vocab[:10])
print(en_vocab[100:110])
print(en_vocab[1000:1010])
print(en_vocab[-10:])

['[PAD]', '[UNK]', '[START]', '[END]', '!', '"', '#', '$', '%', '&']
['ʼ', 'ˆ', 'α', 'β', 'γ', 'δ', 'ε', 'η', 'θ', 'ι']
['forces', 'days', 'final', 'pay', 'un', '29', 'attention', 'capital', 'prevention', 'previous']
['##\uf766', '##\uf767', '##\uf769', '##\uf76e', '##\uf76f', '##\uf772', '##\uf774', '##\uf775', '##\uf8e7', '##�']


In [94]:
write_vocab_file('en_vocab.txt', en_vocab)

### Build the tokenizer

In [95]:
fr_tokenizer = text.BertTokenizer('fr_vocab.txt', **bert_tokenizer_params)
en_tokenizer = text.BertTokenizer('en_vocab.txt', **bert_tokenizer_params)

In [96]:
for batch in train_ds.batch(3).take(1):
    en_examples = batch["en"]
    fr_examples = batch["fr"]
    for ex in en_examples:
        print(ex.numpy())

b'In his briefing on economic development, Al Horner will give you details of programs we fund to foster partnerships between the private sector and First Nations and Inuit communities, in areas like resource development projects, for example.'
b'(b) Positive aspects'
b'Crop insurance payments include only government crop insurance programs; private hail insurance payments are excluded.'


2024-05-19 21:50:24.245867: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [97]:
# Tokenize the examples -> (batch, word, word-piece)
token_batch = en_tokenizer.tokenize(en_examples)
# Merge the word and word-piece axes -> (batch, tokens)
token_batch = token_batch.merge_dims(-2,-1)

for ex in token_batch.to_list():
  print(ex)

[295, 431, 3203, 298, 400, 331, 15, 1438, 49, 5212, 390, 316, 1013, 351, 1735, 292, 700, 324, 802, 294, 2793, 1740, 379, 291, 777, 588, 293, 378, 363, 293, 2442, 721, 15, 295, 519, 593, 1155, 331, 617, 15, 296, 691, 17]
[11, 43, 12, 1278, 1310]
[4060, 1164, 1409, 574, 391, 354, 4060, 1164, 700, 30, 777, 5578, 1578, 1164, 1409, 306, 3922, 17]


In [98]:
# Lookup each token id in the vocabulary.
txt_tokens = tf.gather(en_vocab, token_batch)
# Join with spaces.
tf.strings.reduce_join(txt_tokens, separator=' ', axis=-1)

<tf.Tensor: shape=(3,), dtype=string, numpy=
array([b'in his briefing on economic development , al h ##orn ##er will give you details of programs we fund to foster partnerships between the private sector and first nations and inuit communities , in areas like resource development projects , for example .',
       b'( b ) positive aspects',
       b'crop insurance payments include only government crop insurance programs ; private ha ##il insurance payments are excluded .'],
      dtype=object)>

In [99]:
words = en_tokenizer.detokenize(token_batch)
tf.strings.reduce_join(words, separator=' ', axis=-1)

<tf.Tensor: shape=(3,), dtype=string, numpy=
array([b'in his briefing on economic development , al horner will give you details of programs we fund to foster partnerships between the private sector and first nations and inuit communities , in areas like resource development projects , for example .',
       b'( b ) positive aspects',
       b'crop insurance payments include only government crop insurance programs ; private hail insurance payments are excluded .'],
      dtype=object)>

### Customization and export

#### Custom tokenization

In [100]:
START = tf.argmax(tf.constant(reserved_tokens) == "[START]")
END = tf.argmax(tf.constant(reserved_tokens) == "[END]")

def add_start_end(ragged):
  count = ragged.bounding_shape()[0]
  starts = tf.fill([count,1], START)
  ends = tf.fill([count,1], END)
  return tf.concat([starts, ragged, ends], axis=1)

In [101]:
words = en_tokenizer.detokenize(add_start_end(token_batch))
tf.strings.reduce_join(words, separator=' ', axis=-1)

<tf.Tensor: shape=(3,), dtype=string, numpy=
array([b'[START] in his briefing on economic development , al horner will give you details of programs we fund to foster partnerships between the private sector and first nations and inuit communities , in areas like resource development projects , for example . [END]',
       b'[START] ( b ) positive aspects [END]',
       b'[START] crop insurance payments include only government crop insurance programs ; private hail insurance payments are excluded . [END]'],
      dtype=object)>

#### Custom detokenization

In [102]:
def cleanup_text(reserved_tokens, token_txt):
  # Drop the reserved tokens, except for "[UNK]".
  bad_tokens = [re.escape(tok) for tok in reserved_tokens if tok != "[UNK]"]
  bad_token_re = "|".join(bad_tokens)

  bad_cells = tf.strings.regex_full_match(token_txt, bad_token_re)
  result = tf.ragged.boolean_mask(token_txt, ~bad_cells)

  # Join them into strings.
  result = tf.strings.reduce_join(result, separator=' ', axis=-1)

  return result

In [103]:
en_examples.numpy()

array([b'In his briefing on economic development, Al Horner will give you details of programs we fund to foster partnerships between the private sector and First Nations and Inuit communities, in areas like resource development projects, for example.',
       b'(b) Positive aspects',
       b'Crop insurance payments include only government crop insurance programs; private hail insurance payments are excluded.'],
      dtype=object)

In [104]:
token_batch = en_tokenizer.tokenize(en_examples).merge_dims(-2,-1)
words = en_tokenizer.detokenize(token_batch)
words

<tf.RaggedTensor [[b'in', b'his', b'briefing', b'on', b'economic', b'development', b',',
  b'al', b'horner', b'will', b'give', b'you', b'details', b'of',
  b'programs', b'we', b'fund', b'to', b'foster', b'partnerships',
  b'between', b'the', b'private', b'sector', b'and', b'first', b'nations',
  b'and', b'inuit', b'communities', b',', b'in', b'areas', b'like',
  b'resource', b'development', b'projects', b',', b'for', b'example', b'.'],
 [b'(', b'b', b')', b'positive', b'aspects'],
 [b'crop', b'insurance', b'payments', b'include', b'only', b'government',
  b'crop', b'insurance', b'programs', b';', b'private', b'hail',
  b'insurance', b'payments', b'are', b'excluded', b'.']                  ]>

In [105]:
cleanup_text(reserved_tokens, words).numpy()

array([b'in his briefing on economic development , al horner will give you details of programs we fund to foster partnerships between the private sector and first nations and inuit communities , in areas like resource development projects , for example .',
       b'( b ) positive aspects',
       b'crop insurance payments include only government crop insurance programs ; private hail insurance payments are excluded .'],
      dtype=object)

### Export

In [106]:
class CustomTokenizer(tf.Module):
  def __init__(self, reserved_tokens, vocab_path):
    self.tokenizer = text.BertTokenizer(vocab_path, lower_case=True)
    self._reserved_tokens = reserved_tokens
    self._vocab_path = tf.saved_model.Asset(vocab_path)

    vocab = pathlib.Path(vocab_path).read_text().splitlines()
    self.vocab = tf.Variable(vocab)

    ## Create the signatures for export:   

    # Include a tokenize signature for a batch of strings. 
    self.tokenize.get_concrete_function(
        tf.TensorSpec(shape=[None], dtype=tf.string))

    # Include `detokenize` and `lookup` signatures for:
    #   * `Tensors` with shapes [tokens] and [batch, tokens]
    #   * `RaggedTensors` with shape [batch, tokens]
    self.detokenize.get_concrete_function(
        tf.TensorSpec(shape=[None, None], dtype=tf.int64))
    self.detokenize.get_concrete_function(
          tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int64))

    self.lookup.get_concrete_function(
        tf.TensorSpec(shape=[None, None], dtype=tf.int64))
    self.lookup.get_concrete_function(
          tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int64))

    # These `get_*` methods take no arguments
    self.get_vocab_size.get_concrete_function()
    self.get_vocab_path.get_concrete_function()
    self.get_reserved_tokens.get_concrete_function()

  @tf.function
  def tokenize(self, strings):
    enc = self.tokenizer.tokenize(strings)
    # Merge the `word` and `word-piece` axes.
    enc = enc.merge_dims(-2,-1)
    enc = add_start_end(enc)
    return enc

  @tf.function
  def detokenize(self, tokenized):
    words = self.tokenizer.detokenize(tokenized)
    return cleanup_text(self._reserved_tokens, words)

  @tf.function
  def lookup(self, token_ids):
    return tf.gather(self.vocab, token_ids)

  @tf.function
  def get_vocab_size(self):
    return tf.shape(self.vocab)[0]

  @tf.function
  def get_vocab_path(self):
    return self._vocab_path

  @tf.function
  def get_reserved_tokens(self):
    return tf.constant(self._reserved_tokens)

In [107]:
tokenizers = tf.Module()
tokenizers.fr = CustomTokenizer(reserved_tokens, 'fr_vocab.txt')
tokenizers.en = CustomTokenizer(reserved_tokens, 'en_vocab.txt')

In [108]:
model_name = 'fr_en_tokenizer'
tf.saved_model.save(tokenizers, model_name)

In [109]:
reloaded_tokenizers = tf.saved_model.load(model_name)
reloaded_tokenizers.en.get_vocab_size().numpy()

7955

In [110]:
tokens = reloaded_tokenizers.en.tokenize(['Hello TensorFlow!'])
tokens.numpy()

array([[   2,  429, 2033,  423, 2541, 7258,  668, 4688,    4,    3]])

In [111]:
text_tokens = reloaded_tokenizers.en.lookup(tokens)
text_tokens

<tf.RaggedTensor [[b'[START]', b'##i', b'settlement', b'part', b'saint', b'##pelled',
  b'result', b'##ew', b'!', b'[END]']]>

In [112]:
round_trip = reloaded_tokenizers.en.detokenize(tokens)

print(round_trip.numpy()[0].decode('utf-8'))

hello tensorflow !


In [114]:
!zip -r {model_name}.zip {model_name}

  adding: fr_en_tokenizer/ (stored 0%)
  adding: fr_en_tokenizer/variables/ (stored 0%)
  adding: fr_en_tokenizer/variables/variables.data-00000-of-00001 (deflated 51%)
  adding: fr_en_tokenizer/variables/variables.index (deflated 33%)
  adding: fr_en_tokenizer/assets/ (stored 0%)
  adding: fr_en_tokenizer/assets/en_vocab.txt (deflated 54%)
  adding: fr_en_tokenizer/assets/fr_vocab.txt (deflated 57%)
  adding: fr_en_tokenizer/saved_model.pb (deflated 91%)
  adding: fr_en_tokenizer/fingerprint.pb (stored 0%)
