Skip to content

Commit

Permalink
[ML] add sentence-piece unigram tokenizer (#88858)
Browse files Browse the repository at this point in the history
Add internal unigram tokenizer.

This tokenizer is the same that XLM-Roberta utilizes, along with many other cross-lingual models and tasks.

This does not fully integrate (adding configuration, integrating into nlp tasks, etc.). But instead is just the internal tokenization and some tests showing how it runs with a precompiled charsmap.
  • Loading branch information
benwtrent committed Jul 29, 2022
1 parent 6b8dab7 commit 9f2b96d
Show file tree
Hide file tree
Showing 5 changed files with 790 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.core.Strings.format;

public class DelimitedToken {

static DelimitedToken mergeTokens(List<DelimitedToken> tokens) {
Expand Down Expand Up @@ -67,6 +70,25 @@ public String toString() {
}

public static class Encoded extends DelimitedToken {
static DelimitedToken.Encoded mergeEncodedTokens(List<DelimitedToken.Encoded> tokens) {
if (tokens.size() == 1) {
return tokens.get(0);
}
int startOffSet = tokens.get(0).startOffset();
int endOffset = tokens.get(tokens.size() - 1).endOffset();
final int encoding = tokens.get(0).encoding;
List<CharSequence> sequences = new ArrayList<>(tokens.size());
for (var t : tokens) {
if (t.encoding != encoding) {
throw new IllegalArgumentException(
format("all merged tokens must have the same encoding, expected [%s]; found [%s]", encoding, t.encoding)
);
}
sequences.add(t.charSequence());
}
return new DelimitedToken.Encoded(new MultiCharSequence(sequences), tokens.get(0).encoding, startOffSet, endOffset);
}

private final int encoding;

public Encoded(CharSequence charSequence, int encoding, int startOffset, int endOffset) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,22 @@

import com.ibm.icu.text.BreakIterator;

import org.apache.lucene.analysis.charfilter.BaseCharFilter;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.CharsRefBuilder;
import org.apache.lucene.util.UnicodeUtil;

import java.io.CharArrayReader;
import java.io.IOException;
import java.io.Reader;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.OptionalInt;
Expand All @@ -39,10 +47,15 @@
* DARTS
* </a>
* - <a href="https://github.com/google/sentencepiece/blob/91809e5c70ed0e6364267a0f0fed66c144482ce4/src/normalizer.cc">SP normalizer</a>
*
* We implement this as a char filter to take advantage of the underlying offset correction and because normalization needs to occur before
* tokenization (just like a charfilter)
*/
public class PrecompiledCharMapNormalizer {
public class PrecompiledCharMapNormalizer extends BaseCharFilter {

record Config(int[] offsets, String utf8str) {}

static PrecompiledCharMapNormalizer fromBase64Str(String s) {
static Config fromBase64Str(String s) {
int offset = 0;
byte[] bytes = Base64.getDecoder().decode(s);
int trieSize = ByteBuffer.wrap(bytes, offset, 4).order(java.nio.ByteOrder.LITTLE_ENDIAN).getInt();
Expand All @@ -54,7 +67,7 @@ static PrecompiledCharMapNormalizer fromBase64Str(String s) {
offset += 4;
}
String utf8Str = new String(bytes, offset, bytes.length - offset, StandardCharsets.UTF_8);
return new PrecompiledCharMapNormalizer(offsets, utf8Str);
return new Config(offsets, utf8Str);
}

// The offsets for each normalization piece. Used in DARTS algorithm to iterate and find appropriate section
Expand All @@ -64,8 +77,12 @@ static PrecompiledCharMapNormalizer fromBase64Str(String s) {
private final byte[] normalizedStrUtf8Bytes;
// Continually reused to copy a single char into utf8 bytes
private final byte[] reusableCharByteBuffer = new byte[4];
// reusable char buffer for decoding utf8 bytes to determine char offset corrections
private final char[] reusableCharDecodeBuffer = new char[8];
private Reader transformedInput;

public PrecompiledCharMapNormalizer(int[] offsets, String normalizedStr) {
public PrecompiledCharMapNormalizer(int[] offsets, String normalizedStr, Reader in) {
super(in);
this.offsets = offsets;
this.normalizedStrUtf8Bytes = normalizedStr.getBytes(StandardCharsets.UTF_8);
}
Expand Down Expand Up @@ -152,11 +169,7 @@ private Optional<BytesRef> normalizePart(byte[] strBytes, int offset, int len) {
return Optional.of(new BytesRef(normalizedStrUtf8Bytes, firstIndex, secondIndex - firstIndex));
}

String normalize(String str) {
return normalize((CharSequence) str).utf8ToString();
}

BytesRef normalize(CharSequence str) {
Reader normalize(CharSequence str) {
// We need to iterate actual Unicode graphemes (this includes surrogate pairs, etc.)
ByteBuffer byteBuffer = StandardCharsets.UTF_8.encode(CharBuffer.wrap(str));
byte[] strBytes = new byte[byteBuffer.limit()];
Expand All @@ -167,9 +180,10 @@ BytesRef normalize(CharSequence str) {
// We iterate the whole string, so b.first() is always `0`
int startIter = b.first();
int codePointPos = 0;
BytesRefBuilder strBuilder = new BytesRefBuilder();
CharsRefBuilder strBuilder = new CharsRefBuilder();
strBuilder.grow(strBytes.length);
int bytePos = 0;
int normalizedCharPos = 0;
// Keep in mind, these break points aren't necessarily surrogate pairs, but also codepoints that contain a combining mark
for (int end = b.next(); end != BreakIterator.DONE; startIter = end, end = b.next()) {
int byteLen = 0;
Expand All @@ -181,28 +195,69 @@ BytesRef normalize(CharSequence str) {
// The trie only go up to a depth of 5 bytes.
// So even looking at it for graphemes (with combining, surrogate, etc.) that are 6+ bytes in length is useless.
if (byteLen < 6) {
Optional<BytesRef> subStr = normalizePart(strBytes, bytePos, byteLen);
if (subStr.isPresent()) {
strBuilder.append(subStr.get());
Optional<BytesRef> maybeSubStr = normalizePart(strBytes, bytePos, byteLen);
if (maybeSubStr.isPresent()) {
BytesRef subStr = maybeSubStr.get();
int numChars = UnicodeUtil.UTF8toUTF16(subStr.bytes, subStr.offset, subStr.length, reusableCharDecodeBuffer);
normalizedCharPos += numChars;
if (numChars != end - startIter) {
addOffCorrectMap(normalizedCharPos, getLastCumulativeDiff() + end - startIter - numChars);
}
strBuilder.append(reusableCharDecodeBuffer, 0, numChars);
bytePos += byteLen;
continue;
}
}
int charByteIndex = 0;
for (int i = startIter; i < end; i++) {
int utf8CharBytes = numUtf8Bytes(str.charAt(i));
Optional<BytesRef> subStr = normalizePart(strBytes, charByteIndex + bytePos, utf8CharBytes);
if (subStr.isPresent()) {
strBuilder.append(subStr.get());
Optional<BytesRef> maybeSubStr = normalizePart(strBytes, charByteIndex + bytePos, utf8CharBytes);
if (maybeSubStr.isPresent()) {
BytesRef subStr = maybeSubStr.get();
int numChars = UnicodeUtil.UTF8toUTF16(subStr.bytes, subStr.offset, subStr.length, reusableCharDecodeBuffer);
normalizedCharPos += numChars;
// Meaning we removed this char
if (numChars < 1) {
addOffCorrectMap(normalizedCharPos, getLastCumulativeDiff() + 1);
} else if (numChars > 1) {
addOffCorrectMap(normalizedCharPos, getLastCumulativeDiff() - 1);
}
strBuilder.append(reusableCharDecodeBuffer, 0, numChars);
} else {
int numBytes = UnicodeUtil.UTF16toUTF8(str, i, 1, reusableCharByteBuffer);
strBuilder.append(reusableCharByteBuffer, 0, numBytes);
normalizedCharPos += 1;
strBuilder.append(str.charAt(i));
}
charByteIndex += utf8CharBytes;
}
bytePos += byteLen;
}
return strBuilder.get();
return new CharArrayReader(strBuilder.chars(), 0, strBuilder.length());
}

@Override
public int read(char[] cbuf, int off, int len) throws IOException {
if (transformedInput == null) {
fill();
}

return transformedInput.read(cbuf, off, len);
}

@Override
public int read() throws IOException {
if (transformedInput == null) {
fill();
}

return transformedInput.read();
}

private void fill() throws IOException {
List<CharSequence> charArrays = new ArrayList<>();
char[] temp = new char[1024];
for (int cnt = input.read(temp); cnt > 0; cnt = input.read(temp)) {
charArrays.add(new CharsRef(Arrays.copyOfRange(temp, 0, cnt), 0, cnt));
}
transformedInput = normalize(new MultiCharSequence(charArrays));
}
}

0 comments on commit 9f2b96d

Please sign in to comment.