diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index 1f6607c2..b410f3e5 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -285,6 +285,9 @@ public class PreTrainedTokenizer: Tokenizer { private let cleanUpTokenizationSpaces: Bool + /// Cache for compiled Jinja templates keyed by their literal template string + private var compiledChatTemplateCache: [String: Template] = [:] + public required init(tokenizerConfig: Config, tokenizerData: Config) throws { var addedTokens: [String: Int] = [:] var specialTokens: [String: Int] = [:] @@ -332,6 +335,15 @@ public class PreTrainedTokenizer: Tokenizer { model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens) } + private func compiledTemplate(for templateString: String) throws -> Template { + if let cached = compiledChatTemplateCache[templateString] { + return cached + } + let compiled = try Template(templateString) + compiledChatTemplateCache[templateString] = compiled + return compiled + } + func preTokenize(_ text: String, options: PreTokenizerOptions) -> [String] { guard let preTokenizer else { return [text] } return preTokenizer(text: text, options: options) @@ -530,7 +542,7 @@ public class PreTrainedTokenizer: Tokenizer { throw TokenizerError.missingChatTemplate } - let template = try Template(selectedChatTemplate) + let template = try compiledTemplate(for: selectedChatTemplate) var context: [String: Any] = [ "messages": messages, "add_generation_prompt": addGenerationPrompt, diff --git a/Tests/TokenizersTests/ChatTemplateTests.swift b/Tests/TokenizersTests/ChatTemplateTests.swift index eb18fc8b..4887d457 100644 --- a/Tests/TokenizersTests/ChatTemplateTests.swift +++ b/Tests/TokenizersTests/ChatTemplateTests.swift @@ -5,6 +5,7 @@ // Created by Anthony DePasquale on 2/10/24. // +import Foundation import Tokenizers import XCTest @@ -277,4 +278,32 @@ class ChatTemplateTests: XCTestCase { } } } + + /// Performance: cached vs uncached template application + func testApplyChatTemplatePerformanceCached() async throws { + let tokenizer = try await Self.sharedPhiTokenizer() + + // Purposely reuse the same template literal to hit the memoized compiled template + let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + + // Prime cache once + _ = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate) + + measure(metrics: [XCTClockMetric()]) { + _ = try! tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate) + } + } + + /// Performance: simulate uncached runs by varying the template to bypass memoization + func testApplyChatTemplatePerformanceUncached() async throws { + let tokenizer = try await Self.sharedPhiTokenizer() + + let baseTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + + measure(metrics: [XCTClockMetric()]) { + // Make the template string unique each iteration to force a fresh compilation + let uniqueTemplate = baseTemplate + "{# perf \(UUID().uuidString) #}" + _ = try! tokenizer.applyChatTemplate(messages: messages, chatTemplate: uniqueTemplate) + } + } }