diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index e1fe2530..26bd5ed0 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -278,9 +278,9 @@ public class PreTrainedTokenizer: Tokenizer { public var unknownTokenId: Int? { model.unknownTokenId } public var fuseUnknownTokens: Bool { model.fuseUnknownTokens } - private let addedTokens: Set - private let specialTokens: [String: Int] - private let addedTokensRegex: NSRegularExpression? + let addedTokens: Set + let specialTokens: [String: Int] + let addedTokensRegex: NSRegularExpression? private let preTokenizer: PreTokenizer? private let normalizer: Normalizer? @@ -721,4 +721,18 @@ class LlamaPreTrainedTokenizer: PreTrainedTokenizer { let updatedData = Config(configDictionary) try super.init(tokenizerConfig: tokenizerConfig, tokenizerData: updatedData, strict: strict) } + + /// If `isLegacy` is `False`, a prefix token is added unless the first token is special. + /// https://github.com/huggingface/transformers/blob/e6dcf8abd6f65bb4b6dfc1831b20d9ba49ce00e2/src/transformers/models/t5/tokenization_t5.py#L374-L387 + override func tokenize(text: String) -> [String] { + if isLegacy || text.isEmpty { + return super.tokenize(text: text) + } + + let tokens = super.tokenize(text: sentencePieceUnderline + text.replacingOccurrences(of: sentencePieceUnderline, with: " ")) + if tokens.first == sentencePieceUnderline, let second = tokens.dropFirst().first, specialTokens[second] != nil { + return Array(tokens[1...]) + } + return tokens + } } diff --git a/Tests/TokenizersTests/TokenizerTests.swift b/Tests/TokenizersTests/TokenizerTests.swift index 821df186..62ea3a0d 100644 --- a/Tests/TokenizersTests/TokenizerTests.swift +++ b/Tests/TokenizersTests/TokenizerTests.swift @@ -118,6 +118,17 @@ class PhiSimpleTests: XCTestCase { XCTAssertEqual(tokenizer.encode(text: "hello world"), [15339, 1917]) XCTAssertEqual(tokenizer.encode(text: "<|im_start|>user<|im_sep|>Who are you?<|im_end|><|im_start|>assistant<|im_sep|>"), [100264, 882, 100266, 15546, 527, 499, 30, 100265, 100264, 78191, 100266]) } + + /// https://github.com/huggingface/swift-transformers/issues/96 + func testLegacyLlamaBehaviour() async throws { + guard let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed") as? PreTrainedTokenizer else { + XCTFail() + return + } + + let inputIds = tokenizer(" Hi") + XCTAssertEqual(inputIds, [1, 29871, 6324]) + } } class UnregisteredTokenizerTests: XCTestCase {