Skip to content

Commit

Permalink
Merge pull request #49 from google-research/en_de
Browse files Browse the repository at this point in the history
Adds vocabulary for en_de paracrawl translation
  • Loading branch information
ramasesh committed Jan 14, 2021
2 parents 5e51864 + a88fccf commit d8d3af5
Show file tree
Hide file tree
Showing 5 changed files with 60,219 additions and 6 deletions.
69 changes: 68 additions & 1 deletion renn/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os

from renn import utils
from renn.data.tokenizers import load_tokenizer, SEP
from renn.data.tokenizers import load_tokenizer, SEP, tensor_punctuation_separator
from renn.data import data_utils

__all__ = [
Expand Down Expand Up @@ -41,6 +41,13 @@ def tokenize_fun(tokenizer):
return utils.compose(tokenizer.tokenize, wsp.tokenize, text.case_fold_utf8)


def tokenize_w_punctuation(tokenizer):
"""Text processing function which splits off punctuation."""
wsp = text.WhitespaceTokenizer()
return utils.compose(tokenizer.tokenize, wsp.tokenize,
tensor_punctuation_separator, text.case_fold_utf8)


def padded_batch(dset, batch_size, sequence_length, label_shape=()):
"""Pads examples to a fixed length, and collects them into batches."""

Expand Down Expand Up @@ -107,6 +114,66 @@ def load_csv(name, split, preprocess_fun, filter_fn=None, data_dir=None):
return dset


def paracrawl(language_pair,
vocab_files,
sequence_length,
batch_size=64,
transform_fn=utils.identity,
filter_fn=None,
data_dir=None):
"""Loads a paracrawl translation dataset from TFDS.
Arguments:
language_pair: str, e.g. 'ende', specifying both languages.
vocab_files: List[str], vocab filenames for each language.
"""

PARACRAWL_LANGUAGE_PAIRS = [
'enbg', 'encs', 'enda', 'ende', 'enel', 'enes', 'enet', 'enfi', 'enfr',
'enga', 'enhr', 'enhu', 'enit', 'enlt', 'enlv', 'enmt', 'ennl', 'enpl',
'enpt', 'enro', 'ensk', 'ensl', 'ensv'
]

if language_pair not in PARACRAWL_LANGUAGE_PAIRS:
raise ValueError(f'language_pair must be one of {PARACRAWL_LANGUAGE_PAIRS}')
languages = [language_pair[:2], language_pair[2:]]

tokenizer_list = [
tokenize_w_punctuation(load_tokenizer(f)) for f in vocab_files
]
tokenizer_dict = dict(zip(languages, tokenizer_list))

def _preprocess(d):
tokens = {l: tokenizer_dict[l](d[l]).flat_values for l in languages}
for l in languages:
tokens.update({f'{l}_index': tf.size(tokens[l])})
tokens.update({f'{l}_orig': d[l]})
return transform_fn(tokens)

dataset = tfds.load(
f'para_crawl/{language_pair}',
split='train', # para_crawl only has a train split
data_dir=data_dir)

dset = pipeline(dataset, preprocess_fun=_preprocess, filter_fn=filter_fn)

# Filter out examples longer than sequence length.
for l in languages:
dset = dset.filter(lambda d: d[f'{l}_index'] <= sequence_length)

# We assume the dataset contains inputs, labels, and an index.
padded_shapes = {}
for l in languages:
padded_shapes[f'{l}_index'] = ()
padded_shapes[f'{l}_orig'] = ()
padded_shapes[l] = (sequence_length,)

# Pad remaining examples to the sequence length.
dset = dset.padded_batch(batch_size, padded_shapes)

return dset, tokenizer_dict


def ag_news(split,
vocab_file,
sequence_length=100,
Expand Down
114 changes: 110 additions & 4 deletions renn/data/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
import itertools

from renn.data import wordpiece_tokenizer_learner_lib as vocab_learner
from renn.utils import identity

import tensorflow_text as text
import tensorflow as tf
import tensorflow.strings as strings

from typing import Optional, Callable

import re

__all__ = ['build_vocab', 'load_tokenizer']

Expand All @@ -15,6 +21,104 @@
UNK = '<unk>'
CLS = '<cls>'
SEP = '<sep>'
EOS = '<eos>'
BOS = '<bos>'

def punctuation_separator(s: str) -> str:
"""Separates punctuation at the end of word and end of line."""
punctuation_chars = ['.', ',', ':', ';', '?', '!']
special_chars = ['.', '?'] # regex chars
s_ = s
for c in punctuation_chars:
if c in special_chars:
# separate punctuation at end-of-line
s_ = re.sub(f'\{c}$', f' {c}', s_)
# separate punctuation at end-of-word
s_ = re.sub(f'\{c} ', f' {c} ', s_)
else:
# separate punctuation at end-of-line
s_ = re.sub(f'{c}$', f' {c}', s_)
# separate punctuation at end-of-word
s_ = re.sub(f'{c} ', f' {c} ', s_)
return s_


def tensor_punctuation_separator(x: tf.Tensor) -> tf.Tensor:
"""Separates punctuation at the end of word and end of line.
In behavior this function is identical to punctuation_separator
above. The only difference is that this acts on TF Tensors rather
than strings."""
punctuation_chars = ['.', ',', ':', ';', '?', '!']
special_chars = ['.', '?'] # regex chars

for c in punctuation_chars:
if c in special_chars:
# separate punctuation at end-of-line
x = strings.regex_replace(x, f'\{c}$', f' {c}')
# separate punctuation at end-of-word
x = strings.regex_replace(x, f'\{c} ', f' {c} ')
else:
# separate punctuation at end-of-line
x = strings.regex_replace(x, f'{c}$', f' {c}')
# separate punctuation at end-of-word
x = strings.regex_replace(x, f'{c} ', f' {c} ')
return x

def lowercase_strip(x: str) -> str:
"""Lowercases and strips punctuation.
Can be used as a transform_fn for text_generator()."""

return punctuation_separator(x.lower())

def text_generator(dataset: dict,
split: str,
language: str,
num_examples: int,
transform_fn: Callable[[str],str]=identity):
"""Builds a generator from a TF dataset.
Given a dataset, returns a generator which yields single-language examples
from that dataset, one at a time as strings.
Arguments:
dataset: dictionary of datasets.
split: 'train', 'test', etc. dataset[split] should yield an iterable.
language: which language to generate text from.
transform_fn: string transformation, defaults to identity fn.
num_examples: desired length of the generator.
"""

it = iter(dataset[split])
for count in range(num_examples):
yield transform_fn(next(it)[language].numpy().decode('UTF-8'))

def build_vocab_tr(corpus_generator, vocab_size, split_fun=str.split):
"""Builds a vocab file from a text generator for translation.
Unlike build_vocab() below, these vocabularies will have 3
reserved tokens:
<unk> - unknown
<bos> - beginning of sentence
<eos> - end of sentence
This also does not include a joiner token.
"""

# Split documents into words.
words = itertools.chain(*map(split_fun, corpus_generator))

# Count words in the corpus.
word_counts = Counter(words)

# Find the most common words
most_common_words = sorted(word_counts, key=word_counts.get,
reverse=True)[:vocab_size]

reserved_tokens = [UNK, EOS, BOS]
vocab = reserved_tokens + list(most_common_words)

return vocab


def build_vocab(corpus_generator, vocab_size, split_fun=str.split):
Expand Down Expand Up @@ -50,10 +154,12 @@ def load_tokenizer(vocab_file, default_value=-1):
"""Loads a tokenizer from a vocab file."""

# Build lookup table that maps subwords to ids.
table = tf.lookup.TextFileInitializer(vocab_file, tf.string,
tf.lookup.TextFileIndex.WHOLE_LINE,
tf.int64,
tf.lookup.TextFileIndex.LINE_NUMBER)
table = tf.lookup.TextFileInitializer(
filename=vocab_file,
key_dtype=tf.string,
key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
value_dtype=tf.int64,
value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
static_table = tf.lookup.StaticHashTable(table, default_value)

# Build tokenizer.
Expand Down
36 changes: 35 additions & 1 deletion tests/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
import tempfile
from renn.data import datasets
from renn.data.tokenizers import load_tokenizer
from renn.data.tokenizers import load_tokenizer, punctuation_separator


@pytest.fixture
Expand All @@ -40,3 +40,37 @@ def test_tokenizer_fun(vocab):
actual = list(tokenize("this is a test.").flat_values.numpy())
expected = [5, 6, 2, 6, 3, 4, 7]
assert actual == expected


@pytest.fixture
def sentences():
"""Test sentences for punctuation separation"""

with_punctuation = [
'Hello, how are you?', 'Pictures of Quebec: roofs of Quebec',
'In a not-so-shocking survey, they discovered.',
'linens, hair dryer, shower cabin, washing machine per arrangement',
'Los Feliz (Greater L.A.), which nowadays belongs to actress.',
'word1.word2', 'word1,word2', 'word1:word2', 'word1;word2', 'word1?word2',
'word1!word2', 'word1. word2', 'word1, word2', 'word1: word2',
'word1; word2', 'word1? word2', 'word1! word2', 'word.', 'word,', 'word:',
'word;', 'word?', 'word!'
]

processed_ideal = [
'Hello , how are you ?', 'Pictures of Quebec : roofs of Quebec',
'In a not-so-shocking survey , they discovered .',
'linens , hair dryer , shower cabin , washing machine per arrangement',
'Los Feliz (Greater L.A.) , which nowadays belongs to actress .',
'word1.word2', 'word1,word2', 'word1:word2', 'word1;word2', 'word1?word2',
'word1!word2', 'word1 . word2', 'word1 , word2', 'word1 : word2',
'word1 ; word2', 'word1 ? word2', 'word1 ! word2', 'word .', 'word ,',
'word :', 'word ;', 'word ?', 'word !'
]

return (with_punctuation, processed_ideal)


def test_punctuation_separator(sentences):
for sentence, processed in zip(*sentences):
assert punctuation_separator(sentence) == processed

0 comments on commit d8d3af5

Please sign in to comment.