diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index f8662c34..e1fe2530 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -102,37 +102,42 @@ public protocol PreTrainedTokenizerModel: TokenizingModel { struct TokenizerModel { static let knownTokenizers: [String: PreTrainedTokenizerModel.Type] = [ "BertTokenizer": BertTokenizer.self, + "CodeGenTokenizer": BPETokenizer.self, + "CodeLlamaTokenizer": BPETokenizer.self, + "CohereTokenizer": BPETokenizer.self, "DistilbertTokenizer": BertTokenizer.self, "DistilBertTokenizer": BertTokenizer.self, + "FalconTokenizer": BPETokenizer.self, + "GemmaTokenizer": BPETokenizer.self, + "GPT2Tokenizer": BPETokenizer.self, + "LlamaTokenizer": BPETokenizer.self, "RobertaTokenizer": BPETokenizer.self, - "CodeGenTokenizer": CodeGenTokenizer.self, - "CodeLlamaTokenizer": CodeLlamaTokenizer.self, - "FalconTokenizer": FalconTokenizer.self, - "GemmaTokenizer": GemmaTokenizer.self, - "GPT2Tokenizer": GPT2Tokenizer.self, - "LlamaTokenizer": LlamaTokenizer.self, "T5Tokenizer": T5Tokenizer.self, - "WhisperTokenizer": WhisperTokenizer.self, - "CohereTokenizer": CohereTokenizer.self, - "Qwen2Tokenizer": Qwen2Tokenizer.self, "PreTrainedTokenizer": BPETokenizer.self, + "Qwen2Tokenizer": BPETokenizer.self, + "WhisperTokenizer": BPETokenizer.self, ] static func unknownToken(from tokenizerConfig: Config) -> String? { tokenizerConfig.unkToken.content.string() ?? tokenizerConfig.unkToken.string() } - static func from(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws -> TokenizingModel { + static func from(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int], strict: Bool = true) throws -> TokenizingModel { guard let tokenizerClassName = tokenizerConfig.tokenizerClass.string() else { throw TokenizerError.missingTokenizerClassInConfig } // Some tokenizer_class entries use a Fast suffix let tokenizerName = tokenizerClassName.replacingOccurrences(of: "Fast", with: "") - guard let tokenizerClass = TokenizerModel.knownTokenizers[tokenizerName] else { - throw TokenizerError.unsupportedTokenizer(tokenizerName) + // Fallback to BPETokenizer if class is not explicitly registered + let tokenizerClass = TokenizerModel.knownTokenizers[tokenizerName] ?? BPETokenizer.self + if TokenizerModel.knownTokenizers[tokenizerName] == nil { + if strict { + throw TokenizerError.unsupportedTokenizer(tokenizerName) + } else { + print("Warning: Tokenizer model class \(tokenizerName) is not registered, falling back to a standard BPE implementation.") + } } - return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens) } } @@ -288,7 +293,7 @@ public class PreTrainedTokenizer: Tokenizer { /// Cache for compiled Jinja templates keyed by their literal template string private var compiledChatTemplateCache: [String: Template] = [:] - public required init(tokenizerConfig: Config, tokenizerData: Config) throws { + public required init(tokenizerConfig: Config, tokenizerData: Config, strict: Bool = true) throws { var addedTokens: [String: Int] = [:] var specialTokens: [String: Int] = [:] for addedToken in tokenizerData["addedTokens"].array(or: []) { @@ -331,7 +336,7 @@ public class PreTrainedTokenizer: Tokenizer { cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces.boolean(or: true) self.tokenizerConfig = tokenizerConfig - model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens) + model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens, strict: strict) } private func compiledTemplate(for templateString: String) throws -> Template { @@ -615,46 +620,38 @@ public extension AutoTokenizer { return PreTrainedTokenizer.self } - static func from(tokenizerConfig: Config, tokenizerData: Config) throws -> Tokenizer { + static func from(tokenizerConfig: Config, tokenizerData: Config, strict: Bool = true) throws -> Tokenizer { let tokenizerClass = tokenizerClass(for: tokenizerConfig) - return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) + return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, strict: strict) } static func from( pretrained model: String, - hubApi: HubApi = .shared + hubApi: HubApi = .shared, + strict: Bool = true ) async throws -> Tokenizer { let config = LanguageModelConfigurationFromHub(modelName: model, hubApi: hubApi) guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig } let tokenizerData = try await config.tokenizerData - return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) + return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, strict: strict) } static func from( modelFolder: URL, - hubApi: HubApi = .shared + hubApi: HubApi = .shared, + strict: Bool = true ) async throws -> Tokenizer { let config = LanguageModelConfigurationFromHub(modelFolder: modelFolder, hubApi: hubApi) guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig } let tokenizerData = try await config.tokenizerData - return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) + return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, strict: strict) } } // MARK: - Tokenizer model classes -class GPT2Tokenizer: BPETokenizer { } -class FalconTokenizer: BPETokenizer { } -class LlamaTokenizer: BPETokenizer { } -class CodeGenTokenizer: BPETokenizer { } -class WhisperTokenizer: BPETokenizer { } -class GemmaTokenizer: BPETokenizer { } -class CodeLlamaTokenizer: BPETokenizer { } -class CohereTokenizer: BPETokenizer { } -class Qwen2Tokenizer: BPETokenizer { } - class T5Tokenizer: UnigramTokenizer { } // MARK: - PreTrainedTokenizer classes @@ -707,7 +704,7 @@ func maybeUpdatePostProcessor(tokenizerConfig: Config, processorConfig: Config?) class LlamaPreTrainedTokenizer: PreTrainedTokenizer { let isLegacy: Bool - required init(tokenizerConfig: Config, tokenizerData: Config) throws { + required init(tokenizerConfig: Config, tokenizerData: Config, strict: Bool = true) throws { isLegacy = tokenizerConfig.legacy.boolean(or: true) var configDictionary = tokenizerData.dictionary(or: [:]) if !isLegacy { @@ -722,6 +719,6 @@ class LlamaPreTrainedTokenizer: PreTrainedTokenizer { } let updatedData = Config(configDictionary) - try super.init(tokenizerConfig: tokenizerConfig, tokenizerData: updatedData) + try super.init(tokenizerConfig: tokenizerConfig, tokenizerData: updatedData, strict: strict) } } diff --git a/Tests/TokenizersTests/TokenizerTests.swift b/Tests/TokenizersTests/TokenizerTests.swift index 93cf3ae1..821df186 100644 --- a/Tests/TokenizersTests/TokenizerTests.swift +++ b/Tests/TokenizersTests/TokenizerTests.swift @@ -120,6 +120,27 @@ class PhiSimpleTests: XCTestCase { } } +class UnregisteredTokenizerTests: XCTestCase { + func testNllbTokenizer() async throws { + do { + _ = try await AutoTokenizer.from(pretrained: "Xenova/nllb-200-distilled-600M") + XCTFail("Expected AutoTokenizer.from to throw for strict mode") + } catch { + // Expected to throw in normal (strict) mode + } + + // no strict mode proceeds + guard let tokenizer = try await AutoTokenizer.from(pretrained: "Xenova/nllb-200-distilled-600M", strict: false) as? PreTrainedTokenizer else { + XCTFail() + return + } + + let ids = tokenizer.encode(text: "Why did the chicken cross the road?") + let expected = [256047, 24185, 4077, 349, 1001, 22690, 83580, 349, 82801, 248130, 2] + XCTAssertEqual(ids, expected) + } +} + class LlamaPostProcessorOverrideTests: XCTestCase { /// Deepseek needs a post-processor override to add a bos token as in the reference implementation func testDeepSeek() async throws {