Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ public class PreTrainedTokenizer: Tokenizer {

private let cleanUpTokenizationSpaces: Bool

/// Cache for compiled Jinja templates keyed by their literal template string
private var compiledChatTemplateCache: [String: Template] = [:]

public required init(tokenizerConfig: Config, tokenizerData: Config) throws {
var addedTokens: [String: Int] = [:]
var specialTokens: [String: Int] = [:]
Expand Down Expand Up @@ -332,6 +335,15 @@ public class PreTrainedTokenizer: Tokenizer {
model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
}

private func compiledTemplate(for templateString: String) throws -> Template {
if let cached = compiledChatTemplateCache[templateString] {
return cached
}
let compiled = try Template(templateString)
compiledChatTemplateCache[templateString] = compiled
return compiled
}

func preTokenize(_ text: String, options: PreTokenizerOptions) -> [String] {
guard let preTokenizer else { return [text] }
return preTokenizer(text: text, options: options)
Expand Down Expand Up @@ -530,7 +542,7 @@ public class PreTrainedTokenizer: Tokenizer {
throw TokenizerError.missingChatTemplate
}

let template = try Template(selectedChatTemplate)
let template = try compiledTemplate(for: selectedChatTemplate)
var context: [String: Any] = [
"messages": messages,
"add_generation_prompt": addGenerationPrompt,
Expand Down
29 changes: 29 additions & 0 deletions Tests/TokenizersTests/ChatTemplateTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// Created by Anthony DePasquale on 2/10/24.
//

import Foundation
import Tokenizers
import XCTest

Expand Down Expand Up @@ -277,4 +278,32 @@ class ChatTemplateTests: XCTestCase {
}
}
}

/// Performance: cached vs uncached template application
func testApplyChatTemplatePerformanceCached() async throws {
let tokenizer = try await Self.sharedPhiTokenizer()

// Purposely reuse the same template literal to hit the memoized compiled template
let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"

// Prime cache once
_ = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate)

measure(metrics: [XCTClockMetric()]) {
_ = try! tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate)
}
}

/// Performance: simulate uncached runs by varying the template to bypass memoization
func testApplyChatTemplatePerformanceUncached() async throws {
let tokenizer = try await Self.sharedPhiTokenizer()

let baseTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"

measure(metrics: [XCTClockMetric()]) {
// Make the template string unique each iteration to force a fresh compilation
let uniqueTemplate = baseTemplate + "{# perf \(UUID().uuidString) #}"
_ = try! tokenizer.applyChatTemplate(messages: messages, chatTemplate: uniqueTemplate)
}
}
}
Loading