diff --git a/src/core/PreTokenizer.ts b/src/core/PreTokenizer.ts index 02cf93c..c87ec79 100644 --- a/src/core/PreTokenizer.ts +++ b/src/core/PreTokenizer.ts @@ -6,7 +6,10 @@ import type { PreTokenizeTextOptions } from "@static/tokenizer"; * A callable class representing a pre-tokenizer used in tokenization. Subclasses * should implement the `pre_tokenize_text` method to define the specific pre-tokenization logic. */ -abstract class PreTokenizer extends Callable<[string | string[], any?], string[]> { +abstract class PreTokenizer extends Callable< + [string | string[], any?], + string[] +> { /** * Method that should be implemented by subclasses to define the specific pre-tokenization logic. * @@ -14,7 +17,10 @@ abstract class PreTokenizer extends Callable<[string | string[], any?], string[] * @param options Additional options for the pre-tokenization logic. * @returns The pre-tokenized text. */ - abstract pre_tokenize_text(text: string, options?: PreTokenizeTextOptions): string[]; + abstract pre_tokenize_text( + text: string, + options?: PreTokenizeTextOptions, + ): string[]; /** * Tokenizes the given text into pre-tokens. diff --git a/src/core/Tokenizer.ts b/src/core/Tokenizer.ts index 4086af4..37b9772 100644 --- a/src/core/Tokenizer.ts +++ b/src/core/Tokenizer.ts @@ -19,7 +19,11 @@ import type PreTokenizer from "./PreTokenizer"; import type TokenizerModel from "./TokenizerModel"; import type PostProcessor from "./PostProcessor"; import type Decoder from "./Decoder"; -import type { TokenConfig, TokenizerConfig, TokenizerJSON } from "@static/tokenizer"; +import type { + TokenConfig, + TokenizerConfig, + TokenizerJSON, +} from "@static/tokenizer"; interface EncodeOptions { text_pair?: string | null; @@ -292,6 +296,36 @@ class Tokenizer { ? this.post_processor(tokens1, tokens2, add_special_tokens) : { tokens: merge_arrays(tokens1 ?? [], tokens2 ?? []) }; } + + /** + * Converts a token string to its corresponding token ID. + * @param token The token string to convert. + * @returns The token ID, or undefined if the token is not in the vocabulary. + */ + public token_to_id(token: string): number | undefined { + return this.model.tokens_to_ids.get(token); + } + + /** + * Converts a token ID to its corresponding token string. + * @param id The token ID to convert. + * @returns The token string, or undefined if the ID is not in the vocabulary. + */ + public id_to_token(id: number): string | undefined { + return this.model.vocab[id]; + } + + /** + * Returns a mapping of token IDs to AddedToken objects for all added tokens. + * @returns A Map where keys are token IDs and values are AddedToken objects. + */ + public get_added_tokens_decoder(): Map { + const decoder = new Map(); + for (const token of this.added_tokens) { + decoder.set(token.id, token); + } + return decoder; + } } export default Tokenizer; diff --git a/src/core/decoder/create_decoder.ts b/src/core/decoder/create_decoder.ts index 2775dfb..4abccc9 100644 --- a/src/core/decoder/create_decoder.ts +++ b/src/core/decoder/create_decoder.ts @@ -1,4 +1,3 @@ - import ByteLevel from "./ByteLevel"; import WordPiece from "./WordPiece"; import Metaspace from "./Metaspace"; diff --git a/src/core/normalizer/create_normalizer.ts b/src/core/normalizer/create_normalizer.ts index 14c6e5a..fbe3682 100644 --- a/src/core/normalizer/create_normalizer.ts +++ b/src/core/normalizer/create_normalizer.ts @@ -1,4 +1,3 @@ - import BertNormalizer from "./BertNormalizer"; import Precompiled from "./Precompiled"; import Sequence from "./Sequence"; diff --git a/src/core/postProcessor/create_post_processor.ts b/src/core/postProcessor/create_post_processor.ts index c960163..06cee45 100644 --- a/src/core/postProcessor/create_post_processor.ts +++ b/src/core/postProcessor/create_post_processor.ts @@ -1,4 +1,3 @@ - import TemplateProcessing from "./TemplateProcessing"; import ByteLevel from "./ByteLevel"; import BertProcessing from "./BertProcessing"; diff --git a/src/index.ts b/src/index.ts index 81573ea..d1c639d 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,4 +1,5 @@ export { default as Tokenizer } from "./core/Tokenizer"; +export { default as AddedToken } from "./core/AddedToken"; export type { Encoding } from "./static/types"; // Export all components diff --git a/tests/tokenizers.test.ts b/tests/tokenizers.test.ts index 387aacc..9980b46 100644 --- a/tests/tokenizers.test.ts +++ b/tests/tokenizers.test.ts @@ -1,5 +1,5 @@ import fetchConfigById from "./utils/fetchConfigById"; -import { Tokenizer } from "../src"; +import { Tokenizer, AddedToken } from "../src"; import collectTests from "./utils/collectTests"; const TOKENIZER_TESTS = await collectTests(); @@ -43,3 +43,203 @@ describe("Tokenizers (model-specific)", () => { }); } }); + +describe("Tokenizer methods", () => { + // Create a simple BPE tokenizer for testing + // Vocab size: 10 tokens + // - 3 special tokens: , , + // - 1 unk token: + // - 5 regular tokens: a, b, c, ab, bc + // - 1 non-special added token: "" + const unk_token = ""; + const bos_token = ""; + const eos_token = ""; + const pad_token = ""; + const added_token = ""; + + const added_tokens = [ + new AddedToken({ + id: 0, + content: unk_token, + special: true, + }), + new AddedToken({ + id: 1, + content: bos_token, + special: true, + }), + new AddedToken({ + id: 2, + content: eos_token, + special: true, + }), + new AddedToken({ + id: 3, + content: pad_token, + special: true, + }), + new AddedToken({ + id: 9, + content: added_token, + special: false, // regular added token + }), + ]; + + const tokenizerJson = { + version: "1.0", + truncation: null, + padding: null, + added_tokens, + normalizer: null, + pre_tokenizer: null, + post_processor: null, + decoder: null, + model: { + type: "BPE", + dropout: null, + unk_token, + continuing_subword_prefix: null, + end_of_word_suffix: null, + fuse_unk: false, + byte_fallback: false, + ignore_merges: false, + vocab: { + [unk_token]: 0, + [bos_token]: 1, + [eos_token]: 2, + [pad_token]: 3, + a: 4, + b: 5, + c: 6, + ab: 7, + bc: 8, + }, + merges: [ + ["a", "b"], + ["b", "c"], + ], + }, + } as any; + + const tokenizerConfig = { + add_bos_token: false, + add_prefix_space: false, + added_tokens_decoder: Object.fromEntries(added_tokens.map((token) => [String(token.id), { id: token.id, content: token.content, special: token.special }])), + bos_token, + clean_up_tokenization_spaces: false, + eos_token, + legacy: true, + model_max_length: 1000000000000000, + pad_token, + sp_model_kwargs: {}, + spaces_between_special_tokens: false, + tokenizer_class: "LlamaTokenizer", + unk_token, + }; + + let tokenizer: Tokenizer; + + beforeAll(() => { + tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig); + }); + + describe("token_to_id", () => { + test("should return correct ID for regular token", () => { + expect(tokenizer.token_to_id("a")).toBe(4); + expect(tokenizer.token_to_id("b")).toBe(5); + expect(tokenizer.token_to_id("c")).toBe(6); + }); + + test("should return correct ID for merged token", () => { + expect(tokenizer.token_to_id("ab")).toBe(7); + expect(tokenizer.token_to_id("bc")).toBe(8); + }); + + test("should return correct ID for special tokens", () => { + expect(tokenizer.token_to_id(unk_token)).toBe(0); + expect(tokenizer.token_to_id(bos_token)).toBe(1); + expect(tokenizer.token_to_id(eos_token)).toBe(2); + expect(tokenizer.token_to_id(pad_token)).toBe(3); + expect(tokenizer.token_to_id(added_token)).toBe(9); + }); + + test("should return undefined for non-existing token", () => { + expect(tokenizer.token_to_id("xyz")).toBeUndefined(); + }); + }); + + describe("id_to_token", () => { + test("should return correct token for regular token ID", () => { + expect(tokenizer.id_to_token(4)).toBe("a"); + expect(tokenizer.id_to_token(5)).toBe("b"); + expect(tokenizer.id_to_token(6)).toBe("c"); + }); + + test("should return correct token for merged token ID", () => { + expect(tokenizer.id_to_token(7)).toBe("ab"); + expect(tokenizer.id_to_token(8)).toBe("bc"); + }); + + test("should return correct token for special/added token ID", () => { + expect(tokenizer.id_to_token(0)).toBe(unk_token); + expect(tokenizer.id_to_token(1)).toBe(bos_token); + expect(tokenizer.id_to_token(2)).toBe(eos_token); + expect(tokenizer.id_to_token(3)).toBe(pad_token); + expect(tokenizer.id_to_token(9)).toBe(added_token); + }); + + test("should return undefined for non-existing ID", () => { + expect(tokenizer.id_to_token(999)).toBeUndefined(); + }); + }); + + describe("get_added_tokens_decoder", () => { + test("should return a Map", () => { + const decoder = tokenizer.get_added_tokens_decoder(); + expect(decoder).toBeInstanceOf(Map); + }); + + test("should contain all special tokens", () => { + const decoder = tokenizer.get_added_tokens_decoder(); + expect(decoder.size).toBe(5); + expect(decoder.has(0)).toBe(true); + expect(decoder.has(1)).toBe(true); + expect(decoder.has(2)).toBe(true); + expect(decoder.has(3)).toBe(true); + expect(decoder.has(9)).toBe(true); + }); + + test("should return AddedToken objects with correct properties", () => { + const decoder = tokenizer.get_added_tokens_decoder(); + const unkToken = decoder.get(0); + expect(unkToken).toBeDefined(); + expect(unkToken?.content).toBe(unk_token); + expect(unkToken?.special).toBe(true); + expect(unkToken).toBeInstanceOf(AddedToken); + + const bosToken = decoder.get(1); + expect(bosToken?.content).toBe(bos_token); + expect(bosToken?.special).toBe(true); + }); + + test("should not contain regular tokens", () => { + const decoder = tokenizer.get_added_tokens_decoder(); + expect(decoder.has(4)).toBe(false); + expect(decoder.has(5)).toBe(false); + expect(decoder.has(6)).toBe(false); + }); + }); + + describe("roundtrip conversions", () => { + test("token_to_id and id_to_token should be inverse operations", () => { + const tokens = [unk_token, bos_token, eos_token, pad_token, "a", "b", "c", "ab", "bc", added_token]; + + for (const token of tokens) { + const id = tokenizer.token_to_id(token); + expect(id).toBeDefined(); + const tokenBack = tokenizer.id_to_token(id!); + expect(tokenBack).toBe(token); + } + }); + }); +});