Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 30 additions & 33 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -102,37 +102,42 @@ public protocol PreTrainedTokenizerModel: TokenizingModel {
struct TokenizerModel {
static let knownTokenizers: [String: PreTrainedTokenizerModel.Type] = [
"BertTokenizer": BertTokenizer.self,
"CodeGenTokenizer": BPETokenizer.self,
"CodeLlamaTokenizer": BPETokenizer.self,
"CohereTokenizer": BPETokenizer.self,
"DistilbertTokenizer": BertTokenizer.self,
"DistilBertTokenizer": BertTokenizer.self,
"FalconTokenizer": BPETokenizer.self,
"GemmaTokenizer": BPETokenizer.self,
"GPT2Tokenizer": BPETokenizer.self,
"LlamaTokenizer": BPETokenizer.self,
"RobertaTokenizer": BPETokenizer.self,
"CodeGenTokenizer": CodeGenTokenizer.self,
"CodeLlamaTokenizer": CodeLlamaTokenizer.self,
"FalconTokenizer": FalconTokenizer.self,
"GemmaTokenizer": GemmaTokenizer.self,
"GPT2Tokenizer": GPT2Tokenizer.self,
"LlamaTokenizer": LlamaTokenizer.self,
"T5Tokenizer": T5Tokenizer.self,
"WhisperTokenizer": WhisperTokenizer.self,
"CohereTokenizer": CohereTokenizer.self,
"Qwen2Tokenizer": Qwen2Tokenizer.self,
"PreTrainedTokenizer": BPETokenizer.self,
"Qwen2Tokenizer": BPETokenizer.self,
"WhisperTokenizer": BPETokenizer.self,
]

static func unknownToken(from tokenizerConfig: Config) -> String? {
tokenizerConfig.unkToken.content.string() ?? tokenizerConfig.unkToken.string()
}

static func from(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws -> TokenizingModel {
static func from(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int], strict: Bool = true) throws -> TokenizingModel {
guard let tokenizerClassName = tokenizerConfig.tokenizerClass.string() else {
throw TokenizerError.missingTokenizerClassInConfig
}

// Some tokenizer_class entries use a Fast suffix
let tokenizerName = tokenizerClassName.replacingOccurrences(of: "Fast", with: "")
guard let tokenizerClass = TokenizerModel.knownTokenizers[tokenizerName] else {
throw TokenizerError.unsupportedTokenizer(tokenizerName)
// Fallback to BPETokenizer if class is not explicitly registered
let tokenizerClass = TokenizerModel.knownTokenizers[tokenizerName] ?? BPETokenizer.self
if TokenizerModel.knownTokenizers[tokenizerName] == nil {
if strict {
throw TokenizerError.unsupportedTokenizer(tokenizerName)
} else {
print("Warning: Tokenizer model class \(tokenizerName) is not registered, falling back to a standard BPE implementation.")
}
}

return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
}
}
Expand Down Expand Up @@ -288,7 +293,7 @@ public class PreTrainedTokenizer: Tokenizer {
/// Cache for compiled Jinja templates keyed by their literal template string
private var compiledChatTemplateCache: [String: Template] = [:]

public required init(tokenizerConfig: Config, tokenizerData: Config) throws {
public required init(tokenizerConfig: Config, tokenizerData: Config, strict: Bool = true) throws {
var addedTokens: [String: Int] = [:]
var specialTokens: [String: Int] = [:]
for addedToken in tokenizerData["addedTokens"].array(or: []) {
Expand Down Expand Up @@ -331,7 +336,7 @@ public class PreTrainedTokenizer: Tokenizer {
cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces.boolean(or: true)
self.tokenizerConfig = tokenizerConfig

model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens, strict: strict)
}

private func compiledTemplate(for templateString: String) throws -> Template {
Expand Down Expand Up @@ -615,46 +620,38 @@ public extension AutoTokenizer {
return PreTrainedTokenizer.self
}

static func from(tokenizerConfig: Config, tokenizerData: Config) throws -> Tokenizer {
static func from(tokenizerConfig: Config, tokenizerData: Config, strict: Bool = true) throws -> Tokenizer {
let tokenizerClass = tokenizerClass(for: tokenizerConfig)
return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, strict: strict)
}

static func from(
pretrained model: String,
hubApi: HubApi = .shared
hubApi: HubApi = .shared,
strict: Bool = true
) async throws -> Tokenizer {
let config = LanguageModelConfigurationFromHub(modelName: model, hubApi: hubApi)
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
let tokenizerData = try await config.tokenizerData

return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, strict: strict)
}

static func from(
modelFolder: URL,
hubApi: HubApi = .shared
hubApi: HubApi = .shared,
strict: Bool = true
) async throws -> Tokenizer {
let config = LanguageModelConfigurationFromHub(modelFolder: modelFolder, hubApi: hubApi)
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
let tokenizerData = try await config.tokenizerData

return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, strict: strict)
}
}

// MARK: - Tokenizer model classes

class GPT2Tokenizer: BPETokenizer { }
class FalconTokenizer: BPETokenizer { }
class LlamaTokenizer: BPETokenizer { }
class CodeGenTokenizer: BPETokenizer { }
class WhisperTokenizer: BPETokenizer { }
class GemmaTokenizer: BPETokenizer { }
class CodeLlamaTokenizer: BPETokenizer { }
class CohereTokenizer: BPETokenizer { }
class Qwen2Tokenizer: BPETokenizer { }

class T5Tokenizer: UnigramTokenizer { }

// MARK: - PreTrainedTokenizer classes
Expand Down Expand Up @@ -707,7 +704,7 @@ func maybeUpdatePostProcessor(tokenizerConfig: Config, processorConfig: Config?)
class LlamaPreTrainedTokenizer: PreTrainedTokenizer {
let isLegacy: Bool

required init(tokenizerConfig: Config, tokenizerData: Config) throws {
required init(tokenizerConfig: Config, tokenizerData: Config, strict: Bool = true) throws {
isLegacy = tokenizerConfig.legacy.boolean(or: true)
var configDictionary = tokenizerData.dictionary(or: [:])
if !isLegacy {
Expand All @@ -722,6 +719,6 @@ class LlamaPreTrainedTokenizer: PreTrainedTokenizer {
}

let updatedData = Config(configDictionary)
try super.init(tokenizerConfig: tokenizerConfig, tokenizerData: updatedData)
try super.init(tokenizerConfig: tokenizerConfig, tokenizerData: updatedData, strict: strict)
}
}
21 changes: 21 additions & 0 deletions Tests/TokenizersTests/TokenizerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,27 @@ class PhiSimpleTests: XCTestCase {
}
}

class UnregisteredTokenizerTests: XCTestCase {
func testNllbTokenizer() async throws {
do {
_ = try await AutoTokenizer.from(pretrained: "Xenova/nllb-200-distilled-600M")
XCTFail("Expected AutoTokenizer.from to throw for strict mode")
} catch {
// Expected to throw in normal (strict) mode
}

// no strict mode proceeds
guard let tokenizer = try await AutoTokenizer.from(pretrained: "Xenova/nllb-200-distilled-600M", strict: false) as? PreTrainedTokenizer else {
XCTFail()
return
}

let ids = tokenizer.encode(text: "Why did the chicken cross the road?")
let expected = [256047, 24185, 4077, 349, 1001, 22690, 83580, 349, 82801, 248130, 2]
XCTAssertEqual(ids, expected)
}
}

class LlamaPostProcessorOverrideTests: XCTestCase {
/// Deepseek needs a post-processor override to add a bos token as in the reference implementation
func testDeepSeek() async throws {
Expand Down