From 5089d8d05ded53676bbb45d0b6d666603aef74a3 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sun, 28 Sep 2025 07:05:18 -0700 Subject: [PATCH 1/3] Re-implement Top-P sampling and LogitsWarper using MLTensor --- Package.swift | 1 + Sources/Generation/Decoders.swift | 63 ++++++++++++++- Sources/Generation/Generation.swift | 56 +++++++++++-- Sources/Generation/LogitsWarper.swift | 63 +++++++++++++++ .../Generation/RepetitionPenaltyWarper.swift | 44 ++++++++++ .../Generation/TemperatureLogitsWarper.swift | 37 +++++++++ Sources/Models/LanguageModel.swift | 4 +- Tests/GenerationTests/LogitsWarperTests.swift | 70 ++++++++++++++++ Tests/GenerationTests/SamplingTests.swift | 81 +++++++++++++++++++ 9 files changed, 410 insertions(+), 9 deletions(-) create mode 100644 Sources/Generation/LogitsWarper.swift create mode 100644 Sources/Generation/RepetitionPenaltyWarper.swift create mode 100644 Sources/Generation/TemperatureLogitsWarper.swift create mode 100644 Tests/GenerationTests/LogitsWarperTests.swift create mode 100644 Tests/GenerationTests/SamplingTests.swift diff --git a/Package.swift b/Package.swift index dc6c3f5..fb69f9f 100644 --- a/Package.swift +++ b/Package.swift @@ -24,6 +24,7 @@ let package = Package( .target(name: "Hub", dependencies: [.product(name: "Jinja", package: "swift-jinja")], resources: [.process("Resources")], swiftSettings: swiftSettings), .target(name: "Models", dependencies: ["Tokenizers", "Generation"]), .target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")]), + .testTarget(name: "GenerationTests", dependencies: ["Generation"]), .testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")], swiftSettings: swiftSettings), .testTarget(name: "ModelsTests", dependencies: ["Models", "Hub"], resources: [.process("Resources")]), .testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources")]), diff --git a/Sources/Generation/Decoders.swift b/Sources/Generation/Decoders.swift index fafbd1b..db988eb 100644 --- a/Sources/Generation/Decoders.swift +++ b/Sources/Generation/Decoders.swift @@ -12,7 +12,7 @@ func selectNextTokenUsingGreedyDecoding(from scores: MLTensor) -> MLTensor { @available(macOS 15.0, iOS 18.0, *) func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float, topK: Int) -> MLTensor { - let temperatureAdjustedScores = scores / temperature + let temperatureAdjustedScores = temperature == 1.0 ? scores : scores / temperature let (topKScores, topKIndices) = temperatureAdjustedScores.topK(topK) let topKProbs = topKScores.softmax(alongAxis: -1) let rnd = topKProbs.sum() * Float.random(in: 0..<1) @@ -25,4 +25,65 @@ func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float, ) return nextTokenTensor.reshaped(to: [1, 1]) } + +// MARK: Top-P (Nucleus) Sampling + +/// Selects the next token using top-p (nucleus) sampling. +/// +/// Top-p sampling dynamically selects from the smallest possible set of words +/// whose cumulative probability exceeds the probability p. This provides more +/// diversity than top-k by adapting the vocabulary size based on the probability +/// distribution. +@available(macOS 15.0, iOS 18.0, *) +func selectNextTokenUsingTopPSampling(from scores: MLTensor, temperature: Float, topP: Double) -> MLTensor { + let temperatureAdjustedScores = temperature == 1.0 ? scores : scores / temperature + let probs = temperatureAdjustedScores.softmax(alongAxis: -1) + + // Sort probabilities in descending order by negating values first + let negatedProbs = -probs + let sortedIndices = negatedProbs.argsort(alongAxis: -1) + let sortedProbs = probs.gathering(atIndices: sortedIndices, alongAxis: -1) + + // Calculate cumulative sum + let cumProbs = sortedProbs.cumulativeSum(alongAxis: -1) + + // Find cutoff point - keep tokens where cumulative probability <= topP + let cutoffMask = cumProbs .<= Float(topP) + + // Always keep at least the first (highest probability) token + let firstToken = MLTensor(repeating: 1.0, shape: Array(cutoffMask.shape.dropLast()) + [1]) + if cutoffMask.shape.last! > 1 { + let restMask = cutoffMask[..., 1...] + let finalMask = MLTensor(concatenating: [firstToken, restMask], alongAxis: -1) + + // Apply mask to sorted probabilities + let maskedSortedProbs = finalMask * sortedProbs + + // Sample from the masked distribution + let totalMaskedProb = maskedSortedProbs.sum(alongAxes: [-1]).expandingShape(at: -1) + let normalizedProbs = maskedSortedProbs / totalMaskedProb + + let rnd = Float.random(in: 0..<1) + let cumMaskedProbs = normalizedProbs.cumulativeSum(alongAxis: -1) + var accumProbs = cumMaskedProbs + accumProbs += (accumProbs .< rnd) * 100.0 + let selectedIdx = accumProbs.argsort()[..., 0] + + let nextTokenTensor = sortedIndices.gathering( + atIndices: selectedIdx, + alongAxis: sortedIndices.rank - 1 + ) + + return nextTokenTensor.reshaped(to: [1, 1]) + } else { + // Only one token, just return it + let selectedIdx = MLTensor([Int32(0)]) + let nextTokenTensor = sortedIndices.gathering( + atIndices: selectedIdx, + alongAxis: sortedIndices.rank - 1 + ) + return nextTokenTensor.reshaped(to: [1, 1]) + } +} + #endif // canImport(CoreML) diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift index 0cdfd37..8c902ef 100644 --- a/Sources/Generation/Generation.swift +++ b/Sources/Generation/Generation.swift @@ -74,16 +74,31 @@ extension Generation { var outputTokens = MLTensor(tokens).expandingShape(at: 0) while outputTokens.shape[1] < config.maxLength { let nextTokenScores = await model(outputTokens, config) + // Apply logits processors (repetition penalty, etc.) + let processedLogits = applyLogitsProcessors( + inputIds: outputTokens, + logits: nextTokenScores, + config: config + ) + let nextToken = switch config.generationMode { case .greedy: - selectNextTokenUsingGreedyDecoding(from: nextTokenScores) + selectNextTokenUsingGreedyDecoding(from: processedLogits) case .sample: - selectNextTokenUsingTopKSampling( - from: nextTokenScores, - temperature: config.temperature, - topK: config.topK - ) + if config.topP < 1.0 { + selectNextTokenUsingTopPSampling( + from: processedLogits, + temperature: config.temperature, + topP: config.topP + ) + } else { + selectNextTokenUsingTopKSampling( + from: processedLogits, + temperature: config.temperature, + topK: config.topK + ) + } default: fatalError("Generation mode \(config.generationMode) not implemented yet") } @@ -104,6 +119,35 @@ extension Generation { private func tensorToGenerationOutput(_ tensor: MLTensor) async -> GenerationOutput { await tensor.shapedArray(of: Int32.self).scalars.map { Int($0) } } + + /// Applies configured logits processors to the raw logits. + /// + /// - Parameters: + /// - inputIds: The input token sequence + /// - logits: Raw logits from the model + /// - config: Generation configuration with processor settings + /// - Returns: Processed logits + private func applyLogitsProcessors(inputIds: MLTensor, logits: MLTensor, config: GenerationConfig) -> MLTensor { + var warpers: [LogitsWarper] = [] + + // Add temperature warper if temperature is not 1.0 + if config.temperature != 1.0 { + warpers.append(TemperatureLogitsWarper(temperature: Double(config.temperature))) + } + + // Add repetition penalty if configured + if config.repetitionPenalty != 1.0 { + warpers.append(RepetitionPenaltyWarper(penalty: config.repetitionPenalty)) + } + + // Apply all warpers if any are configured + if warpers.isEmpty { + return logits + } + + let processor = LogitsProcessor(warpers: warpers) + return processor.process(inputIds: inputIds, logits: logits) + } } @available(macOS 15.0, iOS 18.0, *) diff --git a/Sources/Generation/LogitsWarper.swift b/Sources/Generation/LogitsWarper.swift new file mode 100644 index 0000000..b1b6d25 --- /dev/null +++ b/Sources/Generation/LogitsWarper.swift @@ -0,0 +1,63 @@ +#if canImport(CoreML) +import CoreML + +/// Protocol for modifying logits before token sampling. +/// +/// Logits warpers can be used to apply various transformations to the logits +/// distribution before sampling, such as temperature scaling, top-k filtering, +/// top-p (nucleus) filtering, or repetition penalties. +@available(macOS 15.0, iOS 18.0, *) +public protocol LogitsWarper { + /// Warps (modifies) the logits before sampling. + /// + /// - Parameters: + /// - inputIds: The input token sequence used for context-dependent warping + /// - logits: The logits tensor to be modified + /// - Returns: The modified logits tensor + func warp(inputIds: MLTensor, logits: MLTensor) -> MLTensor + + /// Alternative call syntax for convenience. + func callAsFunction(inputIds: MLTensor, logits: MLTensor) -> MLTensor +} + +@available(macOS 15.0, iOS 18.0, *) +public extension LogitsWarper { + /// Default implementation of callAsFunction that delegates to warp. + func callAsFunction(inputIds: MLTensor, logits: MLTensor) -> MLTensor { + warp(inputIds: inputIds, logits: logits) + } +} + +/// A collection of logits warpers that processes logits sequentially. +@available(macOS 15.0, iOS 18.0, *) +public struct LogitsProcessor { + private let warpers: [LogitsWarper] + + /// Creates a new logits processor with the specified warpers. + /// + /// - Parameter warpers: Array of logits warpers to apply sequentially + public init(warpers: [LogitsWarper] = []) { + self.warpers = warpers + } + + /// Applies all warpers sequentially to the logits. + /// + /// - Parameters: + /// - inputIds: The input token sequence + /// - logits: The logits tensor to process + /// - Returns: The processed logits tensor + public func process(inputIds: MLTensor, logits: MLTensor) -> MLTensor { + var processedLogits = logits + for warper in warpers { + processedLogits = warper.warp(inputIds: inputIds, logits: processedLogits) + } + return processedLogits + } + + /// Alternative call syntax for convenience. + public func callAsFunction(inputIds: MLTensor, logits: MLTensor) -> MLTensor { + process(inputIds: inputIds, logits: logits) + } +} + +#endif // canImport(CoreML) diff --git a/Sources/Generation/RepetitionPenaltyWarper.swift b/Sources/Generation/RepetitionPenaltyWarper.swift new file mode 100644 index 0000000..80199c2 --- /dev/null +++ b/Sources/Generation/RepetitionPenaltyWarper.swift @@ -0,0 +1,44 @@ +#if canImport(CoreML) +import CoreML + +/// Logits warper that applies repetition penalty. +/// +/// Repetition penalty reduces the likelihood of generating tokens that have +/// already appeared in the input sequence. This helps reduce repetitive text +/// generation. +/// +/// - Note: Penalty > 1.0 penalizes repetition, penalty < 1.0 encourages it +@available(macOS 15.0, iOS 18.0, *) +public struct RepetitionPenaltyWarper: LogitsWarper { + /// The repetition penalty factor. + public let penalty: Float + + /// Creates a new repetition penalty warper. + /// + /// - Parameter penalty: Penalty factor (must be > 0). Values > 1.0 penalize repetition. + public init(penalty: Double) { + precondition(penalty > 0, "Penalty must be strictly positive") + self.penalty = Float(penalty) + } + + /// Applies repetition penalty to tokens that appear in the input sequence. + /// + /// - Parameters: + /// - inputIds: The input token sequence used to identify repeated tokens + /// - logits: The logits tensor to modify + /// - Returns: Logits with repetition penalty applied + public func warp(inputIds: MLTensor, logits: MLTensor) -> MLTensor { + if penalty == 1.0 { + return logits + } + + // TODO: Implement repetition penalty when MLTensor API allows for easier tensor updates + // For now, we'll return the original logits to avoid compilation errors + // This functionality will need to be implemented when tensor item access and update operations are available + + print("Warning: Repetition penalty is not yet implemented due to MLTensor API limitations") + return logits + } +} + +#endif // canImport(CoreML) diff --git a/Sources/Generation/TemperatureLogitsWarper.swift b/Sources/Generation/TemperatureLogitsWarper.swift new file mode 100644 index 0000000..6890627 --- /dev/null +++ b/Sources/Generation/TemperatureLogitsWarper.swift @@ -0,0 +1,37 @@ +#if canImport(CoreML) +import CoreML + +/// Logits warper that applies temperature scaling. +/// +/// Temperature scaling modifies the sharpness of the probability distribution: +/// - Temperature < 1.0: Makes the distribution more concentrated (less random) +/// - Temperature = 1.0: No change to the distribution +/// - Temperature > 1.0: Makes the distribution more uniform (more random) +@available(macOS 15.0, iOS 18.0, *) +public struct TemperatureLogitsWarper: LogitsWarper { + /// The temperature value for scaling logits. + public let temperature: Float + + /// Creates a new temperature logits warper. + /// + /// - Parameter temperature: Temperature value (must be > 0) + public init(temperature: Double) { + precondition(temperature > 0, "Temperature must be strictly positive") + self.temperature = Float(temperature) + } + + /// Applies temperature scaling to the logits. + /// + /// - Parameters: + /// - inputIds: The input token sequence (unused by temperature warper) + /// - logits: The logits tensor to scale + /// - Returns: Temperature-scaled logits + public func warp(inputIds: MLTensor, logits: MLTensor) -> MLTensor { + if temperature == 1.0 { + return logits + } + return logits / temperature + } +} + +#endif // canImport(CoreML) diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index 044fd19..e93869d 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -129,7 +129,7 @@ extension LanguageModel { static let valueCache = "valueCache" // Output keys static let logits = "logits" - static let presentKeys = "presentKeys" + static let present = "presentKeys" static let presentValues = "presentValues" } } @@ -265,7 +265,7 @@ public extension LanguageModel { } let kCacheInput = model.modelDescription.inputDescriptionsByName[Keys.keyCache] != nil let vCacheInput = model.modelDescription.inputDescriptionsByName[Keys.valueCache] != nil - let kCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.presentKeys] != nil + let kCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.present] != nil let vCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.presentValues] != nil guard Set([kCacheInput, vCacheInput, kCacheOutput, vCacheOutput]).count == 1 else { diff --git a/Tests/GenerationTests/LogitsWarperTests.swift b/Tests/GenerationTests/LogitsWarperTests.swift new file mode 100644 index 0000000..206f884 --- /dev/null +++ b/Tests/GenerationTests/LogitsWarperTests.swift @@ -0,0 +1,70 @@ +import CoreML +import Testing + +@testable import Generation + +#if canImport(CoreML) +@Suite("Logits Warper Tests") +struct LogitsWarperTests { + + @Test("Temperature warper scaling") + @available(macOS 15.0, iOS 18.0, *) + func temperatureWarper() { + let logits = MLTensor([[1.0, 2.0, 3.0]]) + let inputIds = MLTensor([[1, 2]]) + + let tempWarper = TemperatureLogitsWarper(temperature: 2.0) + let warpedLogits = tempWarper.warp(inputIds: inputIds, logits: logits) + + #expect(warpedLogits.shape == logits.shape) + + let identityWarper = TemperatureLogitsWarper(temperature: 1.0) + let unchangedLogits = identityWarper.warp(inputIds: inputIds, logits: logits) + #expect(unchangedLogits.shape == logits.shape) + } + + @Test("LogitsProcessor with multiple warpers") + @available(macOS 15.0, iOS 18.0, *) + func logitsProcessor() { + let logits = MLTensor([[1.0, 2.0, 3.0]]) + let inputIds = MLTensor([[1, 2]]) + + let warpers: [LogitsWarper] = [ + TemperatureLogitsWarper(temperature: 2.0) + ] + + let processor = LogitsProcessor(warpers: warpers) + let processedLogits = processor.process(inputIds: inputIds, logits: logits) + + #expect(processedLogits.shape == logits.shape) + } + + @Test("LogitsProcessor with no warpers") + @available(macOS 15.0, iOS 18.0, *) + func logitsProcessorEmpty() { + let logits = MLTensor([[1.0, 2.0, 3.0]]) + let inputIds = MLTensor([[1, 2]]) + + let processor = LogitsProcessor(warpers: []) + let processedLogits = processor.process(inputIds: inputIds, logits: logits) + + #expect(processedLogits.shape == logits.shape) + } + + @Test("Repetition penalty warper") + @available(macOS 15.0, iOS 18.0, *) + func repetitionPenaltyWarper() { + let logits = MLTensor([[1.0, 2.0, 3.0]]) + let inputIds = MLTensor([[0, 1]]) + + let repWarper = RepetitionPenaltyWarper(penalty: 1.2) + let warpedLogits = repWarper.warp(inputIds: inputIds, logits: logits) + + #expect(warpedLogits.shape == logits.shape) + + let identityWarper = RepetitionPenaltyWarper(penalty: 1.0) + let unchangedLogits = identityWarper.warp(inputIds: inputIds, logits: logits) + #expect(unchangedLogits.shape == logits.shape) + } +} +#endif diff --git a/Tests/GenerationTests/SamplingTests.swift b/Tests/GenerationTests/SamplingTests.swift new file mode 100644 index 0000000..8fbbdb1 --- /dev/null +++ b/Tests/GenerationTests/SamplingTests.swift @@ -0,0 +1,81 @@ +import CoreML +import Testing + +@testable import Generation + +#if canImport(CoreML) +@Suite("Sampling Tests") +struct SamplingTests { + @Test + @available(macOS 15.0, iOS 18.0, *) + func testTopKSampling() { + let logits = MLTensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + + let result = selectNextTokenUsingTopKSampling( + from: logits, + temperature: 1.0, + topK: 3 + ) + + #expect(result.shape == [1, 1]) + } + + @Test + @available(macOS 15.0, iOS 18.0, *) + func testTopPSampling() { + let logits = MLTensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + + let result = selectNextTokenUsingTopPSampling( + from: logits, + temperature: 1.0, + topP: 0.9 + ) + + #expect(result.shape == [1, 1]) + } + + @Test + @available(macOS 15.0, iOS 18.0, *) + func testTopPSamplingWithHighP() { + let logits = MLTensor([[1.0, 2.0, 3.0]]) + + let result = selectNextTokenUsingTopPSampling( + from: logits, + temperature: 1.0, + topP: 1.0 + ) + + #expect(result.shape == [1, 1]) + } + + @Test + @available(macOS 15.0, iOS 18.0, *) + func testGreedyDecoding() { + let logits = MLTensor([[1.0, 3.0, 2.0]]) + + let result = selectNextTokenUsingGreedyDecoding(from: logits) + + #expect(result.shape == [1, 1]) + } + + @Test + @available(macOS 15.0, iOS 18.0, *) + func testTemperatureScaling() { + let logits = MLTensor([[1.0, 2.0, 3.0]]) + + let highTempResult = selectNextTokenUsingTopKSampling( + from: logits, + temperature: 2.0, + topK: 2 + ) + #expect(highTempResult.shape == [1, 1]) + + let lowTempResult = selectNextTokenUsingTopKSampling( + from: logits, + temperature: 0.5, + topK: 2 + ) + #expect(lowTempResult.shape == [1, 1]) + } +} +#endif From bc6669c674dccc6c5b181fdbff1b64d96b669fe4 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sun, 28 Sep 2025 07:24:30 -0700 Subject: [PATCH 2/3] Fixup presentKeys --- Sources/Models/LanguageModel.swift | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index e93869d..d06237c 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -129,7 +129,8 @@ extension LanguageModel { static let valueCache = "valueCache" // Output keys static let logits = "logits" - static let present = "presentKeys" + // swift-format-ignore: DontRepeatTypeInStaticProperties + static let presentKeys = "presentKeys" static let presentValues = "presentValues" } } @@ -265,7 +266,7 @@ public extension LanguageModel { } let kCacheInput = model.modelDescription.inputDescriptionsByName[Keys.keyCache] != nil let vCacheInput = model.modelDescription.inputDescriptionsByName[Keys.valueCache] != nil - let kCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.present] != nil + let kCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.presentKeys] != nil let vCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.presentValues] != nil guard Set([kCacheInput, vCacheInput, kCacheOutput, vCacheOutput]).count == 1 else { From f38cb84526353fdcbcefa7a26349ab919deb5411 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sun, 28 Sep 2025 07:13:27 -0700 Subject: [PATCH 3/3] Disable DontRepeatTypeInStaticProperties rule --- .swift-format | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.swift-format b/.swift-format index 7e6f18f..380162f 100644 --- a/.swift-format +++ b/.swift-format @@ -33,7 +33,7 @@ "AvoidRetroactiveConformances": true, "BeginDocumentationCommentWithOneLineSummary": false, "DoNotUseSemicolons": false, - "DontRepeatTypeInStaticProperties": true, + "DontRepeatTypeInStaticProperties": false, "FileScopedDeclarationPrivacy": true, "FullyIndirectEnum": true, "GroupNumericLiterals": false,