In [94]:
import tensorflow as tf
import tensorflow_text
from biomedicus.sentences.vocabulary import Vocabulary
vocabulary = Vocabulary('/Users/benknoll/BIOMEDICUS_DATA/sentences/vocab')

In [129]:
keys, values = zip(*vocabulary._character_to_id.items())
char_table = tf.lookup.StaticHashTable(tf.lookup.KeyValueTensorInitializer(keys, values), 11)

docs = ["The quick brown fox jumps\n over the lazy dog.", "Simplify, then add lightness."]
print('docs:', docs)
print(tf.constant(docs).shape)
tokenizer = tensorflow_text.WhitespaceTokenizer()
tokens, starts, ends = tokenizer.tokenize_with_offsets(docs)
print('tokens:', tokens)
print('starts:', starts)
print('ends:', ends)
first_prev_end = tf.fill([tokens.nrows(), 1], tf.cast(0, tf.int64))
prev_ends = tf.concat([first_prev_end, ends[:, :-1]], -1)
prior_lens = starts - prev_ends
priors = tf.strings.substr(tf.expand_dims(docs, -1), prev_ends.to_tensor(), prior_lens.to_tensor())
priors = tf.RaggedTensor.from_tensor(priors, lengths=prev_ends.row_lengths())
print('priors:', priors)
last_next_start = tf.fill(first_prev_end.shape, tf.int64.max)
next_starts = tf.concat([starts[:, 1:], last_next_start], -1)
post_lens = next_starts - ends
posts = tf.strings.substr(tf.expand_dims(docs, -1), ends.to_tensor(), post_lens.to_tensor())
posts = tf.RaggedTensor.from_tensor(posts, lengths=post_lens.row_lengths())
print('posts:', posts)
prev_token_marker = tf.expand_dims(tf.RaggedTensor.from_tensor(tf.fill(tokens.bounding_shape(), tf.cast(3, tf.int32)), lengths=tokens.row_lengths()), -1)
print('prev_token_marker:', prev_token_marker)
token_start_marker = tf.RaggedTensor.from_row_lengths(tf.fill(prev_token_marker.flat_values.shape, tf.cast(1, tf.int32)), prev_token_marker.row_lengths())
token_end_marker = tf.RaggedTensor.from_row_lengths(tf.fill(prev_token_marker.flat_values.shape, tf.cast(2, tf.int32)), prev_token_marker.row_lengths())
next_token_marker = tf.RaggedTensor.from_row_lengths(tf.fill(prev_token_marker.flat_values.shape, tf.cast(1, tf.int32)), prev_token_marker.row_lengths())
prior_char_ids = tf.ragged.map_flat_values(char_table.lookup, tf.strings.unicode_split(priors, 'UTF-8'))
token_char_ids = tf.ragged.map_flat_values(char_table.lookup, tf.strings.unicode_split(tokens, 'UTF-8'))
post_char_ids = tf.ragged.map_flat_values(char_table.lookup, tf.strings.unicode_split(posts, 'UTF-8'))
print('token_char_ids', token_char_ids)
print('token_start_marker', token_start_marker)
tf.concat([
    prev_token_marker,
    prior_char_ids,
    token_start_marker,
    token_char_ids,
    token_end_marker,
    post_char_ids,
    next_token_marker
], axis=-1)


docs: ['The quick brown fox jumps\n over the lazy dog.', 'Simplify, then add lightness.']
(2,)
tokens: <tf.RaggedTensor [[b'The', b'quick', b'brown', b'fox', b'jumps', b'over', b'the', b'lazy', b'dog.'], [b'Simplify,', b'then', b'add', b'lightness.']]>
starts: <tf.RaggedTensor [[0, 4, 10, 16, 20, 27, 32, 36, 41], [0, 10, 15, 19]]>
ends: <tf.RaggedTensor [[3, 9, 15, 19, 25, 31, 35, 40, 45], [9, 14, 18, 29]]>
priors: <tf.RaggedTensor [[b'', b' ', b' ', b' ', b' ', b'\n ', b' ', b' ', b' '], [b'', b' ', b' ', b' ']]>
posts: <tf.RaggedTensor [[b' ', b' ', b' ', b' ', b'\n ', b' ', b' ', b' ', b''], [b' ', b' ', b' ', b'']]>
prev_token_marker: <tf.RaggedTensor [[[3], [3], [3], [3], [3], [3], [3], [3], [3]], [[3], [3], [3], [3]]]>
token_char_ids <tf.RaggedTensor [[[63, 83, 80], [92, 96, 84, 78, 86], [77, 93, 90, 98, 89], [81, 90, 99], [85, 96, 88, 91, 94], [90, 97, 80, 93], [95, 83, 80], [87, 76, 101, 100], [79, 90, 82, 25]], [[62, 84, 88, 91, 87, 84, 81, 100, 23], [95, 83, 80, 89], [76, 79,

<tf.RaggedTensor [[[3, 1, 63, 83, 80, 2, 9, 1], [3, 9, 1, 92, 96, 84, 78, 86, 2, 9, 1], [3, 9, 1, 77, 93, 90, 98, 89, 2, 9, 1], [3, 9, 1, 81, 90, 99, 2, 9, 1], [3, 9, 1, 85, 96, 88, 91, 94, 2, 7, 9, 1], [3, 7, 9, 1, 90, 97, 80, 93, 2, 9, 1], [3, 9, 1, 95, 83, 80, 2, 9, 1], [3, 9, 1, 87, 76, 101, 100, 2, 9, 1], [3, 9, 1, 79, 90, 82, 25, 2, 1]], [[3, 1, 62, 84, 88, 91, 87, 84, 81, 100, 23, 2, 9, 1], [3, 9, 1, 95, 83, 80, 89, 2, 9, 1], [3, 9, 1, 76, 79, 79, 2, 9, 1], [3, 9, 1, 87, 84, 82, 83, 95, 89, 80, 94, 94, 25, 2, 1]]]>