From 9f27c059d67f73d74beae68eb8241d38cefe6340 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Jan 2024 03:45:30 +0200 Subject: [PATCH 01/13] Formatting --- scripts/extra/esm.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/scripts/extra/esm.py b/scripts/extra/esm.py index ac42f8a14..5bfdded52 100644 --- a/scripts/extra/esm.py +++ b/scripts/extra/esm.py @@ -2,33 +2,36 @@ from tokenizers import Tokenizer, pre_tokenizers, processors from tokenizers.models import WordPiece + class EsmConverter(Converter): def converted(self) -> Tokenizer: vocab = self.original_tokenizer.vocab - tokenizer = Tokenizer(WordPiece(vocab, continuing_subword_prefix='', max_input_chars_per_word=int(1e10), unk_token=str(self.original_tokenizer.unk_token))) + tokenizer = Tokenizer(WordPiece(vocab, continuing_subword_prefix='', max_input_chars_per_word=int( + 1e10), unk_token=str(self.original_tokenizer.unk_token))) tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit() - + cls = str(self.original_tokenizer.cls_token) cls_token_id = self.original_tokenizer.cls_token_id - sep = str(self.original_tokenizer.eos_token) # No sep token in ESM vocabulary + # No sep token in ESM vocabulary + sep = str(self.original_tokenizer.eos_token) sep_token_id = self.original_tokenizer.eos_token_id if sep_token_id is None: - tokenizer.post_processor = processors.TemplateProcessing( - single=f"{cls}:0 $A:0", - special_tokens=[ - (cls, cls_token_id), - ], - ) + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{cls}:0 $A:0", + special_tokens=[ + (cls, cls_token_id), + ], + ) else: - tokenizer.post_processor = processors.TemplateProcessing( - single=f"{cls}:0 $A:0 {sep}:0", - special_tokens=[ - (cls, cls_token_id), - (sep, sep_token_id), - ], - ) + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{cls}:0 $A:0 {sep}:0", + special_tokens=[ + (cls, cls_token_id), + (sep, sep_token_id), + ], + ) # For some reason, all tokens are added: none of them are special, but they all need special splitting. # See https://github.com/huggingface/transformers/blob/df5c5c62ae253055336f5bb0828ca8e3e15ab6bd/src/transformers/models/esm/tokenization_esm.py#L79-L80 @@ -44,6 +47,7 @@ def converted(self) -> Tokenizer: tokenizer.add_tokens(other_tokens) return tokenizer + def generate_fast_tokenizer(tokenizer): tokenizer.vocab = tokenizer._token_to_id return EsmConverter(tokenizer).converted() From 526a2efc1ae4b69059b553219beb4f2bf0c929c0 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Jan 2024 04:06:31 +0200 Subject: [PATCH 02/13] Update ESM pair template --- scripts/extra/esm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/extra/esm.py b/scripts/extra/esm.py index 5bfdded52..e4cba50c6 100644 --- a/scripts/extra/esm.py +++ b/scripts/extra/esm.py @@ -27,6 +27,7 @@ def converted(self) -> Tokenizer: else: tokenizer.post_processor = processors.TemplateProcessing( single=f"{cls}:0 $A:0 {sep}:0", + pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1", special_tokens=[ (cls, cls_token_id), (sep, sep_token_id), From c051aa24eab8eeff60e0d22b8fce91ccf3b36c7b Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Jan 2024 04:59:15 +0200 Subject: [PATCH 03/13] Fix token type ids --- src/tokenizers.js | 312 +++++++++++++++++++++++---------------- tests/generate_tests.py | 31 +++- tests/tokenizers.test.js | 82 +++++++--- 3 files changed, 279 insertions(+), 146 deletions(-) diff --git a/src/tokenizers.js b/src/tokenizers.js index 94f5def35..0358fcf52 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -196,20 +196,21 @@ function lowercase_and_remove_accent(text) { /** * Helper function to fuse consecutive values in an array equal to the specified value. - * @param {Array} arr The input array + * @param {string[]} arr The input array * @param {any} value The value to fuse on. + * @param {Map} mapping The mapping from input domain to value. */ -function fuse(arr, value) { +function fuse(arr, value, mapping) { let fused = []; let i = 0; while (i < arr.length) { fused.push(arr[i]) - if (arr[i] !== value) { + if ((mapping.get(arr[i]) ?? value) !== value) { ++i; continue; } - while (i < arr.length && arr[i] === value) { + while (i < arr.length && (mapping.get(arr[i]) ?? value) === value) { ++i; } } @@ -321,7 +322,12 @@ export class TokenizerModel extends Callable { * @returns {string[]} The encoded token IDs. */ _call(tokens) { - return this.encode(tokens); + let ids = this.encode(tokens); + if (this.fuse_unk) { + // Fuse unknown tokens + ids = fuse(ids, this.unk_token_id, this.tokens_to_ids); + } + return ids; } /** @@ -340,13 +346,7 @@ export class TokenizerModel extends Callable { * @returns {number[]} The converted token IDs. */ convert_tokens_to_ids(tokens) { - let ids = tokens.map(t => this.tokens_to_ids.get(t) ?? this.unk_token_id); - - if (this.fuse_unk) { - // Fuse unknown tokens - ids = fuse(ids, this.unk_token_id); - } - return ids; + return tokens.map(t => this.tokens_to_ids.get(t) ?? this.unk_token_id); } /** @@ -1526,6 +1526,21 @@ class DigitsPreTokenizer extends PreTokenizer { } } +/** + * @typedef {Object} PostProcessedOutput + * @property {string[]} tokens + * @property {number[]} [token_type_ids] + */ + + +/** + * @typedef {Object} EncodingSingle + * @property {number[]} input_ids + * @property {number[]} attention_mask + * @property {number[]} [token_type_ids] + */ + + /** * @extends Callable */ @@ -1570,7 +1585,7 @@ class PostProcessor extends Callable { * * @param {Array} tokens The input tokens to be post-processed. * @param {...*} args Additional arguments required by the post-processing logic. - * @returns {Array} The post-processed tokens. + * @returns {PostProcessedOutput} The post-processed tokens. * @throws {Error} If the method is not implemented in subclass. */ post_process(tokens, ...args) { @@ -1581,7 +1596,7 @@ class PostProcessor extends Callable { * Alias for {@link PostProcessor#post_process}. * @param {Array} tokens The text or array of texts to post-process. * @param {...*} args Additional arguments required by the post-processing logic. - * @returns {Array} An array of post-processed tokens. + * @returns {PostProcessedOutput} The post-processed tokens. */ _call(tokens, ...args) { return this.post_process(tokens, ...args); @@ -1608,18 +1623,20 @@ class BertProcessing extends PostProcessor { /** * Adds the special tokens to the beginning and end of the input. * @param {string[]} tokens The input tokens. - * @param {string[]|null} tokens_pair An optional second set of input tokens. - * @returns {string[]} The input tokens with the special tokens added to the beginning and end. + * @param {string[]} [tokens_pair=null] An optional second set of input tokens. + * @returns {PostProcessedOutput} The post-processed tokens with the special tokens added to the beginning and end. */ post_process(tokens, tokens_pair = null) { tokens = mergeArrays([this.cls], tokens, [this.sep]); - - // NOTE: It is intended to add 2 EOS tokens after the first set of tokens - // https://github.com/huggingface/tokenizers/issues/983 + let token_type_ids = new Array(tokens.length).fill(0); if (tokens_pair !== null) { - tokens = mergeArrays(tokens, [this.sep], tokens_pair, [this.sep]); + const middle = this instanceof RobertaProcessing ? [this.sep] : []; + // NOTE: It is intended to add 2 EOS tokens after the first set of tokens + // https://github.com/huggingface/tokenizers/issues/983 + tokens = mergeArrays(tokens, middle, tokens_pair, [this.sep]); + token_type_ids = mergeArrays(token_type_ids, new Array(tokens_pair.length + 1 + middle.length).fill(1)); } - return tokens; + return { tokens, token_type_ids }; } } class RobertaProcessing extends BertProcessing { } // NOTE: extends BertProcessing @@ -1644,28 +1661,32 @@ class TemplateProcessing extends PostProcessor { /** * Replaces special tokens in the template with actual tokens. - * @param {Array} tokens The list of tokens for the first sequence. - * @param {Array} [tokens_pair=null] The list of tokens for the second sequence (optional). - * @returns {Array} The list of tokens with the special tokens replaced with actual tokens. + * @param {string[]} tokens The list of tokens for the first sequence. + * @param {string[]} [tokens_pair=null] The list of tokens for the second sequence (optional). + * @returns {PostProcessedOutput} An object containing the list of tokens with the special tokens replaced with actual tokens. */ post_process(tokens, tokens_pair = null) { let type = tokens_pair === null ? this.single : this.pair - let toReturn = []; + let processedTokens = []; + let types = []; for (let item of type) { if ('SpecialToken' in item) { - toReturn.push(item.SpecialToken.id); + processedTokens.push(item.SpecialToken.id); + types.push(item.SpecialToken.type_id); } else if ('Sequence' in item) { if (item.Sequence.id === 'A') { - toReturn = mergeArrays(toReturn, tokens); + processedTokens = mergeArrays(processedTokens, tokens); + types = mergeArrays(types, new Array(tokens.length).fill(item.Sequence.type_id)); } else if (item.Sequence.id === 'B') { - toReturn = mergeArrays(toReturn, tokens_pair); + processedTokens = mergeArrays(processedTokens, tokens_pair); + types = mergeArrays(types, new Array(tokens_pair.length).fill(item.Sequence.type_id)); } } } - return toReturn; + return { tokens: processedTokens, token_type_ids: types }; } } @@ -1676,11 +1697,15 @@ class TemplateProcessing extends PostProcessor { class ByteLevelPostProcessor extends PostProcessor { /** * Post process the given tokens. - * @param {string[]} tokens The tokens to be post processed. - * @returns {string[]} The post processed tokens. + * @param {string[]} tokens The list of tokens for the first sequence. + * @param {string[]} [tokens_pair=null] The list of tokens for the second sequence (optional). + * @returns {PostProcessedOutput} An object containing the post-processed tokens. */ - post_process(tokens) { - return tokens; + post_process(tokens, tokens_pair = null) { + if (tokens_pair) { + tokens = mergeArrays(tokens, tokens_pair); + } + return { tokens }; } } @@ -2327,6 +2352,8 @@ const SPECIAL_TOKEN_ATTRIBUTES = [ ] export class PreTrainedTokenizer extends Callable { + return_token_type_ids = false; + _default_chat_template = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}`; /** @@ -2512,8 +2539,8 @@ export class PreTrainedTokenizer extends Callable { } = {}, ) { - /** @type {number[]|number[][]|Tensor} */ - let tokens; + /** @type {EncodingSingle[]} */ + let encodedTokens; if (Array.isArray(text)) { if (text.length === 0) { @@ -2528,12 +2555,12 @@ export class PreTrainedTokenizer extends Callable { throw Error('text and text_pair must have the same length') } - tokens = text.map( - (t, i) => this.encode(t, text_pair[i], { add_special_tokens }) + encodedTokens = text.map( + (t, i) => this._encode_plus(t, text_pair[i], { add_special_tokens }) ) } else { - tokens = text.map(x => this.encode(x, null, { add_special_tokens })); + encodedTokens = text.map(x => this._encode_plus(x, null, { add_special_tokens })); } } else { @@ -2546,7 +2573,7 @@ export class PreTrainedTokenizer extends Callable { } // For single input, we just wrap in an array, and then unwrap later. - tokens = [this.encode(text, text_pair, { add_special_tokens })]; + encodedTokens = [this._encode_plus(text, text_pair, { add_special_tokens })]; } // At this point, tokens is batched: [batch_size, tokens] // However, array may be jagged. So, we pad to max_length @@ -2556,60 +2583,80 @@ export class PreTrainedTokenizer extends Callable { max_length = this.model_max_length; } else { // Calculate max length from sequences - max_length = max(tokens.map(x => x.length))[0]; + max_length = max(encodedTokens.map(x => x.input_ids.length))[0]; } } // Ensure it is less than model max length max_length = Math.min(max_length, this.model_max_length) - /** @type {any[]|Tensor} */ - let attention_mask = []; if (padding || truncation) { // Perform padding and/or truncation - for (let i = 0; i < tokens.length; ++i) { - if (tokens[i].length === max_length) { - attention_mask.push(new Array(tokens[i].length).fill(1)) + for (let i = 0; i < encodedTokens.length; ++i) { + if (encodedTokens[i].input_ids.length === max_length) { continue; - } else if (tokens[i].length > max_length) { + } else if (encodedTokens[i].input_ids.length > max_length) { // possibly truncate if (truncation) { - tokens[i] = tokens[i].slice(0, max_length); + encodedTokens[i].input_ids = encodedTokens[i].input_ids.slice(0, max_length); + encodedTokens[i].attention_mask = encodedTokens[i].attention_mask.slice(0, max_length); + + if (encodedTokens[i].token_type_ids) { + encodedTokens[i].token_type_ids = encodedTokens[i].token_type_ids.slice(0, max_length); + } } - attention_mask.push(new Array(tokens[i].length).fill(1)) } else { // t.length < max_length + // possibly pad if (padding) { - let diff = max_length - tokens[i].length; + let diff = max_length - encodedTokens[i].input_ids.length; if (this.padding_side === 'right') { - attention_mask.push( - (new Array(tokens[i].length).fill(1)).concat(new Array(diff).fill(0)) + encodedTokens[i].input_ids = mergeArrays( + encodedTokens[i].input_ids, new Array(diff).fill(this.pad_token_id), + ) + encodedTokens[i].attention_mask = mergeArrays( + encodedTokens[i].attention_mask, new Array(diff).fill(0), ) - tokens[i].push(...new Array(diff).fill(this.pad_token_id)) + if (encodedTokens[i].token_type_ids) { + encodedTokens[i].token_type_ids = mergeArrays( + encodedTokens[i].token_type_ids, new Array(diff).fill(0), + ) + } + } else { // left - attention_mask.push( - (new Array(diff).fill(0)).concat(new Array(tokens[i].length).fill(1)) + encodedTokens[i].input_ids = mergeArrays( + new Array(diff).fill(this.pad_token_id), encodedTokens[i].input_ids, ) - tokens[i].unshift(...new Array(diff).fill(this.pad_token_id)) + encodedTokens[i].attention_mask = mergeArrays( + new Array(diff).fill(0), encodedTokens[i].attention_mask, + ) + if (encodedTokens[i].token_type_ids) { + encodedTokens[i].token_type_ids = mergeArrays( + new Array(diff).fill(0), encodedTokens[i].token_type_ids, + ) + } } - - } else { - attention_mask.push(new Array(tokens[i].length).fill(1)) } } } - } else { - attention_mask = tokens.map(x => new Array(x.length).fill(1)) } + let input_ids; + let attention_mask; + let token_type_ids; + if (return_tensor) { if (!(padding && truncation)) { // Not, guaranteed that all items have same length, so // we perform additional check - if (tokens.some(x => x.length !== tokens[0].length)) { + if ( + encodedTokens.some(x => x.input_ids.length !== encodedTokens[0].input_ids.length) + || encodedTokens.some(x => x.attention_mask.length !== encodedTokens[0].attention_mask.length) + || encodedTokens.some(x => x.token_type_ids && x.token_type_ids.length !== encodedTokens[0].token_type_ids.length) + ) { throw Error( "Unable to create tensor, you should probably activate truncation and/or padding " + "with 'padding=true' and 'truncation=true' to have batched tensors with the same length." @@ -2620,37 +2667,52 @@ export class PreTrainedTokenizer extends Callable { // Now we actually convert to tensor // NOTE: In the same way as the python library, we return a batched tensor, regardless of // whether we have a single input or multiple inputs. - let dims = [tokens.length, tokens[0].length]; + let dims = [encodedTokens.length, encodedTokens[0].input_ids.length]; - tokens = new Tensor('int64', - BigInt64Array.from(tokens.flat().map(BigInt)), + input_ids = new Tensor('int64', + BigInt64Array.from(encodedTokens.flatMap(x => x.input_ids).map(BigInt)), dims ); attention_mask = new Tensor( 'int64', - BigInt64Array.from(attention_mask.flat().map(BigInt)), + BigInt64Array.from(encodedTokens.flatMap(x => x.attention_mask).map(BigInt)), dims - ) + ); + + if (encodedTokens[0].token_type_ids) { + token_type_ids = new Tensor( + 'int64', + BigInt64Array.from(encodedTokens.flatMap(x => x.token_type_ids).map(BigInt)), + dims + ); + } + } else { + input_ids = encodedTokens.map(x => x.input_ids); + attention_mask = encodedTokens.map(x => x.attention_mask); + + if (encodedTokens[0].token_type_ids) { + token_type_ids = encodedTokens.map(x => x.token_type_ids); + } + // If not returning a tensor, we match the input type if (!Array.isArray(text)) { // Input was not batched, so we unwrap - tokens = tokens[0]; + input_ids = input_ids[0]; attention_mask = attention_mask[0]; + + if (token_type_ids) { + token_type_ids = token_type_ids[0]; + } } } - - // Finally, add attention mask, and possibly model-specific parameters - let modelInputs = { - input_ids: tokens, - attention_mask: attention_mask + let modelInputs = { input_ids, attention_mask } + if (this.return_token_type_ids && token_type_ids) { + modelInputs.token_type_ids = token_type_ids; } - // Optional post-processing - modelInputs = this.prepare_model_inputs(modelInputs); - return modelInputs } @@ -2705,22 +2767,49 @@ export class PreTrainedTokenizer extends Callable { * @param {string|null} text_pair The optional second text to encode. * @param {Object} options An optional object containing the following properties: * @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model. - * @returns {number[]} An array of token IDs representing the encoded text(s). + * @returns {EncodingSingle} An object containing the encoded text. + * @private */ - encode(text, text_pair = null, { + _encode_plus(text, text_pair = null, { add_special_tokens = true, } = {}) { // Function called by users to encode possibly multiple texts - let tokens = this._encode_text(text); - let tokens2 = this._encode_text(text_pair); + const tokens = this._encode_text(text); + const tokens2 = this._encode_text(text_pair); // TODO improve `add_special_tokens` and ensure correctness - let combinedTokens = (this.post_processor !== null && add_special_tokens) + const combinedTokens = (this.post_processor !== null && add_special_tokens) ? this.post_processor(tokens, tokens2) - : mergeArrays(tokens ?? [], tokens2 ?? []); + : { tokens: mergeArrays(tokens ?? [], tokens2 ?? []) }; - let ids = this.model.convert_tokens_to_ids(combinedTokens); - return ids; + const input_ids = this.model.convert_tokens_to_ids(combinedTokens.tokens); + + const result = { + input_ids, + attention_mask: new Array(input_ids.length).fill(1), + } + if (combinedTokens.token_type_ids) { + result.token_type_ids = combinedTokens.token_type_ids; + } + return result; + } + + /** + * Encodes a single text or a pair of texts using the model's tokenizer. + * + * @param {string} text The text to encode. + * @param {string|null} text_pair The optional second text to encode. + * @param {Object} options An optional object containing the following properties: + * @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model. + * @returns {number[]} An array of token IDs representing the encoded text(s). + */ + encode(text, text_pair = null, { + add_special_tokens = true, + } = {}) { + const encoded = this._encode_plus(text, text_pair, { + add_special_tokens, + }); + return encoded.input_ids; } /** @@ -2951,81 +3040,48 @@ export function add_token_types(inputs) { * @extends PreTrainedTokenizer */ export class BertTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } /** * Albert tokenizer * @extends PreTrainedTokenizer */ export class AlbertTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class MobileBertTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class SqueezeBertTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class DebertaTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class DebertaV2Tokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class HerbertTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class ConvBertTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class RoFormerTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class DistilBertTokenizer extends PreTrainedTokenizer { } export class CamembertTokenizer extends PreTrainedTokenizer { } export class XLMTokenizer extends PreTrainedTokenizer { + return_token_type_ids = true; + constructor(tokenizerJSON, tokenizerConfig) { super(tokenizerJSON, tokenizerConfig); console.warn('WARNING: `XLMTokenizer` is not yet supported by Hugging Face\'s "fast" tokenizers library. Therefore, you may experience slightly inaccurate results.') } - - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } } export class ElectraTokenizer extends PreTrainedTokenizer { - /** @type {add_token_types} */ - prepare_model_inputs(inputs) { - return add_token_types(inputs); - } + return_token_type_ids = true; } export class T5Tokenizer extends PreTrainedTokenizer { } diff --git a/tests/generate_tests.py b/tests/generate_tests.py index e047802e9..d9c457b9d 100644 --- a/tests/generate_tests.py +++ b/tests/generate_tests.py @@ -169,6 +169,25 @@ }, } +TOKENIZER_TEXT_PAIR_TEST_DATA = [ + { + 'text': 'a', + 'text_pair': 'b' + }, + { + 'text': 'a b', + 'text_pair': 'c d e' + }, + { + 'text': ['a b c', 'd'], + 'text_pair': ['e f', 'g h'], + }, + { + 'text': ['a', 'b c', 'd e f'], + 'text_pair': ['g h i', 'j k', 'l'], + } +] + CHAT_MESSAGES_EXAMPLES = { 'basic': [ {"role": "user", "content": "Hello, how are you?"}, @@ -292,13 +311,23 @@ def generate_tokenizer_tests(): tokenizer_results = [] + for data in TOKENIZER_TEXT_PAIR_TEST_DATA: + try: + output = tokenizer(**data).data + except Exception: + # Ignore testing tokenizers which fail in the python library + continue + tokenizer_results.append(dict( + input=data, + output=output, + )) + shared_texts = TOKENIZER_TEST_DATA["shared"] custom_texts = TOKENIZER_TEST_DATA["custom"].get( tokenizer_name, []) # Run tokenizer on test cases for text in shared_texts + custom_texts + custom_by_model_type_texts: - # TODO: add with_pair option try: encoded = tokenizer(text).data except Exception: diff --git a/tests/tokenizers.test.js b/tests/tokenizers.test.js index 6bafa2365..cc2afb614 100644 --- a/tests/tokenizers.test.js +++ b/tests/tokenizers.test.js @@ -12,32 +12,46 @@ const { tokenization, templates } = await (await getFile('./tests/data/tokenizer // Dynamic tests to ensure transformers.js (JavaScript) matches transformers (Python) describe('Tokenizers (dynamic)', () => { - for (let [tokenizerName, tests] of Object.entries(tokenization)) { + for (const [tokenizerName, tests] of Object.entries(tokenization)) { it(tokenizerName, async () => { - let tokenizer = await AutoTokenizer.from_pretrained(m(tokenizerName)); + const tokenizer = await AutoTokenizer.from_pretrained(m(tokenizerName)); - for (let test of tests) { + for (const test of tests) { + // Two kinds of tests: + // 1. text w/o text_pair + // 2. text w text_pair - // Test encoding - let encoded = tokenizer(test.input, { - return_tensor: false - }); + if (typeof test.input === 'string') { + + // Test encoding + const encoded = tokenizer(test.input, { + return_tensor: false + }); + + // Add the input text to the encoded object for easier debugging + test.encoded.input = encoded.input = test.input; - // Add the input text to the encoded object for easier debugging - test.encoded.input = encoded.input = test.input; + expect(encoded).toEqual(test.encoded); - expect(encoded).toEqual(test.encoded); + // Skip decoding tests if encoding produces zero tokens + if (test.encoded.input_ids.length === 0) continue; - // Skip decoding tests if encoding produces zero tokens - if (test.encoded.input_ids.length === 0) continue; + // Test decoding + const decoded_with_special = tokenizer.decode(encoded.input_ids, { skip_special_tokens: false }); + expect(decoded_with_special).toEqual(test.decoded_with_special); - // Test decoding - let decoded_with_special = tokenizer.decode(encoded.input_ids, { skip_special_tokens: false }); - expect(decoded_with_special).toEqual(test.decoded_with_special); + const decoded_without_special = tokenizer.decode(encoded.input_ids, { skip_special_tokens: true }); + expect(decoded_without_special).toEqual(test.decoded_without_special); - let decoded_without_special = tokenizer.decode(encoded.input_ids, { skip_special_tokens: true }); - expect(decoded_without_special).toEqual(test.decoded_without_special); + } else { + const { text, text_pair } = test.input; + const encoded = tokenizer(text, { + text_pair, + return_tensor: false, + }); + compare(encoded, test.output); + } } }, MAX_TEST_EXECUTION_TIME); } @@ -140,6 +154,40 @@ describe('Tokenizers (hard-coded)', () => { } }); +describe('Token type ids', () => { + it('should correctly add token type ids', async () => { + const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); + + const model_inputs = tokenizer( + ['a b c', 'd'], + { + text_pair: ['e f', 'g h'], + padding: true, + truncation: true, + return_tensor: false, + } + ); + + const expected = { + input_ids: [ + [101, 1037, 1038, 1039, 102, 1041, 1042, 102], + [101, 1040, 102, 1043, 1044, 102, 0, 0], + ], + token_type_ids: [ + [0, 0, 0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 0, 0], + ], + attention_mask: [ + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 0, 0], + ], + } + + compare(model_inputs, expected); + + }, MAX_TEST_EXECUTION_TIME); +}); + describe('Edge cases', () => { it('should not crash when encoding a very long string', async () => { let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small'); From 873a41d8804016aea887bab3f62e7ce17e17b9f9 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Jan 2024 05:00:13 +0200 Subject: [PATCH 04/13] Update JSDoc --- src/tokenizers.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tokenizers.js b/src/tokenizers.js index 0358fcf52..1e985df7a 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -1535,9 +1535,9 @@ class DigitsPreTokenizer extends PreTokenizer { /** * @typedef {Object} EncodingSingle - * @property {number[]} input_ids - * @property {number[]} attention_mask - * @property {number[]} [token_type_ids] + * @property {number[]} input_ids List of token ids to be fed to a model. + * @property {number[]} attention_mask List of token type ids to be fed to a model + * @property {number[]} [token_type_ids] List of indices specifying which tokens should be attended to by the model */ From e3bbd4c2aed12c2b8c6cddef7c5dc695726e2f44 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Jan 2024 13:54:37 +0200 Subject: [PATCH 05/13] Cleanup --- src/tokenizers.js | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/tokenizers.js b/src/tokenizers.js index 1e985df7a..cc76883c0 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -1528,8 +1528,8 @@ class DigitsPreTokenizer extends PreTokenizer { /** * @typedef {Object} PostProcessedOutput - * @property {string[]} tokens - * @property {number[]} [token_type_ids] + * @property {string[]} tokens List of token produced by the post-processor. + * @property {number[]} [token_type_ids] List of token type ids produced by the post-processor. */ @@ -2369,11 +2369,8 @@ export class PreTrainedTokenizer extends Callable { // Construct parts of the tokenizer from the JSON this.normalizer = Normalizer.fromConfig(tokenizerJSON.normalizer); this.pre_tokenizer = PreTokenizer.fromConfig(tokenizerJSON.pre_tokenizer); - this.model = TokenizerModel.fromConfig(tokenizerJSON.model, tokenizerConfig); this.post_processor = PostProcessor.fromConfig(tokenizerJSON.post_processor); - - // TODO: maybe, allow this to be null; in which case, we use model as decoder too? this.decoder = Decoder.fromConfig(tokenizerJSON.decoder); // Add added_tokens to model @@ -2539,10 +2536,12 @@ export class PreTrainedTokenizer extends Callable { } = {}, ) { + const isBatched = Array.isArray(text); + /** @type {EncodingSingle[]} */ let encodedTokens; - if (Array.isArray(text)) { + if (isBatched) { if (text.length === 0) { throw Error('text array must be non-empty') } @@ -2667,7 +2666,7 @@ export class PreTrainedTokenizer extends Callable { // Now we actually convert to tensor // NOTE: In the same way as the python library, we return a batched tensor, regardless of // whether we have a single input or multiple inputs. - let dims = [encodedTokens.length, encodedTokens[0].input_ids.length]; + const dims = [encodedTokens.length, encodedTokens[0].input_ids.length]; input_ids = new Tensor('int64', BigInt64Array.from(encodedTokens.flatMap(x => x.input_ids).map(BigInt)), @@ -2680,7 +2679,7 @@ export class PreTrainedTokenizer extends Callable { dims ); - if (encodedTokens[0].token_type_ids) { + if (this.return_token_type_ids && encodedTokens[0].token_type_ids) { token_type_ids = new Tensor( 'int64', BigInt64Array.from(encodedTokens.flatMap(x => x.token_type_ids).map(BigInt)), @@ -2692,24 +2691,24 @@ export class PreTrainedTokenizer extends Callable { input_ids = encodedTokens.map(x => x.input_ids); attention_mask = encodedTokens.map(x => x.attention_mask); - if (encodedTokens[0].token_type_ids) { + if (this.return_token_type_ids && encodedTokens[0].token_type_ids) { token_type_ids = encodedTokens.map(x => x.token_type_ids); } // If not returning a tensor, we match the input type - if (!Array.isArray(text)) { + if (!isBatched) { // Input was not batched, so we unwrap input_ids = input_ids[0]; attention_mask = attention_mask[0]; - if (token_type_ids) { + if (this.return_token_type_ids && token_type_ids) { token_type_ids = token_type_ids[0]; } } } let modelInputs = { input_ids, attention_mask } - if (this.return_token_type_ids && token_type_ids) { + if (token_type_ids) { modelInputs.token_type_ids = token_type_ids; } @@ -2806,10 +2805,10 @@ export class PreTrainedTokenizer extends Callable { encode(text, text_pair = null, { add_special_tokens = true, } = {}) { - const encoded = this._encode_plus(text, text_pair, { + const { input_ids } = this._encode_plus(text, text_pair, { add_special_tokens, }); - return encoded.input_ids; + return input_ids; } /** From aacc018bd7200254b9e093a0765076afc260b5bb Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Jan 2024 13:55:05 +0200 Subject: [PATCH 06/13] Remove unused `prepare_model_inputs` function --- src/tokenizers.js | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/tokenizers.js b/src/tokenizers.js index cc76883c0..2813a4172 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -2499,16 +2499,6 @@ export class PreTrainedTokenizer extends Callable { return new this(...info); } - /** - * This function can be overridden by a subclass to apply additional preprocessing - * to a model's input data. - * @param {Object} inputs An object containing input data as properties. - * @returns {Object} The modified inputs object. - */ - prepare_model_inputs(inputs) { - return inputs; - } - /** * Encode/tokenize the given text(s). * @param {string|string[]} text The text to tokenize. From 3bb8a429e46cd8d1dab1973afd4f96d9d9bf83b5 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Jan 2024 14:28:11 +0200 Subject: [PATCH 07/13] Move pad and truncate logic to helper functions --- src/tokenizers.js | 80 ++++++++++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 35 deletions(-) diff --git a/src/tokenizers.js b/src/tokenizers.js index 2813a4172..c0154959d 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -2351,6 +2351,44 @@ const SPECIAL_TOKEN_ATTRIBUTES = [ // additional_special_tokens (TODO) ] +/** + * + * Helper function for padding values of an object, which are each arrays. + * NOTE: No additional checks are made here for validity of arguments. + * @param {Record} item The input object. + * @param {number} length The length to pad to. + * @param {(key: string) => any} value_fn Determine the value to fill the array, based on its key. + * @param {'right'|'left'} side Which side to pad the array. + * @private + */ +function padHelper(item, length, value_fn, side) { + for (const key in Object.keys(item)) { + const diff = length - item[key].length; + const value = value_fn(key); + + const padData = new Array(diff).fill(value); + item[key] = side === 'right' + ? mergeArrays(item[key], padData) + : mergeArrays(padData, item[key]); + } +} + +/** + * Helper function for truncating values of an object, which are each arrays. + * NOTE: No additional checks are made here for validity of arguments. + * @param {Record} item The input object. + * @param {number} length The length to truncate to. + * @private + */ +function truncateHelper(item, length) { + // Setting .length to a lower value truncates the array in-place: + // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/length + for (const key in Object.keys(item)) { + item[key].length = length; + } +} + + export class PreTrainedTokenizer extends Callable { return_token_type_ids = false; @@ -2435,6 +2473,7 @@ export class PreTrainedTokenizer extends Callable { this.do_lowercase_and_remove_accent = tokenizerConfig.do_lowercase_and_remove_accent ?? false; // TODO allow user to change this + /** @type {'right'|'left'} */ this.padding_side = 'right'; this.legacy = false; @@ -2580,6 +2619,7 @@ export class PreTrainedTokenizer extends Callable { max_length = Math.min(max_length, this.model_max_length) if (padding || truncation) { + // Perform padding and/or truncation for (let i = 0; i < encodedTokens.length; ++i) { if (encodedTokens[i].input_ids.length === max_length) { @@ -2588,45 +2628,15 @@ export class PreTrainedTokenizer extends Callable { } else if (encodedTokens[i].input_ids.length > max_length) { // possibly truncate if (truncation) { - encodedTokens[i].input_ids = encodedTokens[i].input_ids.slice(0, max_length); - encodedTokens[i].attention_mask = encodedTokens[i].attention_mask.slice(0, max_length); - - if (encodedTokens[i].token_type_ids) { - encodedTokens[i].token_type_ids = encodedTokens[i].token_type_ids.slice(0, max_length); - } + truncateHelper(encodedTokens[i], max_length); } } else { // t.length < max_length // possibly pad if (padding) { - let diff = max_length - encodedTokens[i].input_ids.length; - - if (this.padding_side === 'right') { - encodedTokens[i].input_ids = mergeArrays( - encodedTokens[i].input_ids, new Array(diff).fill(this.pad_token_id), - ) - encodedTokens[i].attention_mask = mergeArrays( - encodedTokens[i].attention_mask, new Array(diff).fill(0), - ) - if (encodedTokens[i].token_type_ids) { - encodedTokens[i].token_type_ids = mergeArrays( - encodedTokens[i].token_type_ids, new Array(diff).fill(0), - ) - } - - } else { // left - encodedTokens[i].input_ids = mergeArrays( - new Array(diff).fill(this.pad_token_id), encodedTokens[i].input_ids, - ) - encodedTokens[i].attention_mask = mergeArrays( - new Array(diff).fill(0), encodedTokens[i].attention_mask, - ) - if (encodedTokens[i].token_type_ids) { - encodedTokens[i].token_type_ids = mergeArrays( - new Array(diff).fill(0), encodedTokens[i].token_type_ids, - ) - } - } + padHelper(encodedTokens[i], max_length, (key) => { + key === 'input_ids' ? this.pad_token_id : 0 + }, this.padding_side); } } } @@ -2697,7 +2707,7 @@ export class PreTrainedTokenizer extends Callable { } } - let modelInputs = { input_ids, attention_mask } + const modelInputs = { input_ids, attention_mask }; if (token_type_ids) { modelInputs.token_type_ids = token_type_ids; } From fef0b868211ebe980f49c256c17242e72b4b93ba Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Jan 2024 16:41:38 +0200 Subject: [PATCH 08/13] Add static padding/truncation unit tests --- tests/tokenizers.test.js | 102 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/tests/tokenizers.test.js b/tests/tokenizers.test.js index cc2afb614..40fed05d1 100644 --- a/tests/tokenizers.test.js +++ b/tests/tokenizers.test.js @@ -154,6 +154,108 @@ describe('Tokenizers (hard-coded)', () => { } }); +describe('Tokenizer padding/truncation', () => { + const inputs = ['a', 'b c']; + const text_pair = ['d e', 'f g h']; + + it('should create a jagged array', async () => { + const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); + + { // support jagged array if `return_tensor=false` + const output = tokenizer(inputs, { + return_tensor: false, + }) + const expected = { + input_ids: [[101, 1037, 102], [101, 1038, 1039, 102]], + attention_mask: [[1, 1, 1], [1, 1, 1, 1]], + token_type_ids: [[0, 0, 0], [0, 0, 0, 0]] + } + compare(output, expected); + } + + { + const output = tokenizer(inputs, { + return_tensor: false, + truncation: true, + add_special_tokens: false, + }) + const expected = { + input_ids: [[1037], [1038, 1039]], + attention_mask: [[1], [1, 1]], + token_type_ids: [[0], [0, 0]] + } + compare(output, expected); + } + }) + + it('should create a tensor', async () => { + const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); + + { // Expected to throw error if jagged array + expect(() => tokenizer(inputs)).toThrowError('Unable to create tensor'); + } + + { // Truncation + const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { + truncation: true, + max_length: 1, + add_special_tokens: false, + }) + + expect(input_ids.tolist()).toEqual([[1037n], [1038n]]) + expect(attention_mask.tolist()).toEqual([[1n], [1n]]) + expect(token_type_ids.tolist()).toEqual([[0n], [0n]]) + } + { // Truncation w/ text pair + // TODO + } + + { // Padding + const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { + padding: true, + add_special_tokens: false, + }) + + expect(input_ids.tolist()).toEqual([[1037n, 0n], [1038n, 1039n]]) + expect(attention_mask.tolist()).toEqual([[1n, 0n], [1n, 1n]]) + expect(token_type_ids.tolist()).toEqual([[0n, 0n], [0n, 0n]]) + } + { // Padding w/ text pair + const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { + text_pair, + padding: true, + add_special_tokens: false, + }) + + expect(input_ids.tolist()).toEqual([ + [1037n, 1040n, 1041n, 0n, 0n], + [1038n, 1039n, 1042n, 1043n, 1044n], + ]); + expect(attention_mask.tolist()).toEqual([ + [1n, 1n, 1n, 0n, 0n], + [1n, 1n, 1n, 1n, 1n], + ]); + expect(token_type_ids.tolist()).toEqual([ + [0n, 1n, 1n, 0n, 0n], + [0n, 0n, 1n, 1n, 1n], + ]); + } + + { // Truncation + padding + const { input_ids, attention_mask, token_type_ids } = tokenizer(['a', 'b c', 'd e f'], { + padding: true, + truncation: true, + add_special_tokens: false, + max_length: 2, + }) + + expect(input_ids.tolist()).toEqual([[1037n, 0n], [1038n, 1039n], [1040n, 1041n]]) + expect(attention_mask.tolist()).toEqual([[1n, 0n], [1n, 1n], [1n, 1n]]) + expect(token_type_ids.tolist()).toEqual([[0n, 0n], [0n, 0n], [0n, 0n]]) + } + }, MAX_TEST_EXECUTION_TIME); +}); + describe('Token type ids', () => { it('should correctly add token type ids', async () => { const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); From c4f3123ca03219a8170db98390e57e075b75094f Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Jan 2024 16:43:54 +0200 Subject: [PATCH 09/13] Fix padding/truncation --- src/tokenizers.js | 63 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 20 deletions(-) diff --git a/src/tokenizers.js b/src/tokenizers.js index c0154959d..222ade6af 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -1626,15 +1626,24 @@ class BertProcessing extends PostProcessor { * @param {string[]} [tokens_pair=null] An optional second set of input tokens. * @returns {PostProcessedOutput} The post-processed tokens with the special tokens added to the beginning and end. */ - post_process(tokens, tokens_pair = null) { - tokens = mergeArrays([this.cls], tokens, [this.sep]); + post_process(tokens, tokens_pair = null, { + add_special_tokens = true, + } = {}) { + if (add_special_tokens) { + tokens = mergeArrays([this.cls], tokens, [this.sep]); + } + let token_type_ids = new Array(tokens.length).fill(0); if (tokens_pair !== null) { - const middle = this instanceof RobertaProcessing ? [this.sep] : []; // NOTE: It is intended to add 2 EOS tokens after the first set of tokens // https://github.com/huggingface/tokenizers/issues/983 - tokens = mergeArrays(tokens, middle, tokens_pair, [this.sep]); - token_type_ids = mergeArrays(token_type_ids, new Array(tokens_pair.length + 1 + middle.length).fill(1)); + const middle = (add_special_tokens && this instanceof RobertaProcessing) + ? [this.sep] + : []; + const after = add_special_tokens ? [this.sep] : []; + + tokens = mergeArrays(tokens, middle, tokens_pair, after); + token_type_ids = mergeArrays(token_type_ids, new Array(tokens_pair.length + middle.length + after.length).fill(1)); } return { tokens, token_type_ids }; } @@ -1665,16 +1674,19 @@ class TemplateProcessing extends PostProcessor { * @param {string[]} [tokens_pair=null] The list of tokens for the second sequence (optional). * @returns {PostProcessedOutput} An object containing the list of tokens with the special tokens replaced with actual tokens. */ - post_process(tokens, tokens_pair = null) { + post_process(tokens, tokens_pair = null, { + add_special_tokens = true, + } = {}) { let type = tokens_pair === null ? this.single : this.pair let processedTokens = []; let types = []; for (let item of type) { if ('SpecialToken' in item) { - processedTokens.push(item.SpecialToken.id); - types.push(item.SpecialToken.type_id); - + if (add_special_tokens) { + processedTokens.push(item.SpecialToken.id); + types.push(item.SpecialToken.type_id); + } } else if ('Sequence' in item) { if (item.Sequence.id === 'A') { processedTokens = mergeArrays(processedTokens, tokens); @@ -2362,7 +2374,7 @@ const SPECIAL_TOKEN_ATTRIBUTES = [ * @private */ function padHelper(item, length, value_fn, side) { - for (const key in Object.keys(item)) { + for (const key of Object.keys(item)) { const diff = length - item[key].length; const value = value_fn(key); @@ -2383,7 +2395,7 @@ function padHelper(item, length, value_fn, side) { function truncateHelper(item, length) { // Setting .length to a lower value truncates the array in-place: // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/length - for (const key in Object.keys(item)) { + for (const key of Object.keys(item)) { item[key].length = length; } } @@ -2613,6 +2625,10 @@ export class PreTrainedTokenizer extends Callable { // Calculate max length from sequences max_length = max(encodedTokens.map(x => x.input_ids.length))[0]; } + } else { + if (!truncation) { + console.warn(`Truncation was not explicitly activated but \`max_length\` is provided a specific value, please use \`truncation=true\` to explicitly truncate examples to max length.`) + } } // Ensure it is less than model max length @@ -2634,9 +2650,12 @@ export class PreTrainedTokenizer extends Callable { } else { // t.length < max_length // possibly pad if (padding) { - padHelper(encodedTokens[i], max_length, (key) => { - key === 'input_ids' ? this.pad_token_id : 0 - }, this.padding_side); + padHelper( + encodedTokens[i], + max_length, + key => key === 'input_ids' ? this.pad_token_id : 0, + this.padding_side + ); } } } @@ -2652,9 +2671,14 @@ export class PreTrainedTokenizer extends Callable { // we perform additional check if ( - encodedTokens.some(x => x.input_ids.length !== encodedTokens[0].input_ids.length) - || encodedTokens.some(x => x.attention_mask.length !== encodedTokens[0].attention_mask.length) - || encodedTokens.some(x => x.token_type_ids && x.token_type_ids.length !== encodedTokens[0].token_type_ids.length) + encodedTokens.some(x => { + for (const key of Object.keys(x)) { + if (x[key].length !== encodedTokens[0][key]?.length) { + return true; + } + } + return false; + }) ) { throw Error( "Unable to create tensor, you should probably activate truncation and/or padding " + @@ -2776,9 +2800,8 @@ export class PreTrainedTokenizer extends Callable { const tokens = this._encode_text(text); const tokens2 = this._encode_text(text_pair); - // TODO improve `add_special_tokens` and ensure correctness - const combinedTokens = (this.post_processor !== null && add_special_tokens) - ? this.post_processor(tokens, tokens2) + const combinedTokens = this.post_processor + ? this.post_processor(tokens, tokens2, { add_special_tokens }) : { tokens: mergeArrays(tokens ?? [], tokens2 ?? []) }; const input_ids = this.model.convert_tokens_to_ids(combinedTokens.tokens); From 02d9ca1de3c0132925c23561373e247ef4f36340 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Jan 2024 16:46:28 +0200 Subject: [PATCH 10/13] Remove unused `add_token_types` function --- src/models.js | 12 ++++++------ src/tokenizers.js | 30 ------------------------------ 2 files changed, 6 insertions(+), 36 deletions(-) diff --git a/src/models.js b/src/models.js index 21617249c..802f05722 100644 --- a/src/models.js +++ b/src/models.js @@ -42,10 +42,6 @@ import { AutoConfig, } from './configs.js'; -import { - add_token_types, -} from './tokenizers.js'; - import { Callable, isIntegralNumber, @@ -512,9 +508,13 @@ async function encoderForward(self, model_inputs) { encoderFeeds[key] = model_inputs[key]; } if (self.session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) { - // Assign default `token_type_ids` to the `encoderFeeds` if the model expects it, + // Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it, // but they weren't created by the tokenizer. - add_token_types(encoderFeeds); + encoderFeeds.token_type_ids = new Tensor( + 'int64', + new BigInt64Array(encoderFeeds.input_ids.data.length), + encoderFeeds.input_ids.dims + ) } return await sessionRun(self.session, encoderFeeds); } diff --git a/src/tokenizers.js b/src/tokenizers.js index 222ade6af..974c9cfbf 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -3027,36 +3027,6 @@ export class PreTrainedTokenizer extends Callable { } } -/** -* Helper method for adding `token_type_ids` to model inputs -* @param {Object} inputs An object containing the input ids and attention mask. -* @returns {Object} The prepared inputs object. -*/ -export function add_token_types(inputs) { - // TODO ensure correctness when token pair is present - if (inputs.input_ids instanceof Tensor) { - inputs.token_type_ids = new Tensor( - 'int64', - new BigInt64Array(inputs.input_ids.data.length), - inputs.input_ids.dims - ) - } else if (Array.isArray(inputs.input_ids)) { - - if (Array.isArray(inputs.input_ids[0])) { - // This means input is batched, so we need to batch the token_type_ids as well - inputs.token_type_ids = inputs.input_ids.map( - x => new Array(x.length).fill(0) - ) - } else { - inputs.token_type_ids = new Array(inputs.input_ids.length).fill(0); - } - } else { - throw new Error('Input ids must be a Tensor or an Array') - } - - return inputs; -} - /** * BertTokenizer is a class used to tokenize text for BERT models. * @extends PreTrainedTokenizer From e329fba2fb2a1eb2494dd909b5bbe8487b1ba268 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Jan 2024 17:05:48 +0200 Subject: [PATCH 11/13] Reduce duplication --- src/tokenizers.js | 56 +++++++++++++++++------------------------------ 1 file changed, 20 insertions(+), 36 deletions(-) diff --git a/src/tokenizers.js b/src/tokenizers.js index 974c9cfbf..86628bc5f 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -2550,6 +2550,15 @@ export class PreTrainedTokenizer extends Callable { return new this(...info); } + /** + * @typedef {number[]|number[][]|Tensor} BatchEncodingItem + * + * @typedef {Object} BatchEncoding Holds the output of the tokenizer's call function. + * @property {BatchEncodingItem} input_ids List of token ids to be fed to a model. + * @property {BatchEncodingItem} attention_mask List of indices specifying which tokens should be attended to by the model. + * @property {BatchEncodingItem} [token_type_ids] List of token type ids to be fed to a model. + */ + /** * Encode/tokenize the given text(s). * @param {string|string[]} text The text to tokenize. @@ -2560,7 +2569,7 @@ export class PreTrainedTokenizer extends Callable { * @param {boolean} [options.truncation=null] Whether to truncate the input sequences. * @param {number} [options.max_length=null] Maximum length of the returned list and optionally padding length. * @param {boolean} [options.return_tensor=true] Whether to return the results as Tensors or arrays. - * @returns {{ input_ids: number[]|number[][]|Tensor, attention_mask: any[]|Tensor }} Object to be passed to the model. + * @returns {BatchEncoding} Object to be passed to the model. */ _call( // Required positional arguments @@ -2661,9 +2670,7 @@ export class PreTrainedTokenizer extends Callable { } } - let input_ids; - let attention_mask; - let token_type_ids; + const result = {}; if (return_tensor) { if (!(padding && truncation)) { @@ -2692,51 +2699,28 @@ export class PreTrainedTokenizer extends Callable { // whether we have a single input or multiple inputs. const dims = [encodedTokens.length, encodedTokens[0].input_ids.length]; - input_ids = new Tensor('int64', - BigInt64Array.from(encodedTokens.flatMap(x => x.input_ids).map(BigInt)), - dims - ); - - attention_mask = new Tensor( - 'int64', - BigInt64Array.from(encodedTokens.flatMap(x => x.attention_mask).map(BigInt)), - dims - ); - - if (this.return_token_type_ids && encodedTokens[0].token_type_ids) { - token_type_ids = new Tensor( - 'int64', - BigInt64Array.from(encodedTokens.flatMap(x => x.token_type_ids).map(BigInt)), + for (const key of Object.keys(encodedTokens[0])) { + result[key] = new Tensor('int64', + BigInt64Array.from(encodedTokens.flatMap(x => x[key]).map(BigInt)), dims ); } } else { - input_ids = encodedTokens.map(x => x.input_ids); - attention_mask = encodedTokens.map(x => x.attention_mask); - - if (this.return_token_type_ids && encodedTokens[0].token_type_ids) { - token_type_ids = encodedTokens.map(x => x.token_type_ids); + for (const key of Object.keys(encodedTokens[0])) { + result[key] = encodedTokens.map(x => x[key]); } // If not returning a tensor, we match the input type if (!isBatched) { // Input was not batched, so we unwrap - input_ids = input_ids[0]; - attention_mask = attention_mask[0]; - - if (this.return_token_type_ids && token_type_ids) { - token_type_ids = token_type_ids[0]; + for (const key of Object.keys(result)) { + result[key] = result[key][0]; } } } - const modelInputs = { input_ids, attention_mask }; - if (token_type_ids) { - modelInputs.token_type_ids = token_type_ids; - } - - return modelInputs + return /** @type {BatchEncoding} */(result); } /** @@ -2810,7 +2794,7 @@ export class PreTrainedTokenizer extends Callable { input_ids, attention_mask: new Array(input_ids.length).fill(1), } - if (combinedTokens.token_type_ids) { + if (this.return_token_type_ids && combinedTokens.token_type_ids) { result.token_type_ids = combinedTokens.token_type_ids; } return result; From 83b416e8682da2851043c46a808c5c641e07c5fd Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Jan 2024 17:36:13 +0200 Subject: [PATCH 12/13] `let` -> `const` where possible --- src/tokenizers.js | 177 ++++++++++++++++++++++------------------------ 1 file changed, 83 insertions(+), 94 deletions(-) diff --git a/src/tokenizers.js b/src/tokenizers.js index 86628bc5f..a786cd1ed 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -7,8 +7,8 @@ * ```javascript * import { AutoTokenizer } from '@xenova/transformers'; * - * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); - * let { input_ids } = await tokenizer('I love transformers!'); + * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); + * const { input_ids } = await tokenizer('I love transformers!'); * // Tensor { * // data: BigInt64Array(6) [101n, 1045n, 2293n, 19081n, 999n, 102n], * // dims: [1, 6], @@ -58,7 +58,7 @@ import { Template } from '@huggingface/jinja'; */ async function loadTokenizer(pretrained_model_name_or_path, options) { - let info = await Promise.all([ + const info = await Promise.all([ getModelJSON(pretrained_model_name_or_path, 'tokenizer.json', true, options), getModelJSON(pretrained_model_name_or_path, 'tokenizer_config.json', true, options), ]) @@ -201,7 +201,7 @@ function lowercase_and_remove_accent(text) { * @param {Map} mapping The mapping from input domain to value. */ function fuse(arr, value, mapping) { - let fused = []; + const fused = []; let i = 0; while (i < arr.length) { fused.push(arr[i]) @@ -413,9 +413,9 @@ class WordPieceTokenizer extends TokenizerModel { * @returns {string[]} An array of encoded tokens. */ encode(tokens) { - let outputTokens = []; - for (let token of tokens) { - let chars = [...token]; + const outputTokens = []; + for (const token of tokens) { + const chars = [...token]; if (chars.length > this.max_input_chars_per_word) { outputTokens.push(this.unk_token); continue; @@ -423,7 +423,7 @@ class WordPieceTokenizer extends TokenizerModel { let isUnknown = false; let start = 0; - let subTokens = []; + const subTokens = []; while (start < chars.length) { let end = chars.length; @@ -553,12 +553,12 @@ class Unigram extends TokenizerModel { /** * Encodes an array of tokens using Unigram encoding. - * @param {Array} tokens The tokens to encode. - * @returns {Array} An array of encoded tokens. + * @param {string[]} tokens The tokens to encode. + * @returns {string[]} An array of encoded tokens. */ encode(tokens) { - let toReturn = []; - for (let token of tokens) { + const toReturn = []; + for (const token of tokens) { const tokenized = this.tokenize(token); toReturn.push(...tokenized); } @@ -582,7 +582,7 @@ const BYTES_TO_UNICODE = (() => { ...Array.from({ length: "¬".charCodeAt(0) - "¡".charCodeAt(0) + 1 }, (_, i) => i + "¡".charCodeAt(0)), ...Array.from({ length: "ÿ".charCodeAt(0) - "®".charCodeAt(0) + 1 }, (_, i) => i + "®".charCodeAt(0)), ]; - let cs = bs.slice(); + const cs = bs.slice(); let n = 0; for (let b = 0; b < 256; ++b) { if (!bs.includes(b)) { @@ -591,7 +591,7 @@ const BYTES_TO_UNICODE = (() => { n += 1; } } - let ccs = cs.map(n => String.fromCharCode(n)); + const ccs = cs.map(n => String.fromCharCode(n)); return Object.fromEntries(bs.map((b, i) => [b, ccs[i]])); })(); @@ -809,12 +809,12 @@ class BPE extends TokenizerModel { * @returns {string[]} The resulting subword tokens after applying the BPE algorithm to the input sequence of tokens. */ encode(tokens) { - let outputTokens = []; + const outputTokens = []; - for (let token of tokens) { - let bpe_token_list = this.bpe(token); + for (const token of tokens) { + const bpe_token_list = this.bpe(token); - for (let t of bpe_token_list) { + for (const t of bpe_token_list) { if (this.tokens_to_ids.has(t)) { outputTokens.push(t); } else { @@ -962,14 +962,10 @@ class Replace extends Normalizer { * @returns {string} The normalized text after replacing the pattern with the content. */ normalize(text) { - let pattern = createPattern(this.config.pattern); - if (pattern === null) { - return text; - } - - text = text.replaceAll(pattern, this.config.content) - - return text; + const pattern = createPattern(this.config.pattern); + return pattern === null + ? text + : text.replaceAll(pattern, this.config.content); } } @@ -1132,10 +1128,10 @@ class BertNormalizer extends Normalizer { */ _tokenize_chinese_chars(text) { /* Adds whitespace around any CJK character. */ - let output = []; + const output = []; for (let i = 0; i < text.length; ++i) { - let char = text[i]; - let cp = char.charCodeAt(0); + const char = text[i]; + const cp = char.charCodeAt(0); if (this._is_chinese_char(cp)) { output.push(" "); output.push(char); @@ -1318,13 +1314,10 @@ class PreTokenizer extends Callable { * @returns {string[]} An array of pre-tokens. */ pre_tokenize(text, options) { - let result = []; - if (Array.isArray(text)) { - result = text.map(x => this.pre_tokenize_text(x, options)) - } else { - result = this.pre_tokenize_text(text, options); - } - return result.flat(); + return (Array.isArray(text) + ? text.map(x => this.pre_tokenize_text(x, options)) + : this.pre_tokenize_text(text, options) + ).flat(); } /** @@ -1417,7 +1410,7 @@ class ByteLevelPreTokenizer extends PreTokenizer { } // Split on whitespace and punctuation - let tokens = this.use_regex ? (text.match(this.pattern) || []) : [text]; + const tokens = this.use_regex ? (text.match(this.pattern) || []) : [text]; // Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) return tokens.map( @@ -1677,11 +1670,11 @@ class TemplateProcessing extends PostProcessor { post_process(tokens, tokens_pair = null, { add_special_tokens = true, } = {}) { - let type = tokens_pair === null ? this.single : this.pair + const type = tokens_pair === null ? this.single : this.pair let processedTokens = []; let types = []; - for (let item of type) { + for (const item of type) { if ('SpecialToken' in item) { if (add_special_tokens) { processedTokens.push(item.SpecialToken.id); @@ -1816,12 +1809,10 @@ class ReplaceDecoder extends Decoder { /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { - let pattern = createPattern(this.config.pattern); - if (pattern === null) { - return tokens; - } - - return tokens.map(token => token.replaceAll(pattern, this.config.content)) + const pattern = createPattern(this.config.pattern); + return pattern === null + ? tokens + : tokens.map(token => token.replaceAll(pattern, this.config.content)) } } @@ -1836,13 +1827,13 @@ class ByteFallback extends Decoder { /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { - let new_tokens = []; + const new_tokens = []; let previous_byte_tokens = []; - for (let token of tokens) { + for (const token of tokens) { let bytes = null; if (token.length === 6 && token.startsWith('<0x') && token.endsWith('>')) { - let byte = parseInt(token.slice(3, 5), 16); + const byte = parseInt(token.slice(3, 5), 16); if (!isNaN(byte)) { bytes = byte; } @@ -1851,7 +1842,7 @@ class ByteFallback extends Decoder { previous_byte_tokens.push(bytes); } else { if (previous_byte_tokens.length > 0) { - let string = this.text_decoder.decode(Uint8Array.from(previous_byte_tokens)); + const string = this.text_decoder.decode(Uint8Array.from(previous_byte_tokens)); new_tokens.push(string); previous_byte_tokens = []; } @@ -1859,7 +1850,7 @@ class ByteFallback extends Decoder { } } if (previous_byte_tokens.length > 0) { - let string = this.text_decoder.decode(Uint8Array.from(previous_byte_tokens)); + const string = this.text_decoder.decode(Uint8Array.from(previous_byte_tokens)); new_tokens.push(string); previous_byte_tokens = []; } @@ -1985,10 +1976,9 @@ class ByteLevelDecoder extends Decoder { * @returns {string} The decoded string. */ convert_tokens_to_string(tokens) { - let text = tokens.join(''); - - let byteArray = new Uint8Array([...text].map(c => this.byte_decoder[c])); - let decoded_text = this.text_decoder.decode(byteArray); + const text = tokens.join(''); + const byteArray = new Uint8Array([...text].map(c => this.byte_decoder[c])); + const decoded_text = this.text_decoder.decode(byteArray); return decoded_text; } @@ -2000,9 +1990,9 @@ class ByteLevelDecoder extends Decoder { // To avoid mixing byte-level and unicode for byte-level BPT // we need to build string separately for added tokens and byte-level tokens // cf. https://github.com/huggingface/transformers/issues/1133 - let sub_texts = []; + const sub_texts = []; let current_sub_text = []; - for (let token of tokens) { + for (const token of tokens) { // tokens sent here are already filtered, so we don't need to do this // if (skip_special_tokens && this.all_special_ids.includes(token)) { // continue; @@ -2050,7 +2040,7 @@ class CTCDecoder extends Decoder { if (tokens.length === 0) return ''; // group same tokens into non-repeating tokens in CTC style decoding - let grouped_tokens = [tokens[0]]; + const grouped_tokens = [tokens[0]]; for (let i = 1; i < tokens.length; ++i) { if (tokens[i] !== grouped_tokens.at(-1)) { grouped_tokens.push(tokens[i]); @@ -2058,7 +2048,7 @@ class CTCDecoder extends Decoder { } // filter self.pad_token which is used as CTC-blank token - let filtered_tokens = grouped_tokens.filter(token => token !== this.pad_token); + const filtered_tokens = grouped_tokens.filter(token => token !== this.pad_token); let text = filtered_tokens.join(''); if (this.cleanup) { @@ -2206,7 +2196,7 @@ class MetaspaceDecoder extends Decoder { /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { - let result = []; + const result = []; for (let i = 0; i < tokens.length; ++i) { let normalized = tokens[i].replaceAll(this.replacement, ' '); if (this.addPrefixSpace && i == 0 && normalized.startsWith(' ')) { @@ -2429,7 +2419,7 @@ export class PreTrainedTokenizer extends Callable { /** @type {AddedToken[]} */ this.added_tokens = []; - for (let addedToken of tokenizerJSON.added_tokens) { + for (const addedToken of tokenizerJSON.added_tokens) { const token = new AddedToken(addedToken); this.added_tokens.push(token); @@ -2501,8 +2491,8 @@ export class PreTrainedTokenizer extends Callable { * @throws {Error} If an object is found for a matching key and its __type property is not "AddedToken". */ getToken(...keys) { - for (let key of keys) { - let item = this._tokenizer_config[key]; + for (const key of keys) { + const item = this._tokenizer_config[key]; if (!item) continue; @@ -2537,7 +2527,7 @@ export class PreTrainedTokenizer extends Callable { legacy = null, } = {}) { - let info = await loadTokenizer(pretrained_model_name_or_path, { + const info = await loadTokenizer(pretrained_model_name_or_path, { progress_callback, config, cache_dir, @@ -3204,7 +3194,7 @@ function _build_translation_inputs(self, raw_inputs, tokenizer_options, generate // In the same way as the Python library, we override the post-processor // to force the source language to be first: - for (let item of self.post_processor.config.single) { + for (const item of self.post_processor.config.single) { if ('SpecialToken' in item && self.languageRegex.test(item.SpecialToken.id)) { item.SpecialToken.id = self.lang_to_token(src_lang_token); break; @@ -3474,7 +3464,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer { const all_special_ids = new Set(this.all_special_ids); - for (let output of sequences) { + for (const output of sequences) { // NOTE: python version has batches, so it uses [0] const token_ids = output.tokens; const token_timestamps = returnWordTimestamps ? output.token_timestamps : null; @@ -3692,9 +3682,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer { } } if (returnWordTimestamps) { - let new_chunks = []; - for (let chunk of chunks) { - for (let word of chunk.words) { + const new_chunks = []; + for (const chunk of chunks) { + for (const word of chunk.words) { new_chunks.push(word); } } @@ -3805,9 +3795,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer { /** @private */ collateWordTimestamps(tokens, token_timestamps, language) { - let [words, _, token_indices] = this.combineTokensIntoWords(tokens, language); + const [words, _, token_indices] = this.combineTokensIntoWords(tokens, language); - let timings = []; + const timings = []; for (let i = 0; i < words.length; ++i) { const indices = token_indices[i]; timings.push({ @@ -3879,10 +3869,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer { const timestamp_begin = Array.from(this.all_special_ids).at(-1) + 1; /**@type {Array} */ let outputs = [[]]; - for (let token of token_ids) { + for (const token of token_ids) { if (token >= timestamp_begin) { - let timestamp = (token - timestamp_begin) * time_precision; - timestamp = round(timestamp, 2); + const timestamp = round((token - timestamp_begin) * time_precision, 2); outputs.push(`<|${timestamp}|>`); outputs.push([]); } else { @@ -3915,9 +3904,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer { }); const replacement_char = '\uFFFD'; - let words = [] - let word_tokens = [] - let token_indices = [] + const words = [] + const word_tokens = [] + const token_indices = [] let current_tokens = [] let current_indices = [] let unicode_offset = 0 @@ -3954,11 +3943,11 @@ export class WhisperTokenizer extends PreTrainedTokenizer { */ splitTokensOnSpaces(tokens) { - let [subwords, subword_tokens_list, subword_indices_list] = this.splitTokensOnUnicode(tokens); + const [subwords, subword_tokens_list, subword_indices_list] = this.splitTokensOnUnicode(tokens); - let words = [] - let word_tokens = [] - let token_indices = [] + const words = [] + const word_tokens = [] + const token_indices = [] const punctuationRegex = new RegExp(`^[${PUNCTUATION_REGEX}]$`, 'gu'); @@ -4001,9 +3990,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer { */ mergePunctuations(words, tokens, indices, prepended, appended) { - let newWords = structuredClone(words); - let newTokens = structuredClone(tokens); - let newIndices = structuredClone(indices); + const newWords = structuredClone(words); + const newTokens = structuredClone(tokens); + const newIndices = structuredClone(indices); // prepend punctuations @@ -4057,8 +4046,8 @@ export class WhisperTokenizer extends PreTrainedTokenizer { * **Example: Get ids for a language** * ```javascript * // instantiate the tokenizer and set the prefix token to Spanish - * let tokenizer = await WhisperTokenizer.from_pretrained('Xenova/whisper-tiny'); - * let forced_decoder_ids = tokenizer.get_decoder_prompt_ids({ language: 'spanish' }); + * const tokenizer = await WhisperTokenizer.from_pretrained('Xenova/whisper-tiny'); + * const forced_decoder_ids = tokenizer.get_decoder_prompt_ids({ language: 'spanish' }); * // [(1, 50262), (2, 50363)] * ``` * @@ -4081,7 +4070,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer { // <|lang_id|> <|task|> <|notimestamps|> - let forced_decoder_ids = []; + const forced_decoder_ids = []; if (language) { // User wishes to specify the language @@ -4106,7 +4095,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer { } } - let language_token_id = this.model.tokens_to_ids.get(`<|${language_code}|>`); + const language_token_id = this.model.tokens_to_ids.get(`<|${language_code}|>`); if (language_token_id === undefined) { throw new Error(`Unable to find language "${language_code}" in model vocabulary. Please report this issue at https://github.com/xenova/transformers.js/issues/new/choose.`) } @@ -4123,7 +4112,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer { throw new Error(`Task "${task}" is not supported. Must be one of: ["transcribe", "translate"]`); } - let task_token_id = this.model.tokens_to_ids.get(`<|${task}|>`); + const task_token_id = this.model.tokens_to_ids.get(`<|${task}|>`); if (task_token_id === undefined) { throw new Error(`Unable to find task "${task}" in model vocabulary. Please report this issue at https://github.com/xenova/transformers.js/issues/new/choose.`) } @@ -4135,7 +4124,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer { } if (no_timestamps) { - let no_timestamps_id = this.model.tokens_to_ids.get(`<|notimestamps|>`); + const no_timestamps_id = this.model.tokens_to_ids.get(`<|notimestamps|>`); if (no_timestamps_id === undefined) { throw new Error('Unable to find "<|notimestamps|>" in model vocabulary. Please report this issue at https://github.com/xenova/transformers.js/issues/new/choose.') } @@ -4185,7 +4174,7 @@ export class MarianTokenizer extends PreTrainedTokenizer { if (text === null) return null; // Check if text starts with language code: - let [matchInfo, ...remainder] = text.trim().split(this.languageRegex); + const [matchInfo, ...remainder] = text.trim().split(this.languageRegex); if (remainder.length === 0) { // No language code, encode normally @@ -4193,7 +4182,7 @@ export class MarianTokenizer extends PreTrainedTokenizer { } else if (remainder.length === 2) { // Text starts with language code, so we do not encode it with sentencepiece. - let [language, text] = remainder; + const [language, text] = remainder; if (!this.supported_language_codes.includes(language)) { console.warn(`Unsupported language code "${language}" detected, which may lead to unexpected behavior. Should be one of: ${JSON.stringify(this.supported_language_codes)}`) @@ -4229,7 +4218,7 @@ export class VitsTokenizer extends PreTrainedTokenizer { * The chosen tokenizer class is determined by the type specified in the tokenizer config. * * @example - * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); + * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); */ export class AutoTokenizer { static TOKENIZER_CLASS_MAPPING = { @@ -4304,7 +4293,7 @@ export class AutoTokenizer { legacy = null, } = {}) { - let [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, { + const [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, { quantized, progress_callback, config, @@ -4315,7 +4304,7 @@ export class AutoTokenizer { }) // Some tokenizers are saved with the "Fast" suffix, so we remove that if present. - let tokenizerName = tokenizerConfig.tokenizer_class?.replace(/Fast$/, '') ?? 'PreTrainedTokenizer'; + const tokenizerName = tokenizerConfig.tokenizer_class?.replace(/Fast$/, '') ?? 'PreTrainedTokenizer'; let cls = this.TOKENIZER_CLASS_MAPPING[tokenizerName]; if (!cls) { From b6fe318f8fff9c7a4c1dbbb80d923ad200f6ef14 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Jan 2024 18:10:20 +0200 Subject: [PATCH 13/13] Add cross-encoder models --- scripts/supported_models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/scripts/supported_models.py b/scripts/supported_models.py index 0ea5c8c47..b5d0912df 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -109,6 +109,11 @@ 'unitary/toxic-bert', 'BAAI/bge-reranker-large', 'BAAI/bge-reranker-base', + 'cross-encoder/ms-marco-TinyBERT-L-2-v2', + 'cross-encoder/ms-marco-MiniLM-L-2-v2', + 'cross-encoder/ms-marco-MiniLM-L-4-v2', + 'cross-encoder/ms-marco-MiniLM-L-6-v2', + 'cross-encoder/ms-marco-MiniLM-L-12-v2', ], # Token classification