diff --git a/scripts/extra/esm.py b/scripts/extra/esm.py index ac42f8a14..e4cba50c6 100644 --- a/scripts/extra/esm.py +++ b/scripts/extra/esm.py @@ -2,33 +2,37 @@ from tokenizers import Tokenizer, pre_tokenizers, processors from tokenizers.models import WordPiece + class EsmConverter(Converter): def converted(self) -> Tokenizer: vocab = self.original_tokenizer.vocab - tokenizer = Tokenizer(WordPiece(vocab, continuing_subword_prefix='', max_input_chars_per_word=int(1e10), unk_token=str(self.original_tokenizer.unk_token))) + tokenizer = Tokenizer(WordPiece(vocab, continuing_subword_prefix='', max_input_chars_per_word=int( + 1e10), unk_token=str(self.original_tokenizer.unk_token))) tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit() - + cls = str(self.original_tokenizer.cls_token) cls_token_id = self.original_tokenizer.cls_token_id - sep = str(self.original_tokenizer.eos_token) # No sep token in ESM vocabulary + # No sep token in ESM vocabulary + sep = str(self.original_tokenizer.eos_token) sep_token_id = self.original_tokenizer.eos_token_id if sep_token_id is None: - tokenizer.post_processor = processors.TemplateProcessing( - single=f"{cls}:0 $A:0", - special_tokens=[ - (cls, cls_token_id), - ], - ) + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{cls}:0 $A:0", + special_tokens=[ + (cls, cls_token_id), + ], + ) else: - tokenizer.post_processor = processors.TemplateProcessing( - single=f"{cls}:0 $A:0 {sep}:0", - special_tokens=[ - (cls, cls_token_id), - (sep, sep_token_id), - ], - ) + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{cls}:0 $A:0 {sep}:0", + pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1", + special_tokens=[ + (cls, cls_token_id), + (sep, sep_token_id), + ], + ) # For some reason, all tokens are added: none of them are special, but they all need special splitting. # See https://github.com/huggingface/transformers/blob/df5c5c62ae253055336f5bb0828ca8e3e15ab6bd/src/transformers/models/esm/tokenization_esm.py#L79-L80 @@ -44,6 +48,7 @@ def converted(self) -> Tokenizer: tokenizer.add_tokens(other_tokens) return tokenizer + def generate_fast_tokenizer(tokenizer): tokenizer.vocab = tokenizer._token_to_id return EsmConverter(tokenizer).converted() diff --git a/scripts/supported_models.py b/scripts/supported_models.py index 0ea5c8c47..b5d0912df 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -109,6 +109,11 @@ 'unitary/toxic-bert', 'BAAI/bge-reranker-large', 'BAAI/bge-reranker-base', + 'cross-encoder/ms-marco-TinyBERT-L-2-v2', + 'cross-encoder/ms-marco-MiniLM-L-2-v2', + 'cross-encoder/ms-marco-MiniLM-L-4-v2', + 'cross-encoder/ms-marco-MiniLM-L-6-v2', + 'cross-encoder/ms-marco-MiniLM-L-12-v2', ], # Token classification diff --git a/src/models.js b/src/models.js index 21617249c..802f05722 100644 --- a/src/models.js +++ b/src/models.js @@ -42,10 +42,6 @@ import { AutoConfig, } from './configs.js'; -import { - add_token_types, -} from './tokenizers.js'; - import { Callable, isIntegralNumber, @@ -512,9 +508,13 @@ async function encoderForward(self, model_inputs) { encoderFeeds[key] = model_inputs[key]; } if (self.session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) { - // Assign default `token_type_ids` to the `encoderFeeds` if the model expects it, + // Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it, // but they weren't created by the tokenizer. - add_token_types(encoderFeeds); + encoderFeeds.token_type_ids = new Tensor( + 'int64', + new BigInt64Array(encoderFeeds.input_ids.data.length), + encoderFeeds.input_ids.dims + ) } return await sessionRun(self.session, encoderFeeds); } diff --git a/src/tokenizers.js b/src/tokenizers.js index 94f5def35..a786cd1ed 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -7,8 +7,8 @@ * ```javascript * import { AutoTokenizer } from '@xenova/transformers'; * - * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); - * let { input_ids } = await tokenizer('I love transformers!'); + * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); + * const { input_ids } = await tokenizer('I love transformers!'); * // Tensor { * // data: BigInt64Array(6) [101n, 1045n, 2293n, 19081n, 999n, 102n], * // dims: [1, 6], @@ -58,7 +58,7 @@ import { Template } from '@huggingface/jinja'; */ async function loadTokenizer(pretrained_model_name_or_path, options) { - let info = await Promise.all([ + const info = await Promise.all([ getModelJSON(pretrained_model_name_or_path, 'tokenizer.json', true, options), getModelJSON(pretrained_model_name_or_path, 'tokenizer_config.json', true, options), ]) @@ -196,20 +196,21 @@ function lowercase_and_remove_accent(text) { /** * Helper function to fuse consecutive values in an array equal to the specified value. - * @param {Array} arr The input array + * @param {string[]} arr The input array * @param {any} value The value to fuse on. + * @param {Map} mapping The mapping from input domain to value. */ -function fuse(arr, value) { - let fused = []; +function fuse(arr, value, mapping) { + const fused = []; let i = 0; while (i < arr.length) { fused.push(arr[i]) - if (arr[i] !== value) { + if ((mapping.get(arr[i]) ?? value) !== value) { ++i; continue; } - while (i < arr.length && arr[i] === value) { + while (i < arr.length && (mapping.get(arr[i]) ?? value) === value) { ++i; } } @@ -321,7 +322,12 @@ export class TokenizerModel extends Callable { * @returns {string[]} The encoded token IDs. */ _call(tokens) { - return this.encode(tokens); + let ids = this.encode(tokens); + if (this.fuse_unk) { + // Fuse unknown tokens + ids = fuse(ids, this.unk_token_id, this.tokens_to_ids); + } + return ids; } /** @@ -340,13 +346,7 @@ export class TokenizerModel extends Callable { * @returns {number[]} The converted token IDs. */ convert_tokens_to_ids(tokens) { - let ids = tokens.map(t => this.tokens_to_ids.get(t) ?? this.unk_token_id); - - if (this.fuse_unk) { - // Fuse unknown tokens - ids = fuse(ids, this.unk_token_id); - } - return ids; + return tokens.map(t => this.tokens_to_ids.get(t) ?? this.unk_token_id); } /** @@ -413,9 +413,9 @@ class WordPieceTokenizer extends TokenizerModel { * @returns {string[]} An array of encoded tokens. */ encode(tokens) { - let outputTokens = []; - for (let token of tokens) { - let chars = [...token]; + const outputTokens = []; + for (const token of tokens) { + const chars = [...token]; if (chars.length > this.max_input_chars_per_word) { outputTokens.push(this.unk_token); continue; @@ -423,7 +423,7 @@ class WordPieceTokenizer extends TokenizerModel { let isUnknown = false; let start = 0; - let subTokens = []; + const subTokens = []; while (start < chars.length) { let end = chars.length; @@ -553,12 +553,12 @@ class Unigram extends TokenizerModel { /** * Encodes an array of tokens using Unigram encoding. - * @param {Array} tokens The tokens to encode. - * @returns {Array} An array of encoded tokens. + * @param {string[]} tokens The tokens to encode. + * @returns {string[]} An array of encoded tokens. */ encode(tokens) { - let toReturn = []; - for (let token of tokens) { + const toReturn = []; + for (const token of tokens) { const tokenized = this.tokenize(token); toReturn.push(...tokenized); } @@ -582,7 +582,7 @@ const BYTES_TO_UNICODE = (() => { ...Array.from({ length: "¬".charCodeAt(0) - "¡".charCodeAt(0) + 1 }, (_, i) => i + "¡".charCodeAt(0)), ...Array.from({ length: "ÿ".charCodeAt(0) - "®".charCodeAt(0) + 1 }, (_, i) => i + "®".charCodeAt(0)), ]; - let cs = bs.slice(); + const cs = bs.slice(); let n = 0; for (let b = 0; b < 256; ++b) { if (!bs.includes(b)) { @@ -591,7 +591,7 @@ const BYTES_TO_UNICODE = (() => { n += 1; } } - let ccs = cs.map(n => String.fromCharCode(n)); + const ccs = cs.map(n => String.fromCharCode(n)); return Object.fromEntries(bs.map((b, i) => [b, ccs[i]])); })(); @@ -809,12 +809,12 @@ class BPE extends TokenizerModel { * @returns {string[]} The resulting subword tokens after applying the BPE algorithm to the input sequence of tokens. */ encode(tokens) { - let outputTokens = []; + const outputTokens = []; - for (let token of tokens) { - let bpe_token_list = this.bpe(token); + for (const token of tokens) { + const bpe_token_list = this.bpe(token); - for (let t of bpe_token_list) { + for (const t of bpe_token_list) { if (this.tokens_to_ids.has(t)) { outputTokens.push(t); } else { @@ -962,14 +962,10 @@ class Replace extends Normalizer { * @returns {string} The normalized text after replacing the pattern with the content. */ normalize(text) { - let pattern = createPattern(this.config.pattern); - if (pattern === null) { - return text; - } - - text = text.replaceAll(pattern, this.config.content) - - return text; + const pattern = createPattern(this.config.pattern); + return pattern === null + ? text + : text.replaceAll(pattern, this.config.content); } } @@ -1132,10 +1128,10 @@ class BertNormalizer extends Normalizer { */ _tokenize_chinese_chars(text) { /* Adds whitespace around any CJK character. */ - let output = []; + const output = []; for (let i = 0; i < text.length; ++i) { - let char = text[i]; - let cp = char.charCodeAt(0); + const char = text[i]; + const cp = char.charCodeAt(0); if (this._is_chinese_char(cp)) { output.push(" "); output.push(char); @@ -1318,13 +1314,10 @@ class PreTokenizer extends Callable { * @returns {string[]} An array of pre-tokens. */ pre_tokenize(text, options) { - let result = []; - if (Array.isArray(text)) { - result = text.map(x => this.pre_tokenize_text(x, options)) - } else { - result = this.pre_tokenize_text(text, options); - } - return result.flat(); + return (Array.isArray(text) + ? text.map(x => this.pre_tokenize_text(x, options)) + : this.pre_tokenize_text(text, options) + ).flat(); } /** @@ -1417,7 +1410,7 @@ class ByteLevelPreTokenizer extends PreTokenizer { } // Split on whitespace and punctuation - let tokens = this.use_regex ? (text.match(this.pattern) || []) : [text]; + const tokens = this.use_regex ? (text.match(this.pattern) || []) : [text]; // Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) return tokens.map( @@ -1526,6 +1519,21 @@ class DigitsPreTokenizer extends PreTokenizer { } } +/** + * @typedef {Object} PostProcessedOutput + * @property {string[]} tokens List of token produced by the post-processor. + * @property {number[]} [token_type_ids] List of token type ids produced by the post-processor. + */ + + +/** + * @typedef {Object} EncodingSingle + * @property {number[]} input_ids List of token ids to be fed to a model. + * @property {number[]} attention_mask List of token type ids to be fed to a model + * @property {number[]} [token_type_ids] List of indices specifying which tokens should be attended to by the model + */ + + /** * @extends Callable */ @@ -1570,7 +1578,7 @@ class PostProcessor extends Callable { * * @param {Array} tokens The input tokens to be post-processed. * @param {...*} args Additional arguments required by the post-processing logic. - * @returns {Array} The post-processed tokens. + * @returns {PostProcessedOutput} The post-processed tokens. * @throws {Error} If the method is not implemented in subclass. */ post_process(tokens, ...args) { @@ -1581,7 +1589,7 @@ class PostProcessor extends Callable { * Alias for {@link PostProcessor#post_process}. * @param {Array} tokens The text or array of texts to post-process. * @param {...*} args Additional arguments required by the post-processing logic. - * @returns {Array} An array of post-processed tokens. + * @returns {PostProcessedOutput} The post-processed tokens. */ _call(tokens, ...args) { return this.post_process(tokens, ...args); @@ -1608,18 +1616,29 @@ class BertProcessing extends PostProcessor { /** * Adds the special tokens to the beginning and end of the input. * @param {string[]} tokens The input tokens. - * @param {string[]|null} tokens_pair An optional second set of input tokens. - * @returns {string[]} The input tokens with the special tokens added to the beginning and end. + * @param {string[]} [tokens_pair=null] An optional second set of input tokens. + * @returns {PostProcessedOutput} The post-processed tokens with the special tokens added to the beginning and end. */ - post_process(tokens, tokens_pair = null) { - tokens = mergeArrays([this.cls], tokens, [this.sep]); + post_process(tokens, tokens_pair = null, { + add_special_tokens = true, + } = {}) { + if (add_special_tokens) { + tokens = mergeArrays([this.cls], tokens, [this.sep]); + } - // NOTE: It is intended to add 2 EOS tokens after the first set of tokens - // https://github.com/huggingface/tokenizers/issues/983 + let token_type_ids = new Array(tokens.length).fill(0); if (tokens_pair !== null) { - tokens = mergeArrays(tokens, [this.sep], tokens_pair, [this.sep]); + // NOTE: It is intended to add 2 EOS tokens after the first set of tokens + // https://github.com/huggingface/tokenizers/issues/983 + const middle = (add_special_tokens && this instanceof RobertaProcessing) + ? [this.sep] + : []; + const after = add_special_tokens ? [this.sep] : []; + + tokens = mergeArrays(tokens, middle, tokens_pair, after); + token_type_ids = mergeArrays(token_type_ids, new Array(tokens_pair.length + middle.length + after.length).fill(1)); } - return tokens; + return { tokens, token_type_ids }; } } class RobertaProcessing extends BertProcessing { } // NOTE: extends BertProcessing @@ -1644,28 +1663,35 @@ class TemplateProcessing extends PostProcessor { /** * Replaces special tokens in the template with actual tokens. - * @param {Array} tokens The list of tokens for the first sequence. - * @param {Array} [tokens_pair=null] The list of tokens for the second sequence (optional). - * @returns {Array} The list of tokens with the special tokens replaced with actual tokens. + * @param {string[]} tokens The list of tokens for the first sequence. + * @param {string[]} [tokens_pair=null] The list of tokens for the second sequence (optional). + * @returns {PostProcessedOutput} An object containing the list of tokens with the special tokens replaced with actual tokens. */ - post_process(tokens, tokens_pair = null) { - let type = tokens_pair === null ? this.single : this.pair + post_process(tokens, tokens_pair = null, { + add_special_tokens = true, + } = {}) { + const type = tokens_pair === null ? this.single : this.pair - let toReturn = []; - for (let item of type) { + let processedTokens = []; + let types = []; + for (const item of type) { if ('SpecialToken' in item) { - toReturn.push(item.SpecialToken.id); - + if (add_special_tokens) { + processedTokens.push(item.SpecialToken.id); + types.push(item.SpecialToken.type_id); + } } else if ('Sequence' in item) { if (item.Sequence.id === 'A') { - toReturn = mergeArrays(toReturn, tokens); + processedTokens = mergeArrays(processedTokens, tokens); + types = mergeArrays(types, new Array(tokens.length).fill(item.Sequence.type_id)); } else if (item.Sequence.id === 'B') { - toReturn = mergeArrays(toReturn, tokens_pair); + processedTokens = mergeArrays(processedTokens, tokens_pair); + types = mergeArrays(types, new Array(tokens_pair.length).fill(item.Sequence.type_id)); } } } - return toReturn; + return { tokens: processedTokens, token_type_ids: types }; } } @@ -1676,11 +1702,15 @@ class TemplateProcessing extends PostProcessor { class ByteLevelPostProcessor extends PostProcessor { /** * Post process the given tokens. - * @param {string[]} tokens The tokens to be post processed. - * @returns {string[]} The post processed tokens. + * @param {string[]} tokens The list of tokens for the first sequence. + * @param {string[]} [tokens_pair=null] The list of tokens for the second sequence (optional). + * @returns {PostProcessedOutput} An object containing the post-processed tokens. */ - post_process(tokens) { - return tokens; + post_process(tokens, tokens_pair = null) { + if (tokens_pair) { + tokens = mergeArrays(tokens, tokens_pair); + } + return { tokens }; } } @@ -1779,12 +1809,10 @@ class ReplaceDecoder extends Decoder { /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { - let pattern = createPattern(this.config.pattern); - if (pattern === null) { - return tokens; - } - - return tokens.map(token => token.replaceAll(pattern, this.config.content)) + const pattern = createPattern(this.config.pattern); + return pattern === null + ? tokens + : tokens.map(token => token.replaceAll(pattern, this.config.content)) } } @@ -1799,13 +1827,13 @@ class ByteFallback extends Decoder { /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { - let new_tokens = []; + const new_tokens = []; let previous_byte_tokens = []; - for (let token of tokens) { + for (const token of tokens) { let bytes = null; if (token.length === 6 && token.startsWith('<0x') && token.endsWith('>')) { - let byte = parseInt(token.slice(3, 5), 16); + const byte = parseInt(token.slice(3, 5), 16); if (!isNaN(byte)) { bytes = byte; } @@ -1814,7 +1842,7 @@ class ByteFallback extends Decoder { previous_byte_tokens.push(bytes); } else { if (previous_byte_tokens.length > 0) { - let string = this.text_decoder.decode(Uint8Array.from(previous_byte_tokens)); + const string = this.text_decoder.decode(Uint8Array.from(previous_byte_tokens)); new_tokens.push(string); previous_byte_tokens = []; } @@ -1822,7 +1850,7 @@ class ByteFallback extends Decoder { } } if (previous_byte_tokens.length > 0) { - let string = this.text_decoder.decode(Uint8Array.from(previous_byte_tokens)); + const string = this.text_decoder.decode(Uint8Array.from(previous_byte_tokens)); new_tokens.push(string); previous_byte_tokens = []; } @@ -1948,10 +1976,9 @@ class ByteLevelDecoder extends Decoder { * @returns {string} The decoded string. */ convert_tokens_to_string(tokens) { - let text = tokens.join(''); - - let byteArray = new Uint8Array([...text].map(c => this.byte_decoder[c])); - let decoded_text = this.text_decoder.decode(byteArray); + const text = tokens.join(''); + const byteArray = new Uint8Array([...text].map(c => this.byte_decoder[c])); + const decoded_text = this.text_decoder.decode(byteArray); return decoded_text; } @@ -1963,9 +1990,9 @@ class ByteLevelDecoder extends Decoder { // To avoid mixing byte-level and unicode for byte-level BPT // we need to build string separately for added tokens and byte-level tokens // cf. https://github.com/huggingface/transformers/issues/1133 - let sub_texts = []; + const sub_texts = []; let current_sub_text = []; - for (let token of tokens) { + for (const token of tokens) { // tokens sent here are already filtered, so we don't need to do this // if (skip_special_tokens && this.all_special_ids.includes(token)) { // continue; @@ -2013,7 +2040,7 @@ class CTCDecoder extends Decoder { if (tokens.length === 0) return ''; // group same tokens into non-repeating tokens in CTC style decoding - let grouped_tokens = [tokens[0]]; + const grouped_tokens = [tokens[0]]; for (let i = 1; i < tokens.length; ++i) { if (tokens[i] !== grouped_tokens.at(-1)) { grouped_tokens.push(tokens[i]); @@ -2021,7 +2048,7 @@ class CTCDecoder extends Decoder { } // filter self.pad_token which is used as CTC-blank token - let filtered_tokens = grouped_tokens.filter(token => token !== this.pad_token); + const filtered_tokens = grouped_tokens.filter(token => token !== this.pad_token); let text = filtered_tokens.join(''); if (this.cleanup) { @@ -2169,7 +2196,7 @@ class MetaspaceDecoder extends Decoder { /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { - let result = []; + const result = []; for (let i = 0; i < tokens.length; ++i) { let normalized = tokens[i].replaceAll(this.replacement, ' '); if (this.addPrefixSpace && i == 0 && normalized.startsWith(' ')) { @@ -2326,7 +2353,47 @@ const SPECIAL_TOKEN_ATTRIBUTES = [ // additional_special_tokens (TODO) ] +/** + * + * Helper function for padding values of an object, which are each arrays. + * NOTE: No additional checks are made here for validity of arguments. + * @param {Record} item The input object. + * @param {number} length The length to pad to. + * @param {(key: string) => any} value_fn Determine the value to fill the array, based on its key. + * @param {'right'|'left'} side Which side to pad the array. + * @private + */ +function padHelper(item, length, value_fn, side) { + for (const key of Object.keys(item)) { + const diff = length - item[key].length; + const value = value_fn(key); + + const padData = new Array(diff).fill(value); + item[key] = side === 'right' + ? mergeArrays(item[key], padData) + : mergeArrays(padData, item[key]); + } +} + +/** + * Helper function for truncating values of an object, which are each arrays. + * NOTE: No additional checks are made here for validity of arguments. + * @param {Record} item The input object. + * @param {number} length The length to truncate to. + * @private + */ +function truncateHelper(item, length) { + // Setting .length to a lower value truncates the array in-place: + // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/length + for (const key of Object.keys(item)) { + item[key].length = length; + } +} + + export class PreTrainedTokenizer extends Callable { + return_token_type_ids = false; + _default_chat_template = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}`; /** @@ -2342,11 +2409,8 @@ export class PreTrainedTokenizer extends Callable { // Construct parts of the tokenizer from the JSON this.normalizer = Normalizer.fromConfig(tokenizerJSON.normalizer); this.pre_tokenizer = PreTokenizer.fromConfig(tokenizerJSON.pre_tokenizer); - this.model = TokenizerModel.fromConfig(tokenizerJSON.model, tokenizerConfig); this.post_processor = PostProcessor.fromConfig(tokenizerJSON.post_processor); - - // TODO: maybe, allow this to be null; in which case, we use model as decoder too? this.decoder = Decoder.fromConfig(tokenizerJSON.decoder); // Add added_tokens to model @@ -2355,7 +2419,7 @@ export class PreTrainedTokenizer extends Callable { /** @type {AddedToken[]} */ this.added_tokens = []; - for (let addedToken of tokenizerJSON.added_tokens) { + for (const addedToken of tokenizerJSON.added_tokens) { const token = new AddedToken(addedToken); this.added_tokens.push(token); @@ -2411,6 +2475,7 @@ export class PreTrainedTokenizer extends Callable { this.do_lowercase_and_remove_accent = tokenizerConfig.do_lowercase_and_remove_accent ?? false; // TODO allow user to change this + /** @type {'right'|'left'} */ this.padding_side = 'right'; this.legacy = false; @@ -2426,8 +2491,8 @@ export class PreTrainedTokenizer extends Callable { * @throws {Error} If an object is found for a matching key and its __type property is not "AddedToken". */ getToken(...keys) { - for (let key of keys) { - let item = this._tokenizer_config[key]; + for (const key of keys) { + const item = this._tokenizer_config[key]; if (!item) continue; @@ -2462,7 +2527,7 @@ export class PreTrainedTokenizer extends Callable { legacy = null, } = {}) { - let info = await loadTokenizer(pretrained_model_name_or_path, { + const info = await loadTokenizer(pretrained_model_name_or_path, { progress_callback, config, cache_dir, @@ -2476,14 +2541,13 @@ export class PreTrainedTokenizer extends Callable { } /** - * This function can be overridden by a subclass to apply additional preprocessing - * to a model's input data. - * @param {Object} inputs An object containing input data as properties. - * @returns {Object} The modified inputs object. + * @typedef {number[]|number[][]|Tensor} BatchEncodingItem + * + * @typedef {Object} BatchEncoding Holds the output of the tokenizer's call function. + * @property {BatchEncodingItem} input_ids List of token ids to be fed to a model. + * @property {BatchEncodingItem} attention_mask List of indices specifying which tokens should be attended to by the model. + * @property {BatchEncodingItem} [token_type_ids] List of token type ids to be fed to a model. */ - prepare_model_inputs(inputs) { - return inputs; - } /** * Encode/tokenize the given text(s). @@ -2495,7 +2559,7 @@ export class PreTrainedTokenizer extends Callable { * @param {boolean} [options.truncation=null] Whether to truncate the input sequences. * @param {number} [options.max_length=null] Maximum length of the returned list and optionally padding length. * @param {boolean} [options.return_tensor=true] Whether to return the results as Tensors or arrays. - * @returns {{ input_ids: number[]|number[][]|Tensor, attention_mask: any[]|Tensor }} Object to be passed to the model. + * @returns {BatchEncoding} Object to be passed to the model. */ _call( // Required positional arguments @@ -2512,10 +2576,12 @@ export class PreTrainedTokenizer extends Callable { } = {}, ) { - /** @type {number[]|number[][]|Tensor} */ - let tokens; + const isBatched = Array.isArray(text); + + /** @type {EncodingSingle[]} */ + let encodedTokens; - if (Array.isArray(text)) { + if (isBatched) { if (text.length === 0) { throw Error('text array must be non-empty') } @@ -2528,12 +2594,12 @@ export class PreTrainedTokenizer extends Callable { throw Error('text and text_pair must have the same length') } - tokens = text.map( - (t, i) => this.encode(t, text_pair[i], { add_special_tokens }) + encodedTokens = text.map( + (t, i) => this._encode_plus(t, text_pair[i], { add_special_tokens }) ) } else { - tokens = text.map(x => this.encode(x, null, { add_special_tokens })); + encodedTokens = text.map(x => this._encode_plus(x, null, { add_special_tokens })); } } else { @@ -2546,7 +2612,7 @@ export class PreTrainedTokenizer extends Callable { } // For single input, we just wrap in an array, and then unwrap later. - tokens = [this.encode(text, text_pair, { add_special_tokens })]; + encodedTokens = [this._encode_plus(text, text_pair, { add_special_tokens })]; } // At this point, tokens is batched: [batch_size, tokens] // However, array may be jagged. So, we pad to max_length @@ -2556,60 +2622,61 @@ export class PreTrainedTokenizer extends Callable { max_length = this.model_max_length; } else { // Calculate max length from sequences - max_length = max(tokens.map(x => x.length))[0]; + max_length = max(encodedTokens.map(x => x.input_ids.length))[0]; + } + } else { + if (!truncation) { + console.warn(`Truncation was not explicitly activated but \`max_length\` is provided a specific value, please use \`truncation=true\` to explicitly truncate examples to max length.`) } } // Ensure it is less than model max length max_length = Math.min(max_length, this.model_max_length) - /** @type {any[]|Tensor} */ - let attention_mask = []; if (padding || truncation) { + // Perform padding and/or truncation - for (let i = 0; i < tokens.length; ++i) { - if (tokens[i].length === max_length) { - attention_mask.push(new Array(tokens[i].length).fill(1)) + for (let i = 0; i < encodedTokens.length; ++i) { + if (encodedTokens[i].input_ids.length === max_length) { continue; - } else if (tokens[i].length > max_length) { + } else if (encodedTokens[i].input_ids.length > max_length) { // possibly truncate if (truncation) { - tokens[i] = tokens[i].slice(0, max_length); + truncateHelper(encodedTokens[i], max_length); } - attention_mask.push(new Array(tokens[i].length).fill(1)) } else { // t.length < max_length + // possibly pad if (padding) { - let diff = max_length - tokens[i].length; - - if (this.padding_side === 'right') { - attention_mask.push( - (new Array(tokens[i].length).fill(1)).concat(new Array(diff).fill(0)) - ) - tokens[i].push(...new Array(diff).fill(this.pad_token_id)) - } else { // left - attention_mask.push( - (new Array(diff).fill(0)).concat(new Array(tokens[i].length).fill(1)) - ) - tokens[i].unshift(...new Array(diff).fill(this.pad_token_id)) - } - - } else { - attention_mask.push(new Array(tokens[i].length).fill(1)) + padHelper( + encodedTokens[i], + max_length, + key => key === 'input_ids' ? this.pad_token_id : 0, + this.padding_side + ); } } } - } else { - attention_mask = tokens.map(x => new Array(x.length).fill(1)) } + const result = {}; + if (return_tensor) { if (!(padding && truncation)) { // Not, guaranteed that all items have same length, so // we perform additional check - if (tokens.some(x => x.length !== tokens[0].length)) { + if ( + encodedTokens.some(x => { + for (const key of Object.keys(x)) { + if (x[key].length !== encodedTokens[0][key]?.length) { + return true; + } + } + return false; + }) + ) { throw Error( "Unable to create tensor, you should probably activate truncation and/or padding " + "with 'padding=true' and 'truncation=true' to have batched tensors with the same length." @@ -2620,38 +2687,30 @@ export class PreTrainedTokenizer extends Callable { // Now we actually convert to tensor // NOTE: In the same way as the python library, we return a batched tensor, regardless of // whether we have a single input or multiple inputs. - let dims = [tokens.length, tokens[0].length]; + const dims = [encodedTokens.length, encodedTokens[0].input_ids.length]; - tokens = new Tensor('int64', - BigInt64Array.from(tokens.flat().map(BigInt)), - dims - ); + for (const key of Object.keys(encodedTokens[0])) { + result[key] = new Tensor('int64', + BigInt64Array.from(encodedTokens.flatMap(x => x[key]).map(BigInt)), + dims + ); + } - attention_mask = new Tensor( - 'int64', - BigInt64Array.from(attention_mask.flat().map(BigInt)), - dims - ) } else { + for (const key of Object.keys(encodedTokens[0])) { + result[key] = encodedTokens.map(x => x[key]); + } + // If not returning a tensor, we match the input type - if (!Array.isArray(text)) { + if (!isBatched) { // Input was not batched, so we unwrap - tokens = tokens[0]; - attention_mask = attention_mask[0]; + for (const key of Object.keys(result)) { + result[key] = result[key][0]; + } } } - - // Finally, add attention mask, and possibly model-specific parameters - let modelInputs = { - input_ids: tokens, - attention_mask: attention_mask - } - - // Optional post-processing - modelInputs = this.prepare_model_inputs(modelInputs); - - return modelInputs + return /** @type {BatchEncoding} */(result); } /** @@ -2705,22 +2764,48 @@ export class PreTrainedTokenizer extends Callable { * @param {string|null} text_pair The optional second text to encode. * @param {Object} options An optional object containing the following properties: * @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model. - * @returns {number[]} An array of token IDs representing the encoded text(s). + * @returns {EncodingSingle} An object containing the encoded text. + * @private */ - encode(text, text_pair = null, { + _encode_plus(text, text_pair = null, { add_special_tokens = true, } = {}) { // Function called by users to encode possibly multiple texts - let tokens = this._encode_text(text); - let tokens2 = this._encode_text(text_pair); + const tokens = this._encode_text(text); + const tokens2 = this._encode_text(text_pair); - // TODO improve `add_special_tokens` and ensure correctness - let combinedTokens = (this.post_processor !== null && add_special_tokens) - ? this.post_processor(tokens, tokens2) - : mergeArrays(tokens ?? [], tokens2 ?? []); + const combinedTokens = this.post_processor + ? this.post_processor(tokens, tokens2, { add_special_tokens }) + : { tokens: mergeArrays(tokens ?? [], tokens2 ?? []) }; - let ids = this.model.convert_tokens_to_ids(combinedTokens); - return ids; + const input_ids = this.model.convert_tokens_to_ids(combinedTokens.tokens); + + const result = { + input_ids, + attention_mask: new Array(input_ids.length).fill(1), + } + if (this.return_token_type_ids && combinedTokens.token_type_ids) { + result.token_type_ids = combinedTokens.token_type_ids; + } + return result; + } + + /** + * Encodes a single text or a pair of texts using the model's tokenizer. + * + * @param {string} text The text to encode. + * @param {string|null} text_pair The optional second text to encode. + * @param {Object} options An optional object containing the following properties: + * @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model. + * @returns {number[]} An array of token IDs representing the encoded text(s). + */ + encode(text, text_pair = null, { + add_special_tokens = true, + } = {}) { + const { input_ids } = this._encode_plus(text, text_pair, { + add_special_tokens, + }); + return input_ids; } /** @@ -2916,116 +3001,53 @@ export class PreTrainedTokenizer extends Callable { } } -/** -* Helper method for adding `token_type_ids` to model inputs -* @param {Object} inputs An object containing the input ids and attention mask. -* @returns {Object} The prepared inputs object. -*/ -export function add_token_types(inputs) { - // TODO ensure correctness when token pair is present - if (inputs.input_ids instanceof Tensor) { - inputs.token_type_ids = new Tensor( - 'int64', - new BigInt64Array(inputs.input_ids.data.length), - inputs.input_ids.dims - ) - } else if (Array.isArray(inputs.input_ids)) { - - if (Array.isArray(inputs.input_ids[0])) { - // This means input is batched, so we need to batch the token_type_ids as well - inputs.token_type_ids = inputs.input_ids.map( - x => new Array(x.length).fill(0) - ) - } else { - inputs.token_type_ids = new Array(inputs.input_ids.length).fill(0); - } - } else { - throw new Error('Input ids must be a Tensor or an Array') - } - - return inputs; -} - /** * BertTokenizer is a class used to tokenize text for BERT models. * @extends PreTrainedTokenizer */ export class BertTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } /** * Albert tokenizer * @extends PreTrainedTokenizer */ export class AlbertTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class MobileBertTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class SqueezeBertTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class DebertaTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class DebertaV2Tokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class HerbertTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class ConvBertTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class RoFormerTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class DistilBertTokenizer extends PreTrainedTokenizer { } export class CamembertTokenizer extends PreTrainedTokenizer { } export class XLMTokenizer extends PreTrainedTokenizer { + return_token_type_ids = true; + constructor(tokenizerJSON, tokenizerConfig) { super(tokenizerJSON, tokenizerConfig); console.warn('WARNING: `XLMTokenizer` is not yet supported by Hugging Face\'s "fast" tokenizers library. Therefore, you may experience slightly inaccurate results.') } - - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } } export class ElectraTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class T5Tokenizer extends PreTrainedTokenizer { } @@ -3172,7 +3194,7 @@ function _build_translation_inputs(self, raw_inputs, tokenizer_options, generate // In the same way as the Python library, we override the post-processor // to force the source language to be first: - for (let item of self.post_processor.config.single) { + for (const item of self.post_processor.config.single) { if ('SpecialToken' in item && self.languageRegex.test(item.SpecialToken.id)) { item.SpecialToken.id = self.lang_to_token(src_lang_token); break; @@ -3442,7 +3464,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer { const all_special_ids = new Set(this.all_special_ids); - for (let output of sequences) { + for (const output of sequences) { // NOTE: python version has batches, so it uses [0] const token_ids = output.tokens; const token_timestamps = returnWordTimestamps ? output.token_timestamps : null; @@ -3660,9 +3682,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer { } } if (returnWordTimestamps) { - let new_chunks = []; - for (let chunk of chunks) { - for (let word of chunk.words) { + const new_chunks = []; + for (const chunk of chunks) { + for (const word of chunk.words) { new_chunks.push(word); } } @@ -3773,9 +3795,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer { /** @private */ collateWordTimestamps(tokens, token_timestamps, language) { - let [words, _, token_indices] = this.combineTokensIntoWords(tokens, language); + const [words, _, token_indices] = this.combineTokensIntoWords(tokens, language); - let timings = []; + const timings = []; for (let i = 0; i < words.length; ++i) { const indices = token_indices[i]; timings.push({ @@ -3847,10 +3869,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer { const timestamp_begin = Array.from(this.all_special_ids).at(-1) + 1; /**@type {Array} */ let outputs = [[]]; - for (let token of token_ids) { + for (const token of token_ids) { if (token >= timestamp_begin) { - let timestamp = (token - timestamp_begin) * time_precision; - timestamp = round(timestamp, 2); + const timestamp = round((token - timestamp_begin) * time_precision, 2); outputs.push(`<|${timestamp}|>`); outputs.push([]); } else { @@ -3883,9 +3904,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer { }); const replacement_char = '\uFFFD'; - let words = [] - let word_tokens = [] - let token_indices = [] + const words = [] + const word_tokens = [] + const token_indices = [] let current_tokens = [] let current_indices = [] let unicode_offset = 0 @@ -3922,11 +3943,11 @@ export class WhisperTokenizer extends PreTrainedTokenizer { */ splitTokensOnSpaces(tokens) { - let [subwords, subword_tokens_list, subword_indices_list] = this.splitTokensOnUnicode(tokens); + const [subwords, subword_tokens_list, subword_indices_list] = this.splitTokensOnUnicode(tokens); - let words = [] - let word_tokens = [] - let token_indices = [] + const words = [] + const word_tokens = [] + const token_indices = [] const punctuationRegex = new RegExp(`^[${PUNCTUATION_REGEX}]$`, 'gu'); @@ -3969,9 +3990,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer { */ mergePunctuations(words, tokens, indices, prepended, appended) { - let newWords = structuredClone(words); - let newTokens = structuredClone(tokens); - let newIndices = structuredClone(indices); + const newWords = structuredClone(words); + const newTokens = structuredClone(tokens); + const newIndices = structuredClone(indices); // prepend punctuations @@ -4025,8 +4046,8 @@ export class WhisperTokenizer extends PreTrainedTokenizer { * **Example: Get ids for a language** * ```javascript * // instantiate the tokenizer and set the prefix token to Spanish - * let tokenizer = await WhisperTokenizer.from_pretrained('Xenova/whisper-tiny'); - * let forced_decoder_ids = tokenizer.get_decoder_prompt_ids({ language: 'spanish' }); + * const tokenizer = await WhisperTokenizer.from_pretrained('Xenova/whisper-tiny'); + * const forced_decoder_ids = tokenizer.get_decoder_prompt_ids({ language: 'spanish' }); * // [(1, 50262), (2, 50363)] * ``` * @@ -4049,7 +4070,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer { // <|lang_id|> <|task|> <|notimestamps|> - let forced_decoder_ids = []; + const forced_decoder_ids = []; if (language) { // User wishes to specify the language @@ -4074,7 +4095,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer { } } - let language_token_id = this.model.tokens_to_ids.get(`<|${language_code}|>`); + const language_token_id = this.model.tokens_to_ids.get(`<|${language_code}|>`); if (language_token_id === undefined) { throw new Error(`Unable to find language "${language_code}" in model vocabulary. Please report this issue at https://github.com/xenova/transformers.js/issues/new/choose.`) } @@ -4091,7 +4112,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer { throw new Error(`Task "${task}" is not supported. Must be one of: ["transcribe", "translate"]`); } - let task_token_id = this.model.tokens_to_ids.get(`<|${task}|>`); + const task_token_id = this.model.tokens_to_ids.get(`<|${task}|>`); if (task_token_id === undefined) { throw new Error(`Unable to find task "${task}" in model vocabulary. Please report this issue at https://github.com/xenova/transformers.js/issues/new/choose.`) } @@ -4103,7 +4124,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer { } if (no_timestamps) { - let no_timestamps_id = this.model.tokens_to_ids.get(`<|notimestamps|>`); + const no_timestamps_id = this.model.tokens_to_ids.get(`<|notimestamps|>`); if (no_timestamps_id === undefined) { throw new Error('Unable to find "<|notimestamps|>" in model vocabulary. Please report this issue at https://github.com/xenova/transformers.js/issues/new/choose.') } @@ -4153,7 +4174,7 @@ export class MarianTokenizer extends PreTrainedTokenizer { if (text === null) return null; // Check if text starts with language code: - let [matchInfo, ...remainder] = text.trim().split(this.languageRegex); + const [matchInfo, ...remainder] = text.trim().split(this.languageRegex); if (remainder.length === 0) { // No language code, encode normally @@ -4161,7 +4182,7 @@ export class MarianTokenizer extends PreTrainedTokenizer { } else if (remainder.length === 2) { // Text starts with language code, so we do not encode it with sentencepiece. - let [language, text] = remainder; + const [language, text] = remainder; if (!this.supported_language_codes.includes(language)) { console.warn(`Unsupported language code "${language}" detected, which may lead to unexpected behavior. Should be one of: ${JSON.stringify(this.supported_language_codes)}`) @@ -4197,7 +4218,7 @@ export class VitsTokenizer extends PreTrainedTokenizer { * The chosen tokenizer class is determined by the type specified in the tokenizer config. * * @example - * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); + * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); */ export class AutoTokenizer { static TOKENIZER_CLASS_MAPPING = { @@ -4272,7 +4293,7 @@ export class AutoTokenizer { legacy = null, } = {}) { - let [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, { + const [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, { quantized, progress_callback, config, @@ -4283,7 +4304,7 @@ export class AutoTokenizer { }) // Some tokenizers are saved with the "Fast" suffix, so we remove that if present. - let tokenizerName = tokenizerConfig.tokenizer_class?.replace(/Fast$/, '') ?? 'PreTrainedTokenizer'; + const tokenizerName = tokenizerConfig.tokenizer_class?.replace(/Fast$/, '') ?? 'PreTrainedTokenizer'; let cls = this.TOKENIZER_CLASS_MAPPING[tokenizerName]; if (!cls) { diff --git a/tests/generate_tests.py b/tests/generate_tests.py index e047802e9..d9c457b9d 100644 --- a/tests/generate_tests.py +++ b/tests/generate_tests.py @@ -169,6 +169,25 @@ }, } +TOKENIZER_TEXT_PAIR_TEST_DATA = [ + { + 'text': 'a', + 'text_pair': 'b' + }, + { + 'text': 'a b', + 'text_pair': 'c d e' + }, + { + 'text': ['a b c', 'd'], + 'text_pair': ['e f', 'g h'], + }, + { + 'text': ['a', 'b c', 'd e f'], + 'text_pair': ['g h i', 'j k', 'l'], + } +] + CHAT_MESSAGES_EXAMPLES = { 'basic': [ {"role": "user", "content": "Hello, how are you?"}, @@ -292,13 +311,23 @@ def generate_tokenizer_tests(): tokenizer_results = [] + for data in TOKENIZER_TEXT_PAIR_TEST_DATA: + try: + output = tokenizer(**data).data + except Exception: + # Ignore testing tokenizers which fail in the python library + continue + tokenizer_results.append(dict( + input=data, + output=output, + )) + shared_texts = TOKENIZER_TEST_DATA["shared"] custom_texts = TOKENIZER_TEST_DATA["custom"].get( tokenizer_name, []) # Run tokenizer on test cases for text in shared_texts + custom_texts + custom_by_model_type_texts: - # TODO: add with_pair option try: encoded = tokenizer(text).data except Exception: diff --git a/tests/tokenizers.test.js b/tests/tokenizers.test.js index 6bafa2365..40fed05d1 100644 --- a/tests/tokenizers.test.js +++ b/tests/tokenizers.test.js @@ -12,32 +12,46 @@ const { tokenization, templates } = await (await getFile('./tests/data/tokenizer // Dynamic tests to ensure transformers.js (JavaScript) matches transformers (Python) describe('Tokenizers (dynamic)', () => { - for (let [tokenizerName, tests] of Object.entries(tokenization)) { + for (const [tokenizerName, tests] of Object.entries(tokenization)) { it(tokenizerName, async () => { - let tokenizer = await AutoTokenizer.from_pretrained(m(tokenizerName)); + const tokenizer = await AutoTokenizer.from_pretrained(m(tokenizerName)); - for (let test of tests) { + for (const test of tests) { + // Two kinds of tests: + // 1. text w/o text_pair + // 2. text w text_pair - // Test encoding - let encoded = tokenizer(test.input, { - return_tensor: false - }); + if (typeof test.input === 'string') { + + // Test encoding + const encoded = tokenizer(test.input, { + return_tensor: false + }); - // Add the input text to the encoded object for easier debugging - test.encoded.input = encoded.input = test.input; + // Add the input text to the encoded object for easier debugging + test.encoded.input = encoded.input = test.input; - expect(encoded).toEqual(test.encoded); + expect(encoded).toEqual(test.encoded); - // Skip decoding tests if encoding produces zero tokens - if (test.encoded.input_ids.length === 0) continue; + // Skip decoding tests if encoding produces zero tokens + if (test.encoded.input_ids.length === 0) continue; - // Test decoding - let decoded_with_special = tokenizer.decode(encoded.input_ids, { skip_special_tokens: false }); - expect(decoded_with_special).toEqual(test.decoded_with_special); + // Test decoding + const decoded_with_special = tokenizer.decode(encoded.input_ids, { skip_special_tokens: false }); + expect(decoded_with_special).toEqual(test.decoded_with_special); - let decoded_without_special = tokenizer.decode(encoded.input_ids, { skip_special_tokens: true }); - expect(decoded_without_special).toEqual(test.decoded_without_special); + const decoded_without_special = tokenizer.decode(encoded.input_ids, { skip_special_tokens: true }); + expect(decoded_without_special).toEqual(test.decoded_without_special); + + } else { + const { text, text_pair } = test.input; + const encoded = tokenizer(text, { + text_pair, + return_tensor: false, + }); + compare(encoded, test.output); + } } }, MAX_TEST_EXECUTION_TIME); } @@ -140,6 +154,142 @@ describe('Tokenizers (hard-coded)', () => { } }); +describe('Tokenizer padding/truncation', () => { + const inputs = ['a', 'b c']; + const text_pair = ['d e', 'f g h']; + + it('should create a jagged array', async () => { + const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); + + { // support jagged array if `return_tensor=false` + const output = tokenizer(inputs, { + return_tensor: false, + }) + const expected = { + input_ids: [[101, 1037, 102], [101, 1038, 1039, 102]], + attention_mask: [[1, 1, 1], [1, 1, 1, 1]], + token_type_ids: [[0, 0, 0], [0, 0, 0, 0]] + } + compare(output, expected); + } + + { + const output = tokenizer(inputs, { + return_tensor: false, + truncation: true, + add_special_tokens: false, + }) + const expected = { + input_ids: [[1037], [1038, 1039]], + attention_mask: [[1], [1, 1]], + token_type_ids: [[0], [0, 0]] + } + compare(output, expected); + } + }) + + it('should create a tensor', async () => { + const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); + + { // Expected to throw error if jagged array + expect(() => tokenizer(inputs)).toThrowError('Unable to create tensor'); + } + + { // Truncation + const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { + truncation: true, + max_length: 1, + add_special_tokens: false, + }) + + expect(input_ids.tolist()).toEqual([[1037n], [1038n]]) + expect(attention_mask.tolist()).toEqual([[1n], [1n]]) + expect(token_type_ids.tolist()).toEqual([[0n], [0n]]) + } + { // Truncation w/ text pair + // TODO + } + + { // Padding + const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { + padding: true, + add_special_tokens: false, + }) + + expect(input_ids.tolist()).toEqual([[1037n, 0n], [1038n, 1039n]]) + expect(attention_mask.tolist()).toEqual([[1n, 0n], [1n, 1n]]) + expect(token_type_ids.tolist()).toEqual([[0n, 0n], [0n, 0n]]) + } + { // Padding w/ text pair + const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { + text_pair, + padding: true, + add_special_tokens: false, + }) + + expect(input_ids.tolist()).toEqual([ + [1037n, 1040n, 1041n, 0n, 0n], + [1038n, 1039n, 1042n, 1043n, 1044n], + ]); + expect(attention_mask.tolist()).toEqual([ + [1n, 1n, 1n, 0n, 0n], + [1n, 1n, 1n, 1n, 1n], + ]); + expect(token_type_ids.tolist()).toEqual([ + [0n, 1n, 1n, 0n, 0n], + [0n, 0n, 1n, 1n, 1n], + ]); + } + + { // Truncation + padding + const { input_ids, attention_mask, token_type_ids } = tokenizer(['a', 'b c', 'd e f'], { + padding: true, + truncation: true, + add_special_tokens: false, + max_length: 2, + }) + + expect(input_ids.tolist()).toEqual([[1037n, 0n], [1038n, 1039n], [1040n, 1041n]]) + expect(attention_mask.tolist()).toEqual([[1n, 0n], [1n, 1n], [1n, 1n]]) + expect(token_type_ids.tolist()).toEqual([[0n, 0n], [0n, 0n], [0n, 0n]]) + } + }, MAX_TEST_EXECUTION_TIME); +}); + +describe('Token type ids', () => { + it('should correctly add token type ids', async () => { + const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); + + const model_inputs = tokenizer( + ['a b c', 'd'], + { + text_pair: ['e f', 'g h'], + padding: true, + truncation: true, + return_tensor: false, + } + ); + + const expected = { + input_ids: [ + [101, 1037, 1038, 1039, 102, 1041, 1042, 102], + [101, 1040, 102, 1043, 1044, 102, 0, 0], + ], + token_type_ids: [ + [0, 0, 0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 0, 0], + ], + attention_mask: [ + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 0, 0], + ], + } + + compare(model_inputs, expected); + + }, MAX_TEST_EXECUTION_TIME); +}); + describe('Edge cases', () => { it('should not crash when encoding a very long string', async () => { let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small');