From 486e58ffcb70ea842411d711b47a3e19620211a9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 27 Sep 2025 22:21:39 +0200 Subject: [PATCH 01/14] inital --- Package.swift | 1 + Sources/Generation/Decoders.swift | 36 +++- Sources/Generation/Generation.swift | 48 ++++- .../LogitsWarper/LogitsProcessor.swift | 53 +++++ .../RepetitionPenaltyLogitsProcessor.swift | 74 +++++++ .../TemperatureLogitsWarper.swift | 31 +++ .../LogitsWarper/TopKLogitsWarper.swift | 74 +++++++ .../LogitsWarper/TopPLogitsWarper.swift | 104 +++++++++ Sources/Models/LanguageModel.swift | 2 +- .../LogitsProcessorTests.swift | 203 ++++++++++++++++++ 10 files changed, 618 insertions(+), 8 deletions(-) create mode 100644 Sources/Generation/LogitsWarper/LogitsProcessor.swift create mode 100644 Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift create mode 100644 Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift create mode 100644 Sources/Generation/LogitsWarper/TopKLogitsWarper.swift create mode 100644 Sources/Generation/LogitsWarper/TopPLogitsWarper.swift create mode 100644 Tests/GenerationTests/LogitsProcessorTests.swift diff --git a/Package.swift b/Package.swift index dc6c3f5..fb69f9f 100644 --- a/Package.swift +++ b/Package.swift @@ -24,6 +24,7 @@ let package = Package( .target(name: "Hub", dependencies: [.product(name: "Jinja", package: "swift-jinja")], resources: [.process("Resources")], swiftSettings: swiftSettings), .target(name: "Models", dependencies: ["Tokenizers", "Generation"]), .target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")]), + .testTarget(name: "GenerationTests", dependencies: ["Generation"]), .testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")], swiftSettings: swiftSettings), .testTarget(name: "ModelsTests", dependencies: ["Models", "Hub"], resources: [.process("Resources")]), .testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources")]), diff --git a/Sources/Generation/Decoders.swift b/Sources/Generation/Decoders.swift index fafbd1b..81a52d7 100644 --- a/Sources/Generation/Decoders.swift +++ b/Sources/Generation/Decoders.swift @@ -8,8 +8,42 @@ func selectNextTokenUsingGreedyDecoding(from scores: MLTensor) -> MLTensor { scores.argmax(alongAxis: -1).reshaped(to: [1, 1]) } -// MARK: Top-K Sampling +// MARK: Sampling +/// Performs multinomial sampling from processed logits. +/// +/// Assumes logits have already been processed by LogitsProcessorList +/// (temperature, top-k, top-p, etc. already applied). +/// +/// - Parameter scores: Processed logits tensor [batch_size, vocab_size] +/// - Returns: Sampled token ID tensor [batch_size, 1] +@available(macOS 15.0, iOS 18.0, *) +func selectNextTokenUsingSampling(from scores: MLTensor) -> MLTensor { + // Convert logits to probabilities + let probs = scores.softmax(alongAxis: -1) + + // Multinomial sampling using cumulative sum method + // 1. Generate random number in [0, 1) + // 2. Compute cumulative sum of probabilities + // 3. Find first index where cumsum >= random_number + + let rnd = Float.random(in: 0..<1) + var cumulativeProbs = probs.cumulativeSum(alongAxis: -1) + + // Mark all positions where cumsum >= rnd with a large value + // Then argsort will give us the first such position + cumulativeProbs = cumulativeProbs + (cumulativeProbs .< rnd) * 100.0 + + let sampledIndex = cumulativeProbs.argsort(alongAxis: -1)[..., 0] + return sampledIndex.reshaped(to: [1, 1]) +} + +// MARK: Legacy Top-K Sampling (deprecated, use LogitsProcessorList instead) + +/// Legacy top-k sampling function that combines temperature, top-k, and sampling. +/// +/// - Note: This function is deprecated. Use `selectNextTokenUsingSampling` with +/// `TemperatureLogitsWarper` and `TopKLogitsWarper` in a `LogitsProcessorList` instead. @available(macOS 15.0, iOS 18.0, *) func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float, topK: Int) -> MLTensor { let temperatureAdjustedScores = scores / temperature diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift index 0cdfd37..45d8908 100644 --- a/Sources/Generation/Generation.swift +++ b/Sources/Generation/Generation.swift @@ -72,18 +72,24 @@ extension Generation { ) async -> GenerationOutput { let tokens = tokens.map { Int32($0) } var outputTokens = MLTensor(tokens).expandingShape(at: 0) + + // Create logits processor list based on config + let logitsProcessorList = createLogitsProcessorList(config: config) + while outputTokens.shape[1] < config.maxLength { + // Get raw logits from model let nextTokenScores = await model(outputTokens, config) + + // Apply logits processors + let processedScores = await logitsProcessorList(outputTokens, nextTokenScores) + + // Select next token based on generation mode let nextToken = switch config.generationMode { case .greedy: - selectNextTokenUsingGreedyDecoding(from: nextTokenScores) + selectNextTokenUsingGreedyDecoding(from: processedScores) case .sample: - selectNextTokenUsingTopKSampling( - from: nextTokenScores, - temperature: config.temperature, - topK: config.topK - ) + selectNextTokenUsingSampling(from: processedScores) default: fatalError("Generation mode \(config.generationMode) not implemented yet") } @@ -101,6 +107,36 @@ extension Generation { return await tensorToGenerationOutput(outputTokens) } + /// Creates a list of logits processors based on generation configuration. + /// + /// - Parameter config: Generation configuration specifying which processors to apply + /// - Returns: List of logits processors to apply during generation + private func createLogitsProcessorList(config: GenerationConfig) -> LogitsProcessorList { + var processors: [any LogitsProcessor] = [] + + // Temperature scaling (if not default) + if config.temperature > 0 && config.temperature != 1.0 { + processors.append(TemperatureLogitsWarper(temperature: config.temperature)) + } + + // Top-K filtering + if config.topK > 0 && config.topK < Int.max { + processors.append(TopKLogitsWarper(topK: config.topK)) + } + + // Top-P (nucleus) sampling + if config.topP < 1.0 { + processors.append(TopPLogitsWarper(topP: Float(config.topP))) + } + + // Repetition penalty + if config.repetitionPenalty != 1.0 { + processors.append(RepetitionPenaltyLogitsProcessor(penalty: Float(config.repetitionPenalty))) + } + + return LogitsProcessorList(processors: processors) + } + private func tensorToGenerationOutput(_ tensor: MLTensor) async -> GenerationOutput { await tensor.shapedArray(of: Int32.self).scalars.map { Int($0) } } diff --git a/Sources/Generation/LogitsWarper/LogitsProcessor.swift b/Sources/Generation/LogitsWarper/LogitsProcessor.swift new file mode 100644 index 0000000..9af7091 --- /dev/null +++ b/Sources/Generation/LogitsWarper/LogitsProcessor.swift @@ -0,0 +1,53 @@ +#if canImport(CoreML) +import CoreML + +/// Abstract base class for all logits processors that can be applied during generation. +/// +/// Logits processors modify the probability distribution over vocabulary tokens by transforming +/// the raw logit scores produced by language models. This enables various sampling strategies +/// such as temperature scaling, top-k/top-p filtering, and repetition penalties. +/// +/// Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py +@available(macOS 15.0, iOS 18.0, *) +public protocol LogitsProcessor { + /// Processes logits for next token prediction. + /// + /// - Parameters: + /// - inputIds: Tensor of input token IDs with shape `[batch_size, sequence_length]` + /// - scores: Tensor of raw logit scores with shape `[batch_size, vocab_size]` + /// - Returns: Processed logits tensor with shape `[batch_size, vocab_size]` + /// + /// - Note: The `inputIds` parameter provides context for processors that need to examine + /// the generated sequence (e.g., repetition penalty). Processors that don't need this + /// context (e.g., temperature) can ignore it. + func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor +} + +/// A list of logits processors that applies each processor sequentially. +/// +/// This class provides a convenient way to chain multiple logits processors together. +/// Each processor is applied in order to the logits tensor, with the output of one +/// processor becoming the input to the next. +@available(macOS 15.0, iOS 18.0, *) +public struct LogitsProcessorList { + public var processors: [any LogitsProcessor] + + public init(processors: [any LogitsProcessor]) { + self.processors = processors + } + + /// Applies all logits processors sequentially to the input scores. + /// + /// - Parameters: + /// - inputIds: Tensor of input token IDs with shape `[batch_size, sequence_length]` + /// - scores: Tensor of raw logit scores with shape `[batch_size, vocab_size]` + /// - Returns: Processed logits tensor with shape `[batch_size, vocab_size]` + public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { + var processedScores = scores + for processor in processors { + processedScores = await processor(inputIds, processedScores) + } + return processedScores + } +} +#endif \ No newline at end of file diff --git a/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift b/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift new file mode 100644 index 0000000..e2f4215 --- /dev/null +++ b/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift @@ -0,0 +1,74 @@ +#if canImport(CoreML) +import CoreML + +/// LogitsProcessor that prevents repetition of previous tokens through a penalty. +/// +/// For each token that has already appeared in the sequence: +/// - If the token's logit is negative: multiply by penalty (further suppresses it) +/// - If the token's logit is positive: divide by penalty (suppresses it) +/// +/// This penalty is applied at most once per token, regardless of how many times it appears. +/// +/// Typical penalty values: +/// - 1.0: No penalty +/// - > 1.0: Penalize repetition (e.g., 1.2 for balanced generation) +/// - 0.0 - 1.0: Encourage repetition +/// +/// Based on: +/// - https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L297 +/// - Paper: https://arxiv.org/abs/1909.05858 +@available(macOS 15.0, iOS 18.0, *) +public struct RepetitionPenaltyLogitsProcessor: LogitsProcessor { + public let penalty: Float + + /// Creates a repetition penalty logits processor. + /// + /// - Parameter penalty: Penalty factor. Values > 1.0 penalize repetition, values < 1.0 encourage it. + public init(penalty: Float) { + precondition(penalty > 0, "penalty must be strictly positive, got \(penalty)") + self.penalty = penalty + } + + public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { + guard penalty != 1.0 else { return scores } + + // Implementation approach (following transformers): + // 1. Get unique token IDs from inputIds + // 2. For each unique token, gather its logit value + // 3. Apply conditional penalty: if logit < 0: *= penalty, else: /= penalty + // 4. Scatter penalized values back to original positions + + // Convert to CPU for gather/scatter operations + let scoresArray = await scores.shapedArray(of: Float.self) + let inputIdsArray = await inputIds.shapedArray(of: Int32.self) + + // Process each batch item + var scoresData = scoresArray.scalars + let batchSize = scores.shape[0] + let vocabSize = scores.shape[1] + + for batchIdx in 0..= 0 && tokenId < vocabSize else { continue } + + let scoreIdx = seqStart + tokenId + let score = scoresData[scoreIdx] + + // Apply penalty based on sign (following transformers implementation) + scoresData[scoreIdx] = score < 0 ? score * penalty : score / penalty + } + } + + // Create new tensor with penalized scores + return MLTensor(shape: scores.shape, scalars: scoresData, scalarType: Float.self) + } +} +#endif \ No newline at end of file diff --git a/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift b/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift new file mode 100644 index 0000000..6707cbc --- /dev/null +++ b/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift @@ -0,0 +1,31 @@ +#if canImport(CoreML) +import CoreML + +/// LogitsProcessor for temperature scaling, which effectively controls the randomness +/// of predicted tokens by modulating the logits distribution. +/// +/// Temperature < 1.0 makes the model more confident (sharper distribution). +/// Temperature > 1.0 makes the model less confident (flatter distribution). +/// Temperature = 1.0 leaves the distribution unchanged. +/// +/// Often used together with `TopPLogitsWarper` and `TopKLogitsWarper`. +/// +/// Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L231 +@available(macOS 15.0, iOS 18.0, *) +public struct TemperatureLogitsWarper: LogitsProcessor { + public let temperature: Float + + /// Creates a temperature logits warper. + /// + /// - Parameter temperature: Strictly positive float value used to modulate the logits distribution. + /// Must be > 0. Values close to 0 approximate greedy decoding. + public init(temperature: Float) { + precondition(temperature > 0, "temperature must be strictly positive, got \(temperature)") + self.temperature = temperature + } + + public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { + scores / temperature + } +} +#endif \ No newline at end of file diff --git a/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift new file mode 100644 index 0000000..6a3f832 --- /dev/null +++ b/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift @@ -0,0 +1,74 @@ +#if canImport(CoreML) +import CoreML + +/// LogitsProcessor that performs top-k filtering, restricting to the k highest probability elements. +/// +/// Filters out all tokens except the k most likely ones by setting their logits to -inf. +/// This reduces the risk of low-probability tokens being sampled. +/// +/// Often used together with `TemperatureLogitsWarper` and `TopPLogitsWarper`. +/// Pro tip: In practice, LLMs use top_k in the 5-50 range. +/// +/// Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L532 +@available(macOS 15.0, iOS 18.0, *) +public struct TopKLogitsWarper: LogitsProcessor { + public let topK: Int + public let filterValue: Float + public let minTokensToKeep: Int + + /// Creates a top-k logits warper. + /// + /// - Parameters: + /// - topK: Number of highest probability tokens to keep + /// - filterValue: Value to set filtered tokens to (default: -infinity) + /// - minTokensToKeep: Minimum tokens that cannot be filtered (default: 1) + public init(topK: Int, filterValue: Float = -.infinity, minTokensToKeep: Int = 1) { + precondition(topK > 0, "topK must be strictly positive, got \(topK)") + precondition(minTokensToKeep >= 1, "minTokensToKeep must be at least 1, got \(minTokensToKeep)") + + self.topK = max(topK, minTokensToKeep) + self.filterValue = filterValue + self.minTokensToKeep = minTokensToKeep + } + + public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { + let vocabSize = scores.shape[scores.rank - 1] + let k = min(topK, vocabSize) // Safety check + + // Get the k-th highest score (the threshold) + let (topKValues, _) = scores.topK(k) + + // The threshold is the smallest value in the top-k (last element) + // We need to get the value at index [k-1] along the last dimension + let thresholdScores = await topKValues.shapedArray(of: Float.self) + + // For each batch item, get the k-th largest score + let batchSize = scores.shape[0] + var thresholds = [Float]() + + for batchIdx in 0..= 0 && topP <= 1.0, "topP must be in [0, 1], got \(topP)") + precondition(minTokensToKeep >= 1, "minTokensToKeep must be at least 1, got \(minTokensToKeep)") + + self.topP = topP + self.filterValue = filterValue + self.minTokensToKeep = minTokensToKeep + } + + public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { + // Algorithm (following transformers implementation): + // 1. Sort logits in ascending order (transformers uses descending=False) + // 2. Compute softmax probabilities on sorted logits + // 3. Compute cumulative sum of probabilities + // 4. Remove tokens with cumulative probability <= (1 - top_p) + // 5. Keep at least min_tokens_to_keep + // 6. Scatter mask back to original indexing + + let batchSize = scores.shape[0] + let vocabSize = scores.shape[scores.rank - 1] + + // Convert to CPU for processing + let scoresArray = await scores.shapedArray(of: Float.self) + var maskedScores = scoresArray.scalars + + // Process each batch item + for batchIdx in 0.. (1 - top_p) + let threshold = 1.0 - topP + var indicesToRemove = Set() + + for i in 0.. 0, "Top-P should filter some low-probability tokens") + } + + func testTopPWarperWithHighThreshold() async throws { + // With topP=0.99, almost all tokens should be kept + let warper = TopPLogitsWarper(topP: 0.99) + + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 5], scalars: [Float(1.0), Float(2.0), Float(3.0), Float(4.0), Float(5.0)], scalarType: Float.self) + + let result = await warper(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + // With high topP and relatively uniform distribution, most tokens should be kept + let keptCount = resultArray.filter { !($0.isInfinite && $0 < 0) }.count + XCTAssertTrue(keptCount >= 4, "High topP should keep most tokens") + } + + // MARK: - Repetition Penalty Tests + + func testRepetitionPenaltyProcessor() async throws { + let processor = RepetitionPenaltyLogitsProcessor(penalty: 2.0) + + // Input sequence with tokens [1, 2, 3] + let inputIds = MLTensor(shape: [1, 3], scalars: [Int32(1), Int32(2), Int32(3)], scalarType: Int32.self) + + // Scores for vocab of size 5: [0.5, -0.5, 1.0, -1.0, 2.0] + let scores = MLTensor(shape: [1, 5], scalars: [Float(0.5), Float(-0.5), Float(1.0), Float(-1.0), Float(2.0)], scalarType: Float.self) + + let result = await processor(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + // Token 0: not in sequence, unchanged + XCTAssertEqual(resultArray[0], 0.5, accuracy: accuracy, "Token 0 should be unchanged") + + // Token 1 (score -0.5 < 0): multiplied by penalty = -1.0 + XCTAssertEqual(resultArray[1], -1.0, accuracy: accuracy, "Token 1 should be penalized (negative)") + + // Token 2 (score 1.0 > 0): divided by penalty = 0.5 + XCTAssertEqual(resultArray[2], 0.5, accuracy: accuracy, "Token 2 should be penalized (positive)") + + // Token 3 (score -1.0 < 0): multiplied by penalty = -2.0 + XCTAssertEqual(resultArray[3], -2.0, accuracy: accuracy, "Token 3 should be penalized (negative)") + + // Token 4: not in sequence, unchanged + XCTAssertEqual(resultArray[4], 2.0, accuracy: accuracy, "Token 4 should be unchanged") + } + + func testRepetitionPenaltyWithNoPenalty() async throws { + let processor = RepetitionPenaltyLogitsProcessor(penalty: 1.0) + + let inputIds = MLTensor(shape: [1, 2], scalars: [Int32(1), Int32(2)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 5], scalars: [Float(1.0), Float(2.0), Float(3.0), Float(4.0), Float(5.0)], scalarType: Float.self) + + let result = await processor(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + let expectedArray = await scores.shapedArray(of: Float.self).scalars + + // With penalty=1.0, scores should be unchanged + XCTAssertEqual(resultArray, expectedArray, "Penalty of 1.0 should not change scores") + } + + // MARK: - Processor List Tests + + func testLogitsProcessorList() async throws { + let temp = TemperatureLogitsWarper(temperature: 2.0) + let topK = TopKLogitsWarper(topK: 3) + let processorList = LogitsProcessorList(processors: [temp, topK]) + + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 5], scalars: [Float(2.0), Float(4.0), Float(6.0), Float(8.0), Float(10.0)], scalarType: Float.self) + + // First temperature divides by 2: [1, 2, 3, 4, 5] + // Then top-k keeps top 3: [-inf, -inf, 3, 4, 5] + let result = await processorList(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + XCTAssertTrue(resultArray[0].isInfinite && resultArray[0] < 0) + XCTAssertTrue(resultArray[1].isInfinite && resultArray[1] < 0) + XCTAssertEqual(resultArray[2], 3.0, accuracy: accuracy) + XCTAssertEqual(resultArray[3], 4.0, accuracy: accuracy) + XCTAssertEqual(resultArray[4], 5.0, accuracy: accuracy) + } + + func testEmptyProcessorList() async throws { + let processorList = LogitsProcessorList(processors: []) + + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 3], scalars: [Float(1.0), Float(2.0), Float(3.0)], scalarType: Float.self) + + let result = await processorList(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + let expectedArray = await scores.shapedArray(of: Float.self).scalars + + // Should be unchanged + XCTAssertEqual(resultArray, expectedArray) + } +} + +// MARK: - Test Helpers + +@available(macOS 15.0, iOS 18.0, *) +func assertMLTensorEqual( + _ tensor: MLTensor, + expected: [Float], + accuracy: Float, + file: StaticString = #filePath, + line: UInt = #line +) async { + let actual = await tensor.shapedArray(of: Float.self).scalars + XCTAssertEqual(actual.count, expected.count, "Tensor size mismatch", file: file, line: line) + for (a, e) in zip(actual, expected) { + XCTAssertEqual(a, e, accuracy: accuracy, file: file, line: line) + } +} \ No newline at end of file From e550fa45aa8cab875c1ec72b2cf549fbdd5676b9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 27 Sep 2025 22:23:44 +0200 Subject: [PATCH 02/14] formatting --- Sources/Generation/LogitsWarper/LogitsProcessor.swift | 2 +- .../LogitsWarper/RepetitionPenaltyLogitsProcessor.swift | 2 +- Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift | 2 +- Sources/Generation/LogitsWarper/TopKLogitsWarper.swift | 4 ++-- Sources/Generation/LogitsWarper/TopPLogitsWarper.swift | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Sources/Generation/LogitsWarper/LogitsProcessor.swift b/Sources/Generation/LogitsWarper/LogitsProcessor.swift index 9af7091..cb8f7a0 100644 --- a/Sources/Generation/LogitsWarper/LogitsProcessor.swift +++ b/Sources/Generation/LogitsWarper/LogitsProcessor.swift @@ -50,4 +50,4 @@ public struct LogitsProcessorList { return processedScores } } -#endif \ No newline at end of file +#endif diff --git a/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift b/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift index e2f4215..1f565ba 100644 --- a/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift +++ b/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift @@ -71,4 +71,4 @@ public struct RepetitionPenaltyLogitsProcessor: LogitsProcessor { return MLTensor(shape: scores.shape, scalars: scoresData, scalarType: Float.self) } } -#endif \ No newline at end of file +#endif diff --git a/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift b/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift index 6707cbc..ad27849 100644 --- a/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift +++ b/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift @@ -28,4 +28,4 @@ public struct TemperatureLogitsWarper: LogitsProcessor { scores / temperature } } -#endif \ No newline at end of file +#endif diff --git a/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift index 6a3f832..16396a4 100644 --- a/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift +++ b/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift @@ -33,7 +33,7 @@ public struct TopKLogitsWarper: LogitsProcessor { public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { let vocabSize = scores.shape[scores.rank - 1] - let k = min(topK, vocabSize) // Safety check + let k = min(topK, vocabSize) // Safety check // Get the k-th highest score (the threshold) let (topKValues, _) = scores.topK(k) @@ -71,4 +71,4 @@ public struct TopKLogitsWarper: LogitsProcessor { return MLTensor(shape: scores.shape, scalars: maskedScores, scalarType: Float.self) } } -#endif \ No newline at end of file +#endif diff --git a/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift index aaf4732..73e35c7 100644 --- a/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift +++ b/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift @@ -101,4 +101,4 @@ public struct TopPLogitsWarper: LogitsProcessor { return MLTensor(shape: scores.shape, scalars: maskedScores, scalarType: Float.self) } } -#endif \ No newline at end of file +#endif From d6f1aeecae716b6fd3974a1a1038ede0b6c2c2f8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 27 Sep 2025 22:56:28 +0200 Subject: [PATCH 03/14] added generation integration tests --- Sources/Generation/Generation.swift | 9 +- .../RepetitionPenaltyLogitsProcessor.swift | 31 +- .../LogitsWarper/TopKLogitsWarper.swift | 5 + .../GenerationIntegrationTests.swift | 309 ++++++++++++++++++ .../LogitsProcessorTests.swift | 32 +- 5 files changed, 370 insertions(+), 16 deletions(-) create mode 100644 Tests/GenerationTests/GenerationIntegrationTests.swift diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift index 45d8908..69ca751 100644 --- a/Sources/Generation/Generation.swift +++ b/Sources/Generation/Generation.swift @@ -76,7 +76,10 @@ extension Generation { // Create logits processor list based on config let logitsProcessorList = createLogitsProcessorList(config: config) - while outputTokens.shape[1] < config.maxLength { + let inputLength = outputTokens.shape[1] + let maxTotalLength = min(config.maxLength, inputLength + config.maxNewTokens) + + while outputTokens.shape[1] < maxTotalLength { // Get raw logits from model let nextTokenScores = await model(outputTokens, config) @@ -119,7 +122,9 @@ extension Generation { processors.append(TemperatureLogitsWarper(temperature: config.temperature)) } - // Top-K filtering + // Top-K filtering (only apply if topK is meaningful) + // Note: We can't determine vocab size here, so TopKLogitsWarper handles the case + // where topK >= vocabSize internally if config.topK > 0 && config.topK < Int.max { processors.append(TopKLogitsWarper(topK: config.topK)) } diff --git a/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift b/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift index 1f565ba..ae0b125 100644 --- a/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift +++ b/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift @@ -44,26 +44,37 @@ public struct RepetitionPenaltyLogitsProcessor: LogitsProcessor { // Process each batch item var scoresData = scoresArray.scalars - let batchSize = scores.shape[0] - let vocabSize = scores.shape[1] + let shape = scores.shape + precondition(!shape.isEmpty, "scores tensor must have at least one dimension") + + let batchSize = shape[0] + let vocabSize = shape[shape.count - 1] + let elementsPerBatch = shape.dropFirst().reduce(1, *) + let vocabBlocksPerBatch = max(elementsPerBatch / max(vocabSize, 1), 1) for batchIdx in 0..= 0 && tokenId < vocabSize else { continue } + // Apply penalty to each token that appeared in the sequence across all vocab blocks + for blockIdx in 0..= 0 && tokenId < vocabSize else { continue } + + let scoreIdx = blockOffset + tokenId + guard scoreIdx < scoresData.count else { continue } - let scoreIdx = seqStart + tokenId - let score = scoresData[scoreIdx] + let score = scoresData[scoreIdx] - // Apply penalty based on sign (following transformers implementation) - scoresData[scoreIdx] = score < 0 ? score * penalty : score / penalty + // Apply penalty based on sign (following transformers implementation) + scoresData[scoreIdx] = score < 0 ? score * penalty : score / penalty + } } } diff --git a/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift index 16396a4..2dc153c 100644 --- a/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift +++ b/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift @@ -35,6 +35,11 @@ public struct TopKLogitsWarper: LogitsProcessor { let vocabSize = scores.shape[scores.rank - 1] let k = min(topK, vocabSize) // Safety check + // If k equals vocabSize, no filtering is needed + if k >= vocabSize { + return scores + } + // Get the k-th highest score (the threshold) let (topKValues, _) = scores.topK(k) diff --git a/Tests/GenerationTests/GenerationIntegrationTests.swift b/Tests/GenerationTests/GenerationIntegrationTests.swift new file mode 100644 index 0000000..6a0809e --- /dev/null +++ b/Tests/GenerationTests/GenerationIntegrationTests.swift @@ -0,0 +1,309 @@ +import CoreML +import Tokenizers +import XCTest + +@testable import Generation + +@available(macOS 15.0, iOS 18.0, *) +final class GenerationIntegrationTests: XCTestCase { + + // MARK: - Mock Model for Testing + + /// Mock language model that returns predictable logits for testing + class MockLanguageModel { + var callCount = 0 + var logitsHistory: [MLTensor] = [] + + func predictNextToken(_ inputTokens: MLTensor, _ config: GenerationConfig) async -> MLTensor { + callCount += 1 + + // Return different logits based on the sequence length + let seqLength = inputTokens.shape[1] + + // Simulate a vocabulary of 10 tokens + // Create logits that favor certain tokens based on context + let vocabSize = 10 + var logits = [Float](repeating: 0.0, count: vocabSize) + + switch seqLength { + case 1: + // First generation: favor token 5 + logits = [1.0, 1.5, 2.0, 2.5, 3.0, 10.0, 3.0, 2.5, 2.0, 1.5] + case 2: + // Second generation: favor token 3 + logits = [1.0, 2.0, 3.0, 8.0, 3.0, 2.0, 2.0, 1.5, 1.0, 0.5] + case 3: + // Third generation: create a more uniform distribution + logits = [2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 3.5, 3.0, 2.5, 2.0] + default: + // Default: slightly favor middle tokens + logits = [1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.5] + } + + let tensor = MLTensor(shape: [1, vocabSize], scalars: logits, scalarType: Float.self) + logitsHistory.append(tensor) + return tensor + } + } + + // MARK: - Integration Tests + + func testGreedyGenerationWithoutProcessors() async throws { + let model = MockLanguageModel() + + var config = GenerationConfig(maxNewTokens: 3) + config.doSample = false // Greedy mode + config.eosTokenId = -1 // Disable early stopping + + let generation = TestGeneration() + let startTokens = [0] // Start with token 0 + + let output = await generation.generate( + config: config, + tokens: startTokens, + model: model.predictNextToken + ) + + // Greedy should always pick the highest logit + // Token 0 -> Token 5 (logit 10.0) -> Token 3 (logit 8.0) -> Token 5 (logit 4.5) + XCTAssertEqual(output.count, 4, "Should generate 3 new tokens + initial token") + XCTAssertEqual(output[0], 0, "First token should be the start token") + XCTAssertEqual(output[1], 5, "Second token should be 5 (highest logit)") + XCTAssertEqual(output[2], 3, "Third token should be 3 (highest logit)") + XCTAssertEqual(output[3], 5, "Fourth token should be 5 (highest logit)") + + XCTAssertEqual(model.callCount, 3, "Model should be called 3 times") + } + + func testSamplingWithTemperature() async throws { + let model = MockLanguageModel() + + var config = GenerationConfig(maxNewTokens: 3) + config.doSample = true // Sampling mode + config.temperature = 0.1 // Low temperature = more deterministic + config.eosTokenId = -1 + + let generation = TestGeneration() + let startTokens = [0] + + let output = await generation.generate( + config: config, + tokens: startTokens, + model: model.predictNextToken + ) + + XCTAssertEqual(output.count, 4, "Should generate 3 new tokens + initial token") + XCTAssertEqual(output[0], 0, "First token should be the start token") + + // With low temperature, sampling should still prefer high-probability tokens + // We can't assert exact tokens due to randomness, but can verify structure + XCTAssertTrue(output[1] < 10, "Generated token should be in vocab range") + XCTAssertTrue(output[2] < 10, "Generated token should be in vocab range") + XCTAssertTrue(output[3] < 10, "Generated token should be in vocab range") + } + + func testTopKFiltering() async throws { + let model = MockLanguageModel() + + var config = GenerationConfig(maxNewTokens: 3) + config.doSample = true // Sampling mode + config.topK = 3 // Only consider top 3 tokens + config.temperature = 1.0 + config.eosTokenId = -1 + + let generation = TestGeneration() + let startTokens = [0] + + // Run generation multiple times to test top-k filtering + for _ in 0..<5 { + model.callCount = 0 + model.logitsHistory = [] + + let output = await generation.generate( + config: config, + tokens: startTokens, + model: model.predictNextToken + ) + + XCTAssertEqual(output.count, 4, "Should generate 3 new tokens + initial token") + + // Verify that generated tokens are within valid range + for token in output[1...] { + XCTAssertTrue(token >= 0 && token < 10, "Token \(token) should be in vocab range") + } + } + } + + func testTopPFiltering() async throws { + let model = MockLanguageModel() + + var config = GenerationConfig(maxNewTokens: 2) + config.doSample = true // Sampling mode + config.topP = 0.9 // Top 90% probability mass + config.temperature = 1.0 + config.eosTokenId = -1 + + let generation = TestGeneration() + let startTokens = [0] + + let output = await generation.generate( + config: config, + tokens: startTokens, + model: model.predictNextToken + ) + + XCTAssertEqual(output.count, 3, "Should generate 2 new tokens + initial token") + + // Top-P should filter out low-probability tokens + for token in output[1...] { + XCTAssertTrue(token >= 0 && token < 10, "Token should be in vocab range") + } + } + + func testRepetitionPenalty() async throws { + // Create a model that always returns the same high-scoring token + class RepetitiveModel { + func predict(_ inputTokens: MLTensor, _ config: GenerationConfig) async -> MLTensor { + // Always favor token 7 + let logits: [Float] = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 10.0, 1.0, 1.0] + return MLTensor(shape: [1, 10], scalars: logits, scalarType: Float.self) + } + } + + let model = RepetitiveModel() + + // Test WITHOUT repetition penalty + var configNoPenalty = GenerationConfig(maxNewTokens: 3) + configNoPenalty.doSample = false // Greedy mode + configNoPenalty.repetitionPenalty = 1.0 // No penalty + configNoPenalty.eosTokenId = -1 + + let generation = TestGeneration() + let startTokens = [0] + + let outputNoPenalty = await generation.generate( + config: configNoPenalty, + tokens: startTokens, + model: model.predict + ) + + // Without penalty, should keep selecting token 7 + XCTAssertEqual(outputNoPenalty, [0, 7, 7, 7], "Without penalty, should repeat token 7") + + // Test WITH repetition penalty + var configWithPenalty = GenerationConfig(maxNewTokens: 3) + configWithPenalty.doSample = false // Greedy mode + configWithPenalty.repetitionPenalty = 10.0 // Strong penalty (10.0 / 10.0 = 1.0, so other tokens win) + configWithPenalty.eosTokenId = -1 + + let outputWithPenalty = await generation.generate( + config: configWithPenalty, + tokens: startTokens, + model: model.predict + ) + + // With strong penalty, should select token 7 first time, but then avoid it + XCTAssertEqual(outputWithPenalty[0], 0, "First token should be start token") + XCTAssertEqual(outputWithPenalty[1], 7, "Second token should be 7 (highest score)") + + // After token 7 is penalized, other tokens should be chosen + // (exact tokens depend on penalty calculation, but should NOT all be 7) + let uniqueTokens = Set(outputWithPenalty[1...]) + XCTAssertTrue(uniqueTokens.count > 1, "With repetition penalty, should generate diverse tokens, got \(outputWithPenalty)") + } + + func testCombinedProcessors() async throws { + let model = MockLanguageModel() + + var config = GenerationConfig(maxNewTokens: 3) + config.doSample = true // Sampling mode + config.temperature = 0.8 // Slightly focused + config.topK = 5 // Top 5 tokens + config.topP = 0.95 // 95% probability mass + config.repetitionPenalty = 1.1 // Slight penalty + config.eosTokenId = -1 + + let generation = TestGeneration() + let startTokens = [0] + + let output = await generation.generate( + config: config, + tokens: startTokens, + model: model.predictNextToken + ) + + XCTAssertEqual(output.count, 4, "Should generate 3 new tokens + initial token") + XCTAssertEqual(output[0], 0, "First token should be the start token") + + // All processors should work together + // Can't assert exact tokens due to randomness, but verify structure + for token in output[1...] { + XCTAssertTrue(token >= 0 && token < 10, "Token should be in vocab range") + } + + // Verify model was called correct number of times + XCTAssertEqual(model.callCount, 3, "Model should be called 3 times") + } + + func testEarlyStoppingWithEOS() async throws { + // Create a model that returns EOS token after 2 generations + class EOSModel { + var callCount = 0 + + func predict(_ inputTokens: MLTensor, _ config: GenerationConfig) async -> MLTensor { + callCount += 1 + + let vocabSize = 10 + var logits = [Float](repeating: 1.0, count: vocabSize) + + if callCount >= 2 { + // After 2 calls, strongly favor EOS token (which we'll set as token 9) + logits[9] = 100.0 + } else { + // Before that, favor token 5 + logits[5] = 10.0 + } + + return MLTensor(shape: [1, vocabSize], scalars: logits, scalarType: Float.self) + } + } + + let model = EOSModel() + + var config = GenerationConfig(maxNewTokens: 10) // Request many tokens + config.doSample = false // Greedy mode + config.eosTokenId = 9 // Token 9 is EOS + + let generation = TestGeneration() + let startTokens = [0] + + let output = await generation.generate( + config: config, + tokens: startTokens, + model: model.predict + ) + + // Should stop early when EOS is encountered + XCTAssertLessThan(output.count, 11, "Should stop before generating all 10 tokens") + XCTAssertEqual(output[0], 0, "First token should be start token") + + // Model should be called fewer times due to early stopping + XCTAssertLessThan(model.callCount, 10, "Model should be called fewer times due to EOS") + } +} + +// MARK: - Test Helper + +@available(macOS 15.0, iOS 18.0, *) +struct TestGeneration: Generation { + func generate( + config: GenerationConfig, + prompt: String, + model: NextTokenModel, + tokenizer: Tokenizers.Tokenizer, + callback: PredictionStringCallback? + ) async -> String { + // Not used in these tests + return "" + } +} diff --git a/Tests/GenerationTests/LogitsProcessorTests.swift b/Tests/GenerationTests/LogitsProcessorTests.swift index 1e27248..701f0c4 100644 --- a/Tests/GenerationTests/LogitsProcessorTests.swift +++ b/Tests/GenerationTests/LogitsProcessorTests.swift @@ -1,5 +1,6 @@ -import XCTest import CoreML +import XCTest + @testable import Generation @available(macOS 15.0, iOS 18.0, *) @@ -17,7 +18,7 @@ final class LogitsProcessorTests: XCTestCase { let scores = MLTensor(shape: [1, 3], scalars: [Float(2.0), Float(4.0), Float(6.0)], scalarType: Float.self) let result = await warper(inputIds, scores) - let expected: [Float] = [1.0, 2.0, 3.0] // Each score divided by 2.0 + let expected: [Float] = [1.0, 2.0, 3.0] // Each score divided by 2.0 await assertMLTensorEqual(result, expected: expected, accuracy: accuracy) } @@ -29,7 +30,7 @@ final class LogitsProcessorTests: XCTestCase { let scores = MLTensor(shape: [1, 2], scalars: [Float(1.0), Float(2.0)], scalarType: Float.self) let result = await sharper(inputIds, scores) - let expected: [Float] = [2.0, 4.0] // Divided by 0.5 = multiplied by 2 + let expected: [Float] = [2.0, 4.0] // Divided by 0.5 = multiplied by 2 await assertMLTensorEqual(result, expected: expected, accuracy: accuracy) } @@ -148,6 +149,29 @@ final class LogitsProcessorTests: XCTestCase { XCTAssertEqual(resultArray, expectedArray, "Penalty of 1.0 should not change scores") } + func testRepetitionPenaltyWithRank3Scores() async throws { + let processor = RepetitionPenaltyLogitsProcessor(penalty: 2.0) + + // Input sequence with tokens [1, 2, 3] + let inputIds = MLTensor(shape: [1, 3], scalars: [Int32(1), Int32(2), Int32(3)], scalarType: Int32.self) + + // Scores shaped as [batch, sequence_length, vocab] -> [1, 1, 5] + let scores = MLTensor( + shape: [1, 1, 5], + scalars: [Float(0.5), Float(-0.5), Float(1.0), Float(-1.0), Float(2.0)], + scalarType: Float.self + ) + + let result = await processor(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + let expected: [Float] = [0.5, -1.0, 0.5, -2.0, 2.0] + XCTAssertEqual(resultArray.count, expected.count, "Flattened tensor mismatch") + for (value, exp) in zip(resultArray, expected) { + XCTAssertEqual(value, exp, accuracy: accuracy) + } + } + // MARK: - Processor List Tests func testLogitsProcessorList() async throws { @@ -200,4 +224,4 @@ func assertMLTensorEqual( for (a, e) in zip(actual, expected) { XCTAssertEqual(a, e, accuracy: accuracy, file: file, line: line) } -} \ No newline at end of file +} From 55b534f1a24a3abeee1a546bd09d3b8b8db2c616 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 27 Sep 2025 23:07:01 +0200 Subject: [PATCH 04/14] use MLTensor --- .../LogitsWarper/TopPLogitsWarper.swift | 79 ++++++++----------- 1 file changed, 35 insertions(+), 44 deletions(-) diff --git a/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift index 73e35c7..9eee77e 100644 --- a/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift +++ b/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift @@ -37,65 +37,56 @@ public struct TopPLogitsWarper: LogitsProcessor { public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { // Algorithm (following transformers implementation): - // 1. Sort logits in ascending order (transformers uses descending=False) + // 1. Sort logits in descending order // 2. Compute softmax probabilities on sorted logits // 3. Compute cumulative sum of probabilities - // 4. Remove tokens with cumulative probability <= (1 - top_p) + // 4. Remove tokens with cumulative probability > top_p // 5. Keep at least min_tokens_to_keep - // 6. Scatter mask back to original indexing + // 6. Scatter mask back to original indexing using argsort indices - let batchSize = scores.shape[0] let vocabSize = scores.shape[scores.rank - 1] - // Convert to CPU for processing + // Get sorted indices (descending order - highest scores first) + let sortedIndices = scores.argsort(alongAxis: -1, descendingOrder: true) + + // Gather scores in sorted order + let sortedScores = scores.gathering(atIndices: sortedIndices, alongAxis: -1) + + // Compute softmax on sorted scores + let sortedProbs = sortedScores.softmax(alongAxis: -1) + + // Compute cumulative sum + let cumulativeProbs = sortedProbs.cumulativeSum(alongAxis: -1) + + // Create mask: remove tokens where cumsum > topP + // The HuggingFace implementation removes tokens where cumsum - current_prob > topP + // This ensures we include the first token that pushes us over the threshold + + // Shift cumsum to get cumsum of previous tokens + // For first token, this will be 0 + let cumulativeProbsShifted = cumulativeProbs - sortedProbs + + // Need CPU fallback for scatter and minTokensToKeep logic + let cumulativeProbsArray = await cumulativeProbsShifted.shapedArray(of: Float.self) + let indicesArray = await sortedIndices.shapedArray(of: Int32.self) let scoresArray = await scores.shapedArray(of: Float.self) + var maskedScores = scoresArray.scalars + let batchSize = scores.shape[0] - // Process each batch item for batchIdx in 0.. (1 - top_p) - let threshold = 1.0 - topP - var indicesToRemove = Set() + // Keep first minTokensToKeep tokens, otherwise check cumsum threshold + let shouldRemove = (i >= minTokensToKeep) && (cumulativeProbsArray.scalars[sortedIdx] > topP) - for i in 0.. Date: Sun, 28 Sep 2025 15:44:38 +0200 Subject: [PATCH 05/14] fix top-p and do-sample --- .../transformers-cli/Transformers.swift | 33 ++++++++- Sources/Generation/Decoders.swift | 33 ++++++--- .../LogitsWarper/LogitsProcessor.swift | 10 ++- .../LogitsWarper/TopPLogitsWarper.swift | 73 ++++++++++--------- 4 files changed, 101 insertions(+), 48 deletions(-) diff --git a/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift b/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift index 0ff53b5..e02c621 100644 --- a/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift +++ b/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift @@ -27,10 +27,25 @@ struct TransformersCLI: AsyncParsableCommand { @Option( help: """ When enabled, two generation passes are ran, one to 'warm up' and another to collect \ - benchmark metrics. + benchmark metrics. """) var warmup: Bool = false + @Option(help: "Enable sampling mode (true) or use greedy decoding (false)") + var doSample: Bool = false + + @Option(help: "Temperature for sampling (lower = more deterministic, typical: 0.7-1.0)") + var temperature: Float? + + @Option(help: "Top-k filtering - only consider k most likely tokens (typical: 5-50)") + var topK: Int? + + @Option(help: "Top-p (nucleus) sampling - cumulative probability threshold (typical: 0.9-0.95)") + var topP: Double? + + @Option(help: "Repetition penalty to discourage repeating tokens (typical: 1.0-2.0, 1.0 = no penalty)") + var repetitionPenalty: Double? + func generate( model: LanguageModel, config: GenerationConfig, @@ -88,11 +103,23 @@ struct TransformersCLI: AsyncParsableCommand { print("Loading model \(compiledURL)") let model = try LanguageModel.loadCompiled(url: compiledURL, computeUnits: computeUnits.asMLComputeUnits) - // Using greedy generation for now var config = model.defaultGenerationConfig - config.doSample = false + config.doSample = doSample config.maxNewTokens = maxLength + if let temperature = temperature { + config.temperature = temperature + } + if let topK = topK { + config.topK = topK + } + if let topP = topP { + config.topP = topP + } + if let repetitionPenalty = repetitionPenalty { + config.repetitionPenalty = repetitionPenalty + } + // Given the size of the out-of-model computation, dispatch all // tensor operations to the CPU. diff --git a/Sources/Generation/Decoders.swift b/Sources/Generation/Decoders.swift index 81a52d7..a76a101 100644 --- a/Sources/Generation/Decoders.swift +++ b/Sources/Generation/Decoders.swift @@ -5,7 +5,9 @@ import CoreML @available(macOS 15.0, iOS 18.0, *) func selectNextTokenUsingGreedyDecoding(from scores: MLTensor) -> MLTensor { - scores.argmax(alongAxis: -1).reshaped(to: [1, 1]) + let indices = scores.argmax(alongAxis: -1).reshaped(to: [1, 1]) + // Ensure indices are Int32 for concatenation with input tokens + return indices.scalarType == Int32.self ? indices : indices.cast(to: Int32.self) } // MARK: Sampling @@ -22,20 +24,33 @@ func selectNextTokenUsingSampling(from scores: MLTensor) -> MLTensor { // Convert logits to probabilities let probs = scores.softmax(alongAxis: -1) - // Multinomial sampling using cumulative sum method + // Multinomial sampling using cumulative sum method: // 1. Generate random number in [0, 1) // 2. Compute cumulative sum of probabilities // 3. Find first index where cumsum >= random_number + // + // This is equivalent to torch.multinomial() but using available MLTensor ops - let rnd = Float.random(in: 0..<1) - var cumulativeProbs = probs.cumulativeSum(alongAxis: -1) + let batchSize = scores.shape[0] + let rndTensor = MLTensor(randomUniform: [batchSize, 1], in: 0..<1, scalarType: Float.self) + let cumulativeProbs = probs.cumulativeSum(alongAxis: -1) - // Mark all positions where cumsum >= rnd with a large value - // Then argsort will give us the first such position - cumulativeProbs = cumulativeProbs + (cumulativeProbs .< rnd) * 100.0 + // Ensure random tensor matches the type of cumulativeProbs + let rnd = cumulativeProbs.scalarType == Float.self ? rndTensor : rndTensor.cast(to: cumulativeProbs.scalarType) - let sampledIndex = cumulativeProbs.argsort(alongAxis: -1)[..., 0] - return sampledIndex.reshaped(to: [1, 1]) + // Create mask where cumsum >= rnd (these are candidates) + // We want the FIRST position where this is true + // Strategy: Set all positions where cumsum < rnd to a large value (1000.0) + // Set all positions where cumsum >= rnd to their index value + // Then argmin will give us the first qualifying index + + let mask = cumulativeProbs .< rnd + let penalized = mask * 1000.0 // Large value for positions to skip + let indexed = penalized + cumulativeProbs // Positions >= rnd will have small values + + let sampledIndex = indexed.argmin(alongAxis: -1).reshaped(to: [1, 1]) + // Ensure indices are Int32 for concatenation with input tokens + return sampledIndex.scalarType == Int32.self ? sampledIndex : sampledIndex.cast(to: Int32.self) } // MARK: Legacy Top-K Sampling (deprecated, use LogitsProcessorList instead) diff --git a/Sources/Generation/LogitsWarper/LogitsProcessor.swift b/Sources/Generation/LogitsWarper/LogitsProcessor.swift index cb8f7a0..a2a042b 100644 --- a/Sources/Generation/LogitsWarper/LogitsProcessor.swift +++ b/Sources/Generation/LogitsWarper/LogitsProcessor.swift @@ -43,11 +43,17 @@ public struct LogitsProcessorList { /// - scores: Tensor of raw logit scores with shape `[batch_size, vocab_size]` /// - Returns: Processed logits tensor with shape `[batch_size, vocab_size]` public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { - var processedScores = scores + // Following transformers convention: all logits processing happens in Float32 + // Cast to Float32 once at the start, process, then cast back to original type at the end + let originalScalarType = scores.scalarType + var processedScores = scores.scalarType == Float.self ? scores : scores.cast(to: Float.self) + for processor in processors { processedScores = await processor(inputIds, processedScores) } - return processedScores + + // Cast back to original type if needed + return originalScalarType == Float.self ? processedScores : processedScores.cast(to: originalScalarType) } } #endif diff --git a/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift index 9eee77e..725589b 100644 --- a/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift +++ b/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift @@ -42,54 +42,59 @@ public struct TopPLogitsWarper: LogitsProcessor { // 3. Compute cumulative sum of probabilities // 4. Remove tokens with cumulative probability > top_p // 5. Keep at least min_tokens_to_keep - // 6. Scatter mask back to original indexing using argsort indices + // 6. Scatter mask back to original indexing using inverse permutation let vocabSize = scores.shape[scores.rank - 1] - // Get sorted indices (descending order - highest scores first) + // Scores are already in Float32 (handled by LogitsProcessorList) + // Sort in descending order (highest scores first) let sortedIndices = scores.argsort(alongAxis: -1, descendingOrder: true) + // Build inverse permutation for scattering back + let inversePermutation = sortedIndices.argsort(alongAxis: -1) + // Gather scores in sorted order let sortedScores = scores.gathering(atIndices: sortedIndices, alongAxis: -1) - // Compute softmax on sorted scores + // Compute probabilities and cumulative sum in sorted order let sortedProbs = sortedScores.softmax(alongAxis: -1) - - // Compute cumulative sum let cumulativeProbs = sortedProbs.cumulativeSum(alongAxis: -1) - // Create mask: remove tokens where cumsum > topP - // The HuggingFace implementation removes tokens where cumsum - current_prob > topP + // Shift cumsum to exclude current token (HuggingFace convention) // This ensures we include the first token that pushes us over the threshold - - // Shift cumsum to get cumsum of previous tokens - // For first token, this will be 0 let cumulativeProbsShifted = cumulativeProbs - sortedProbs - // Need CPU fallback for scatter and minTokensToKeep logic - let cumulativeProbsArray = await cumulativeProbsShifted.shapedArray(of: Float.self) - let indicesArray = await sortedIndices.shapedArray(of: Int32.self) - let scoresArray = await scores.shapedArray(of: Float.self) - - var maskedScores = scoresArray.scalars - let batchSize = scores.shape[0] - - for batchIdx in 0..= minTokensToKeep) && (cumulativeProbsArray.scalars[sortedIdx] > topP) - - if shouldRemove { - maskedScores[batchStart + originalIdx] = filterValue - } - } - } - - return MLTensor(shape: scores.shape, scalars: maskedScores, scalarType: Float.self) + // Create position tensor [0, 1, 2, ..., vocabSize-1] for minTokensToKeep check + let baseShape = Array(repeating: 1, count: sortedScores.rank - 1) + [vocabSize] + var multiples = sortedScores.shape + multiples[multiples.count - 1] = 1 + + let positions = MLTensor( + rangeFrom: Int32(0), + to: Int32(vocabSize), + by: 1, + scalarType: Int32.self + ) + .reshaped(to: baseShape) + .tiled(multiples: multiples) + .cast(to: Float.self) + + // Create mask in sorted order: + // Remove if: position >= minTokensToKeep AND cumsum_shifted > topP + let beyondMinimum = positions .>= Float(minTokensToKeep) + let exceedsThreshold = cumulativeProbsShifted .> topP + let removeMaskSorted = beyondMinimum .& exceedsThreshold + + // Apply filter value in sorted space + let filterTensor = MLTensor( + repeating: filterValue, + shape: sortedScores.shape, + scalarType: Float.self + ) + let filteredSorted = sortedScores.replacing(with: filterTensor, where: removeMaskSorted) + + // Scatter back to original indexing using inverse permutation + return filteredSorted.gathering(atIndices: inversePermutation, alongAxis: -1) } } #endif From 8e4db8f9477ac3caaec76f69b0ec42500232af17 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 28 Sep 2025 15:49:53 +0200 Subject: [PATCH 06/14] fix CI issue --- Sources/Models/LanguageModel.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index 9a962af..072b982 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -129,7 +129,7 @@ extension LanguageModel { static let valueCache = "valueCache" // Output keys static let logits = "logits" - static let presentKeys = "presentKeys" + static let present = "presentKeys" static let presentValues = "presentValues" } } @@ -265,7 +265,7 @@ public extension LanguageModel { } let kCacheInput = model.modelDescription.inputDescriptionsByName[Keys.keyCache] != nil let vCacheInput = model.modelDescription.inputDescriptionsByName[Keys.valueCache] != nil - let kCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.presentKeys] != nil + let kCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.present] != nil let vCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.presentValues] != nil guard Set([kCacheInput, vCacheInput, kCacheOutput, vCacheOutput]).count == 1 else { From 4d012b9198f424f06316147c5016f61a90898c9e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 28 Sep 2025 15:55:21 +0200 Subject: [PATCH 07/14] undo changes to LanguageModel.swift --- Sources/Models/LanguageModel.swift | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index 072b982..044fd19 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -129,7 +129,7 @@ extension LanguageModel { static let valueCache = "valueCache" // Output keys static let logits = "logits" - static let present = "presentKeys" + static let presentKeys = "presentKeys" static let presentValues = "presentValues" } } @@ -265,7 +265,7 @@ public extension LanguageModel { } let kCacheInput = model.modelDescription.inputDescriptionsByName[Keys.keyCache] != nil let vCacheInput = model.modelDescription.inputDescriptionsByName[Keys.valueCache] != nil - let kCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.present] != nil + let kCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.presentKeys] != nil let vCacheOutput = model.modelDescription.outputDescriptionsByName[Keys.presentValues] != nil guard Set([kCacheInput, vCacheInput, kCacheOutput, vCacheOutput]).count == 1 else { @@ -408,7 +408,7 @@ extension LanguageModel: TextGenerationModel { /// Provides sensible defaults based on the model type, with model-specific /// optimizations for known architectures like GPT models. public var defaultGenerationConfig: GenerationConfig { - var config: GenerationConfig = GenerationConfig(maxNewTokens: 2048) + var config = GenerationConfig(maxNewTokens: 2048) switch modelName.lowercased() { case let x where x.contains("gpt"): config.doSample = true From 9492d473860e2b0e16f0f8d85737b4c18ac0d648 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 28 Sep 2025 16:09:41 +0200 Subject: [PATCH 08/14] use MLTensor --- .../RepetitionPenaltyLogitsProcessor.swift | 74 +++++++++---------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift b/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift index ae0b125..43c90f6 100644 --- a/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift +++ b/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift @@ -32,54 +32,54 @@ public struct RepetitionPenaltyLogitsProcessor: LogitsProcessor { public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { guard penalty != 1.0 else { return scores } - // Implementation approach (following transformers): - // 1. Get unique token IDs from inputIds - // 2. For each unique token, gather its logit value - // 3. Apply conditional penalty: if logit < 0: *= penalty, else: /= penalty - // 4. Scatter penalized values back to original positions + // Optimized implementation following transformers: + // 1. Gather scores for tokens that appear in input_ids + // 2. Apply conditional penalty: if score < 0: *= penalty, else: /= penalty + // 3. Scatter penalized values back to original positions - // Convert to CPU for gather/scatter operations - let scoresArray = await scores.shapedArray(of: Float.self) - let inputIdsArray = await inputIds.shapedArray(of: Int32.self) - - // Process each batch item - var scoresData = scoresArray.scalars - let shape = scores.shape - precondition(!shape.isEmpty, "scores tensor must have at least one dimension") - - let batchSize = shape[0] - let vocabSize = shape[shape.count - 1] - let elementsPerBatch = shape.dropFirst().reduce(1, *) - let vocabBlocksPerBatch = max(elementsPerBatch / max(vocabSize, 1), 1) - - for batchIdx in 0..= 0 && tokenId < vocabSize else { continue } + let inputIdsArray = await inputIds.shapedArray(of: Int32.self) + let penalizedArray = await penalizedScores.shapedArray(of: Float.self) + var scoresArray = await scores.shapedArray(of: Float.self) - let scoreIdx = blockOffset + tokenId - guard scoreIdx < scoresData.count else { continue } + for batchIdx in 0..= 0 && tokenId < vocabSize else { continue } - // Apply penalty based on sign (following transformers implementation) - scoresData[scoreIdx] = score < 0 ? score * penalty : score / penalty + // For rank-2: [batch_size, vocab_size] + if scores.rank == 2 { + let scoreIdx = batchOffset + tokenId + let penalizedIdx = seqStart + tokenIdx + scoresArray.scalars[scoreIdx] = penalizedArray.scalars[penalizedIdx] + } + // For rank-3: [batch_size, seq_len, vocab_size] - update last position + else if scores.rank == 3 { + let lastSeqPos = scores.shape[1] - 1 + let scoreIdx = batchOffset + lastSeqPos * vocabSize + tokenId + let penalizedIdx = seqStart + tokenIdx + scoresArray.scalars[scoreIdx] = penalizedArray.scalars[penalizedIdx] } } } - // Create new tensor with penalized scores - return MLTensor(shape: scores.shape, scalars: scoresData, scalarType: Float.self) + return MLTensor(shape: scores.shape, scalars: scoresArray.scalars, scalarType: Float.self) } } #endif From 83164c299986cb55c3449f35310548b27ef2d383 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 29 Sep 2025 21:18:49 +0200 Subject: [PATCH 09/14] Update Tests/GenerationTests/GenerationIntegrationTests.swift Co-authored-by: Pedro Cuenca --- Tests/GenerationTests/GenerationIntegrationTests.swift | 1 + 1 file changed, 1 insertion(+) diff --git a/Tests/GenerationTests/GenerationIntegrationTests.swift b/Tests/GenerationTests/GenerationIntegrationTests.swift index 6a0809e..637760f 100644 --- a/Tests/GenerationTests/GenerationIntegrationTests.swift +++ b/Tests/GenerationTests/GenerationIntegrationTests.swift @@ -54,6 +54,7 @@ final class GenerationIntegrationTests: XCTestCase { var config = GenerationConfig(maxNewTokens: 3) config.doSample = false // Greedy mode config.eosTokenId = -1 // Disable early stopping + config.maxLength = 10 let generation = TestGeneration() let startTokens = [0] // Start with token 0 From fff7a3cd9bc249435c03bdc460cd2a708271ad66 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 29 Sep 2025 21:24:54 +0200 Subject: [PATCH 10/14] Throws: If penalty is not strictly positive --- .../RepetitionPenaltyLogitsProcessor.swift | 15 +++++++++++---- Tests/GenerationTests/LogitsProcessorTests.swift | 6 +++--- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift b/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift index 43c90f6..006bac3 100644 --- a/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift +++ b/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift @@ -1,6 +1,11 @@ #if canImport(CoreML) import CoreML +/// Error thrown by logits processors +public enum LogitsProcessorError: Error { + case invalidParameter(String) +} + /// LogitsProcessor that prevents repetition of previous tokens through a penalty. /// /// For each token that has already appeared in the sequence: @@ -24,8 +29,11 @@ public struct RepetitionPenaltyLogitsProcessor: LogitsProcessor { /// Creates a repetition penalty logits processor. /// /// - Parameter penalty: Penalty factor. Values > 1.0 penalize repetition, values < 1.0 encourage it. - public init(penalty: Float) { - precondition(penalty > 0, "penalty must be strictly positive, got \(penalty)") + /// - Throws: If penalty is not strictly positive + public init(penalty: Float) throws { + guard penalty > 0 else { + throw LogitsProcessorError.invalidParameter("penalty must be strictly positive, got \(penalty)") + } self.penalty = penalty } @@ -42,8 +50,7 @@ public struct RepetitionPenaltyLogitsProcessor: LogitsProcessor { // Apply conditional penalty based on sign (vectorized) let negativeScores = gatheredScores .< 0.0 - let penalizedScores = negativeScores.cast(to: Float.self) * (gatheredScores * penalty) + - (1.0 - negativeScores.cast(to: Float.self)) * (gatheredScores / penalty) + let penalizedScores = negativeScores.cast(to: Float.self) * (gatheredScores * penalty) + (1.0 - negativeScores.cast(to: Float.self)) * (gatheredScores / penalty) // Scatter penalized values back to original positions // Note: MLTensor doesn't have direct scatter, so we use CPU operations for this step diff --git a/Tests/GenerationTests/LogitsProcessorTests.swift b/Tests/GenerationTests/LogitsProcessorTests.swift index 701f0c4..c2e8283 100644 --- a/Tests/GenerationTests/LogitsProcessorTests.swift +++ b/Tests/GenerationTests/LogitsProcessorTests.swift @@ -108,7 +108,7 @@ final class LogitsProcessorTests: XCTestCase { // MARK: - Repetition Penalty Tests func testRepetitionPenaltyProcessor() async throws { - let processor = RepetitionPenaltyLogitsProcessor(penalty: 2.0) + let processor = try RepetitionPenaltyLogitsProcessor(penalty: 2.0) // Input sequence with tokens [1, 2, 3] let inputIds = MLTensor(shape: [1, 3], scalars: [Int32(1), Int32(2), Int32(3)], scalarType: Int32.self) @@ -136,7 +136,7 @@ final class LogitsProcessorTests: XCTestCase { } func testRepetitionPenaltyWithNoPenalty() async throws { - let processor = RepetitionPenaltyLogitsProcessor(penalty: 1.0) + let processor = try RepetitionPenaltyLogitsProcessor(penalty: 1.0) let inputIds = MLTensor(shape: [1, 2], scalars: [Int32(1), Int32(2)], scalarType: Int32.self) let scores = MLTensor(shape: [1, 5], scalars: [Float(1.0), Float(2.0), Float(3.0), Float(4.0), Float(5.0)], scalarType: Float.self) @@ -150,7 +150,7 @@ final class LogitsProcessorTests: XCTestCase { } func testRepetitionPenaltyWithRank3Scores() async throws { - let processor = RepetitionPenaltyLogitsProcessor(penalty: 2.0) + let processor = try RepetitionPenaltyLogitsProcessor(penalty: 2.0) // Input sequence with tokens [1, 2, 3] let inputIds = MLTensor(shape: [1, 3], scalars: [Int32(1), Int32(2), Int32(3)], scalarType: Int32.self) From 9754942b64182d072b154a0c287266a04adb5740 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 29 Sep 2025 21:25:09 +0200 Subject: [PATCH 11/14] remove unused selectNextTokenUsingTopKSampling --- Sources/Generation/Decoders.swift | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/Sources/Generation/Decoders.swift b/Sources/Generation/Decoders.swift index a76a101..e6dfd73 100644 --- a/Sources/Generation/Decoders.swift +++ b/Sources/Generation/Decoders.swift @@ -52,26 +52,4 @@ func selectNextTokenUsingSampling(from scores: MLTensor) -> MLTensor { // Ensure indices are Int32 for concatenation with input tokens return sampledIndex.scalarType == Int32.self ? sampledIndex : sampledIndex.cast(to: Int32.self) } - -// MARK: Legacy Top-K Sampling (deprecated, use LogitsProcessorList instead) - -/// Legacy top-k sampling function that combines temperature, top-k, and sampling. -/// -/// - Note: This function is deprecated. Use `selectNextTokenUsingSampling` with -/// `TemperatureLogitsWarper` and `TopKLogitsWarper` in a `LogitsProcessorList` instead. -@available(macOS 15.0, iOS 18.0, *) -func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float, topK: Int) -> MLTensor { - let temperatureAdjustedScores = scores / temperature - let (topKScores, topKIndices) = temperatureAdjustedScores.topK(topK) - let topKProbs = topKScores.softmax(alongAxis: -1) - let rnd = topKProbs.sum() * Float.random(in: 0..<1) - var accumTopKProbs = topKProbs.cumulativeSum(alongAxis: -1) - accumTopKProbs += (accumTopKProbs .< rnd) * 100.0 - let topKIndex = accumTopKProbs.argsort()[..., 0] - let nextTokenTensor = topKIndices.gathering( - atIndices: topKIndex, - alongAxis: topKIndices.rank - 1 - ) - return nextTokenTensor.reshaped(to: [1, 1]) -} #endif // canImport(CoreML) From d6a7526b58bc4bff20f5b5835bb1ab358c9991af Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 29 Sep 2025 21:25:37 +0200 Subject: [PATCH 12/14] make sure ordering of warpers is that of transformers --- Sources/Generation/Generation.swift | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift index 69ca751..09f363d 100644 --- a/Sources/Generation/Generation.swift +++ b/Sources/Generation/Generation.swift @@ -117,6 +117,13 @@ extension Generation { private func createLogitsProcessorList(config: GenerationConfig) -> LogitsProcessorList { var processors: [any LogitsProcessor] = [] + // Repetition penalty (applied before sampling warpers) + if config.repetitionPenalty != 1.0 { + if let processor = try? RepetitionPenaltyLogitsProcessor(penalty: Float(config.repetitionPenalty)) { + processors.append(processor) + } + } + // Temperature scaling (if not default) if config.temperature > 0 && config.temperature != 1.0 { processors.append(TemperatureLogitsWarper(temperature: config.temperature)) @@ -134,11 +141,6 @@ extension Generation { processors.append(TopPLogitsWarper(topP: Float(config.topP))) } - // Repetition penalty - if config.repetitionPenalty != 1.0 { - processors.append(RepetitionPenaltyLogitsProcessor(penalty: Float(config.repetitionPenalty))) - } - return LogitsProcessorList(processors: processors) } From aec446deee34f9168b07f61394990881c2b99ce7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 30 Sep 2025 09:25:02 +0200 Subject: [PATCH 13/14] throw --- Sources/Generation/Generation.swift | 12 +++++++++--- .../LogitsWarper/TemperatureLogitsWarper.swift | 7 +++++-- .../LogitsWarper/TopKLogitsWarper.swift | 12 +++++++++--- .../LogitsWarper/TopPLogitsWarper.swift | 11 ++++++++--- Tests/GenerationTests/LogitsProcessorTests.swift | 16 ++++++++-------- 5 files changed, 39 insertions(+), 19 deletions(-) diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift index 09f363d..72a93a5 100644 --- a/Sources/Generation/Generation.swift +++ b/Sources/Generation/Generation.swift @@ -126,19 +126,25 @@ extension Generation { // Temperature scaling (if not default) if config.temperature > 0 && config.temperature != 1.0 { - processors.append(TemperatureLogitsWarper(temperature: config.temperature)) + if let processor = try? TemperatureLogitsWarper(temperature: config.temperature) { + processors.append(processor) + } } // Top-K filtering (only apply if topK is meaningful) // Note: We can't determine vocab size here, so TopKLogitsWarper handles the case // where topK >= vocabSize internally if config.topK > 0 && config.topK < Int.max { - processors.append(TopKLogitsWarper(topK: config.topK)) + if let processor = try? TopKLogitsWarper(topK: config.topK) { + processors.append(processor) + } } // Top-P (nucleus) sampling if config.topP < 1.0 { - processors.append(TopPLogitsWarper(topP: Float(config.topP))) + if let processor = try? TopPLogitsWarper(topP: Float(config.topP)) { + processors.append(processor) + } } return LogitsProcessorList(processors: processors) diff --git a/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift b/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift index ad27849..e3f9383 100644 --- a/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift +++ b/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift @@ -19,8 +19,11 @@ public struct TemperatureLogitsWarper: LogitsProcessor { /// /// - Parameter temperature: Strictly positive float value used to modulate the logits distribution. /// Must be > 0. Values close to 0 approximate greedy decoding. - public init(temperature: Float) { - precondition(temperature > 0, "temperature must be strictly positive, got \(temperature)") + /// - Throws: If temperature is not strictly positive + public init(temperature: Float) throws { + guard temperature > 0 else { + throw LogitsProcessorError.invalidParameter("temperature must be strictly positive, got \(temperature)") + } self.temperature = temperature } diff --git a/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift index 2dc153c..5472c82 100644 --- a/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift +++ b/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift @@ -22,9 +22,15 @@ public struct TopKLogitsWarper: LogitsProcessor { /// - topK: Number of highest probability tokens to keep /// - filterValue: Value to set filtered tokens to (default: -infinity) /// - minTokensToKeep: Minimum tokens that cannot be filtered (default: 1) - public init(topK: Int, filterValue: Float = -.infinity, minTokensToKeep: Int = 1) { - precondition(topK > 0, "topK must be strictly positive, got \(topK)") - precondition(minTokensToKeep >= 1, "minTokensToKeep must be at least 1, got \(minTokensToKeep)") + /// - Throws: If topK is not strictly positive or if minTokensToKeep is less than 1 + /// - Note: If topK is larger than the vocabulary size, no filtering is applied. + public init(topK: Int, filterValue: Float = -.infinity, minTokensToKeep: Int = 1) throws { + guard topK > 0 else { + throw LogitsProcessorError.invalidParameter("topK must be strictly positive, got \(topK)") + } + guard minTokensToKeep >= 1 else { + throw LogitsProcessorError.invalidParameter("minTokensToKeep must be at least 1, got \(minTokensToKeep)") + } self.topK = max(topK, minTokensToKeep) self.filterValue = filterValue diff --git a/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift index 725589b..e3445a4 100644 --- a/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift +++ b/Sources/Generation/LogitsWarper/TopPLogitsWarper.swift @@ -26,9 +26,14 @@ public struct TopPLogitsWarper: LogitsProcessor { /// - topP: Cumulative probability threshold. Must be between 0 and 1. /// - filterValue: Value to set filtered tokens to (default: -infinity) /// - minTokensToKeep: Minimum tokens that cannot be filtered (default: 1) - public init(topP: Float, filterValue: Float = -.infinity, minTokensToKeep: Int = 1) { - precondition(topP >= 0 && topP <= 1.0, "topP must be in [0, 1], got \(topP)") - precondition(minTokensToKeep >= 1, "minTokensToKeep must be at least 1, got \(minTokensToKeep)") + /// - Throws: If topP is not in [0, 1] or if minTokensToKeep is less than 1 + public init(topP: Float, filterValue: Float = -.infinity, minTokensToKeep: Int = 1) throws { + guard topP >= 0 && topP <= 1.0 else { + throw LogitsProcessorError.invalidParameter("topP must be in [0, 1], got \(topP)") + } + guard minTokensToKeep >= 1 else { + throw LogitsProcessorError.invalidParameter("minTokensToKeep must be at least 1, got \(minTokensToKeep)") + } self.topP = topP self.filterValue = filterValue diff --git a/Tests/GenerationTests/LogitsProcessorTests.swift b/Tests/GenerationTests/LogitsProcessorTests.swift index c2e8283..2791023 100644 --- a/Tests/GenerationTests/LogitsProcessorTests.swift +++ b/Tests/GenerationTests/LogitsProcessorTests.swift @@ -10,7 +10,7 @@ final class LogitsProcessorTests: XCTestCase { // MARK: - Temperature Tests func testTemperatureWarper() async throws { - let warper = TemperatureLogitsWarper(temperature: 2.0) + let warper = try TemperatureLogitsWarper(temperature: 2.0) // Create input: batch_size=1, seq_len=3 let inputIds = MLTensor(shape: [1, 3], scalars: [Int32(1), Int32(2), Int32(3)], scalarType: Int32.self) @@ -25,7 +25,7 @@ final class LogitsProcessorTests: XCTestCase { func testTemperatureWarperWithDifferentValues() async throws { // Test temperature < 1 (sharper distribution) - let sharper = TemperatureLogitsWarper(temperature: 0.5) + let sharper = try TemperatureLogitsWarper(temperature: 0.5) let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) let scores = MLTensor(shape: [1, 2], scalars: [Float(1.0), Float(2.0)], scalarType: Float.self) @@ -38,7 +38,7 @@ final class LogitsProcessorTests: XCTestCase { // MARK: - Top-K Tests func testTopKWarper() async throws { - let warper = TopKLogitsWarper(topK: 3) + let warper = try TopKLogitsWarper(topK: 3) let inputIds = MLTensor(shape: [1, 2], scalars: [Int32(1), Int32(2)], scalarType: Int32.self) let scores = MLTensor(shape: [1, 5], scalars: [Float(1.0), Float(2.0), Float(3.0), Float(4.0), Float(5.0)], scalarType: Float.self) @@ -55,7 +55,7 @@ final class LogitsProcessorTests: XCTestCase { } func testTopKWarperWithSmallK() async throws { - let warper = TopKLogitsWarper(topK: 1) + let warper = try TopKLogitsWarper(topK: 1) let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) let scores = MLTensor(shape: [1, 3], scalars: [Float(1.0), Float(5.0), Float(3.0)], scalarType: Float.self) @@ -72,7 +72,7 @@ final class LogitsProcessorTests: XCTestCase { // MARK: - Top-P Tests func testTopPWarper() async throws { - let warper = TopPLogitsWarper(topP: 0.9) + let warper = try TopPLogitsWarper(topP: 0.9) let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) // Create a distribution where top tokens dominate: [0.0, 1.0, 2.0, 3.0, 10.0] @@ -92,7 +92,7 @@ final class LogitsProcessorTests: XCTestCase { func testTopPWarperWithHighThreshold() async throws { // With topP=0.99, almost all tokens should be kept - let warper = TopPLogitsWarper(topP: 0.99) + let warper = try TopPLogitsWarper(topP: 0.99) let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) let scores = MLTensor(shape: [1, 5], scalars: [Float(1.0), Float(2.0), Float(3.0), Float(4.0), Float(5.0)], scalarType: Float.self) @@ -175,8 +175,8 @@ final class LogitsProcessorTests: XCTestCase { // MARK: - Processor List Tests func testLogitsProcessorList() async throws { - let temp = TemperatureLogitsWarper(temperature: 2.0) - let topK = TopKLogitsWarper(topK: 3) + let temp = try TemperatureLogitsWarper(temperature: 2.0) + let topK = try TopKLogitsWarper(topK: 3) let processorList = LogitsProcessorList(processors: [temp, topK]) let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) From 501f192ea2c8389e5c4b1088eb0cb0ce9d7e9cab Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 30 Sep 2025 09:41:23 +0200 Subject: [PATCH 14/14] add Min-P --- .../transformers-cli/Transformers.swift | 12 +- Sources/Generation/Generation.swift | 7 ++ Sources/Generation/GenerationConfig.swift | 18 ++- .../LogitsWarper/MinPLogitsWarper.swift | 114 ++++++++++++++++++ .../GenerationIntegrationTests.swift | 48 ++++++++ .../LogitsProcessorTests.swift | 108 +++++++++++++++++ 6 files changed, 298 insertions(+), 9 deletions(-) create mode 100644 Sources/Generation/LogitsWarper/MinPLogitsWarper.swift diff --git a/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift b/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift index e02c621..77732dc 100644 --- a/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift +++ b/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift @@ -34,17 +34,20 @@ struct TransformersCLI: AsyncParsableCommand { @Option(help: "Enable sampling mode (true) or use greedy decoding (false)") var doSample: Bool = false - @Option(help: "Temperature for sampling (lower = more deterministic, typical: 0.7-1.0)") + @Option(help: "Temperature for sampling (lower = more deterministic, typical: 0.1-2.0)") var temperature: Float? @Option(help: "Top-k filtering - only consider k most likely tokens (typical: 5-50)") var topK: Int? @Option(help: "Top-p (nucleus) sampling - cumulative probability threshold (typical: 0.9-0.95)") - var topP: Double? + var topP: Float? + + @Option(help: "Min-p sampling - minimum probability threshold scaled by top token (typical: 0.01-0.2)") + var minP: Float? @Option(help: "Repetition penalty to discourage repeating tokens (typical: 1.0-2.0, 1.0 = no penalty)") - var repetitionPenalty: Double? + var repetitionPenalty: Float? func generate( model: LanguageModel, @@ -116,6 +119,9 @@ struct TransformersCLI: AsyncParsableCommand { if let topP = topP { config.topP = topP } + if let minP = minP { + config.minP = minP + } if let repetitionPenalty = repetitionPenalty { config.repetitionPenalty = repetitionPenalty } diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift index 72a93a5..837b836 100644 --- a/Sources/Generation/Generation.swift +++ b/Sources/Generation/Generation.swift @@ -147,6 +147,13 @@ extension Generation { } } + // Min-P sampling (applied after temperature scaling) + if let minP = config.minP { + if let processor = try? MinPLogitsWarper(minP: Float(minP)) { + processors.append(processor) + } + } + return LogitsProcessorList(processors: processors) } diff --git a/Sources/Generation/GenerationConfig.swift b/Sources/Generation/GenerationConfig.swift index e7389fb..db125a9 100644 --- a/Sources/Generation/GenerationConfig.swift +++ b/Sources/Generation/GenerationConfig.swift @@ -39,10 +39,13 @@ public struct GenerationConfig { public var topK = 50 /// Cumulative probability threshold for top-p sampling. - public var topP = 1.0 + public var topP: Float = 1.0 + + /// Minimum token probability threshold, scaled by the most likely token's probability. + public var minP: Float? /// Penalty for token repetition (1.0 means no penalty). - public var repetitionPenalty = 1.0 + public var repetitionPenalty: Float = 1.0 /// Token ID used for padding sequences. public var padTokenId: Int? @@ -65,6 +68,7 @@ public struct GenerationConfig { /// - temperature: Sampling temperature /// - topK: Top-k sampling parameter /// - topP: Top-p sampling parameter + /// - minP: Min-p sampling parameter /// - repetitionPenalty: Repetition penalty factor public init( maxLength: Int = 20, @@ -73,10 +77,11 @@ public struct GenerationConfig { numBeams: Int = 1, numBeamGroups: Int = 1, penaltyAlpha: Double? = nil, - temperature: Double = 1.0, + temperature: Float = 1.0, topK: Int = 50, - topP: Double = 1.0, - repetitionPenalty: Double = 1.0 + topP: Float = 1.0, + minP: Float? = nil, + repetitionPenalty: Float = 1.0 ) { self.maxLength = maxLength self.maxNewTokens = maxNewTokens @@ -84,9 +89,10 @@ public struct GenerationConfig { self.numBeams = numBeams self.numBeamGroups = numBeamGroups self.penaltyAlpha = penaltyAlpha - self.temperature = Float(temperature) + self.temperature = temperature self.topK = topK self.topP = topP + self.minP = minP self.repetitionPenalty = repetitionPenalty } } diff --git a/Sources/Generation/LogitsWarper/MinPLogitsWarper.swift b/Sources/Generation/LogitsWarper/MinPLogitsWarper.swift new file mode 100644 index 0000000..9b2867b --- /dev/null +++ b/Sources/Generation/LogitsWarper/MinPLogitsWarper.swift @@ -0,0 +1,114 @@ +#if canImport(CoreML) +import CoreML + +/// LogitsProcessor that performs min-p filtering on the logits. +/// +/// Min-p keeps all tokens that are above a minimum probability, scaled by the probability +/// of the most likely token. As a result, the filter becomes more aggressive in the presence +/// of high-probability tokens, which is a sign of a confident output that we shouldn't deviate from. +/// +/// Often used together with `TemperatureLogitsWarper`. Used as an alternative to `TopPLogitsWarper` +/// and `TopKLogitsWarper`. +/// +/// Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in +/// the 0.99-0.8 range (use the opposite of normal `top_p` values). +/// +/// Based on: +/// - https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L460 +@available(macOS 15.0, iOS 18.0, *) +public struct MinPLogitsWarper: LogitsProcessor { + public let minP: Float + public let minTokensToKeep: Int + public let filterValue: Float + + /// Creates a min-p logits warper. + /// + /// - Parameters: + /// - minP: Minimum token probability, which will be scaled by the probability of the most likely token. + /// Must be between 0 and 1. Typical values are 0.01-0.2. + /// - minTokensToKeep: Minimum number of tokens that cannot be filtered. + /// - filterValue: Value to set for filtered tokens (default: -infinity) + /// - Throws: If parameters are invalid + public init(minP: Float, minTokensToKeep: Int = 1, filterValue: Float = -Float.infinity) throws { + guard minP >= 0 && minP <= 1.0 else { + throw LogitsProcessorError.invalidParameter("minP must be in [0, 1], got \(minP)") + } + guard minTokensToKeep >= 1 else { + throw LogitsProcessorError.invalidParameter("minTokensToKeep must be >= 1, got \(minTokensToKeep)") + } + self.minP = minP + self.minTokensToKeep = minTokensToKeep + self.filterValue = filterValue + } + + public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { + // Algorithm (following transformers implementation): + // 1. Compute probabilities from logits + // 2. Find max probability per batch + // 3. Create threshold = minP * maxProb + // 4. Sort logits and mask tokens where prob < threshold + // 5. Keep at least minTokensToKeep + // 6. Scatter back to original order + + let vocabSize = scores.shape[scores.rank - 1] + + // Compute probabilities + let probs = scores.softmax(alongAxis: -1) + + // Sort probabilities descending to get max (first element) + let sortedProbIndices = probs.argsort(alongAxis: -1, descendingOrder: true) + let sortedProbs = probs.gathering(atIndices: sortedProbIndices, alongAxis: -1) + + // Extract max prob per batch: first element of each sorted sequence + // Do this on CPU to avoid complex broadcasting issues + let sortedProbsArray = await sortedProbs.shapedArray(of: Float.self) + let batchSize = scores.shape[0] + var thresholdScalars = [Float]() + thresholdScalars.reserveCapacity(batchSize * vocabSize) + for batchIdx in 0..= minTokensToKeep AND shouldRemove) + let beyondMinimum = positions .>= Int32(minTokensToKeep) + let finalRemoveMask = sortedTokensToRemove .& beyondMinimum + + // Apply filter in sorted space + let sortedScores = scores.gathering(atIndices: sortedScoreIndices, alongAxis: -1) + let filterTensor = MLTensor(repeating: filterValue, shape: sortedScores.shape, scalarType: Float.self) + let filteredSorted = sortedScores.replacing(with: filterTensor, where: finalRemoveMask) + + // Scatter back to original order + return filteredSorted.gathering(atIndices: inversePermutation, alongAxis: -1) + } +} +#endif diff --git a/Tests/GenerationTests/GenerationIntegrationTests.swift b/Tests/GenerationTests/GenerationIntegrationTests.swift index 637760f..8963a65 100644 --- a/Tests/GenerationTests/GenerationIntegrationTests.swift +++ b/Tests/GenerationTests/GenerationIntegrationTests.swift @@ -246,6 +246,54 @@ final class GenerationIntegrationTests: XCTestCase { XCTAssertEqual(model.callCount, 3, "Model should be called 3 times") } + func testMinPFiltering() async throws { + let model = MockLanguageModel() + + // Test with min-p: keep tokens with prob >= minP * max_prob + var configWithMinP = GenerationConfig(maxNewTokens: 3) + configWithMinP.doSample = true // Sampling mode + configWithMinP.temperature = 1.0 // No temperature adjustment + configWithMinP.minP = 0.05 // Relatively permissive threshold + configWithMinP.eosTokenId = -1 + configWithMinP.maxLength = 10 + + let generation = TestGeneration() + let startTokens = [0] + + let output = await generation.generate( + config: configWithMinP, + tokens: startTokens, + model: model.predictNextToken + ) + + XCTAssertEqual(output.count, 4, "Should generate 3 new tokens + initial token") + XCTAssertEqual(output[0], 0, "First token should be the start token") + + // All tokens should be valid + for token in output[1...] { + XCTAssertTrue(token >= 0 && token < 10, "Token should be in vocab range") + } + + // Test with more aggressive min-p + model.callCount = 0 + var configAggressiveMinP = GenerationConfig(maxNewTokens: 3) + configAggressiveMinP.doSample = true + configAggressiveMinP.temperature = 1.0 + configAggressiveMinP.minP = 0.5 // Much more aggressive + configAggressiveMinP.eosTokenId = -1 + configAggressiveMinP.maxLength = 10 + + let outputAggressive = await generation.generate( + config: configAggressiveMinP, + tokens: startTokens, + model: model.predictNextToken + ) + + XCTAssertEqual(outputAggressive.count, 4, "Should generate 3 new tokens + initial token") + // With aggressive min-p, should sample from fewer options + // (exact behavior depends on model, but verify it doesn't crash) + } + func testEarlyStoppingWithEOS() async throws { // Create a model that returns EOS token after 2 generations class EOSModel { diff --git a/Tests/GenerationTests/LogitsProcessorTests.swift b/Tests/GenerationTests/LogitsProcessorTests.swift index 2791023..44d16f2 100644 --- a/Tests/GenerationTests/LogitsProcessorTests.swift +++ b/Tests/GenerationTests/LogitsProcessorTests.swift @@ -207,6 +207,114 @@ final class LogitsProcessorTests: XCTestCase { // Should be unchanged XCTAssertEqual(resultArray, expectedArray) } + + // MARK: - Min-P Tests + + func testMinPWarper() async throws { + let warper = try MinPLogitsWarper(minP: 0.1) + + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + // Scores: [1.0, 2.0, 3.0, 4.0, 5.0] + // After softmax, probabilities will be computed + // Max prob will be for score=5.0 + // Min threshold = 0.1 * max_prob + // Tokens with prob < threshold should be filtered + let scores = MLTensor(shape: [1, 5], scalars: [Float(1.0), Float(2.0), Float(3.0), Float(4.0), Float(5.0)], scalarType: Float.self) + + let result = await warper(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + // Compute expected: softmax probabilities manually + let scoresArray = await scores.shapedArray(of: Float.self).scalars + let expScores = scoresArray.map { exp($0) } + let sumExp = expScores.reduce(0, +) + let probs = expScores.map { $0 / sumExp } + let maxProb = probs.max()! + let threshold = 0.1 * maxProb + + // Check that low probability tokens are filtered + for (idx, prob) in probs.enumerated() { + if prob < threshold { + XCTAssertTrue(resultArray[idx].isInfinite && resultArray[idx] < 0, "Token \(idx) with prob \(prob) should be filtered") + } else { + XCTAssertEqual(resultArray[idx], scoresArray[idx], accuracy: accuracy, "Token \(idx) should not be filtered") + } + } + } + + func testMinPWarperKeepsMinTokens() async throws { + // Even with aggressive minP, should keep at least minTokensToKeep tokens + let warper = try MinPLogitsWarper(minP: 0.99, minTokensToKeep: 2) + + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 5], scalars: [Float(1.0), Float(2.0), Float(3.0), Float(4.0), Float(5.0)], scalarType: Float.self) + + let result = await warper(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + // Count non-infinite values + let nonInfiniteCount = resultArray.filter { !$0.isInfinite }.count + XCTAssertGreaterThanOrEqual(nonInfiniteCount, 2, "Should keep at least 2 tokens") + } + + func testMinPWarperWithLowThreshold() async throws { + // With very low minP, most tokens should pass + let warper = try MinPLogitsWarper(minP: 0.001) + + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 5], scalars: [Float(1.0), Float(2.0), Float(3.0), Float(4.0), Float(5.0)], scalarType: Float.self) + + let result = await warper(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + // Most or all tokens should remain + let nonInfiniteCount = resultArray.filter { !$0.isInfinite }.count + XCTAssertGreaterThanOrEqual(nonInfiniteCount, 4, "With low minP, most tokens should pass") + } + + func testMinPWarperInvalidParameters() { + // Test invalid minP + XCTAssertThrowsError(try MinPLogitsWarper(minP: -0.1)) + XCTAssertThrowsError(try MinPLogitsWarper(minP: 1.5)) + + // Test invalid minTokensToKeep + XCTAssertThrowsError(try MinPLogitsWarper(minP: 0.1, minTokensToKeep: 0)) + XCTAssertThrowsError(try MinPLogitsWarper(minP: 0.1, minTokensToKeep: -1)) + } + + // MARK: - Parameter Validation Tests + + func testTemperatureWarperInvalidParameters() { + // Test invalid temperature values + XCTAssertThrowsError(try TemperatureLogitsWarper(temperature: 0.0)) + XCTAssertThrowsError(try TemperatureLogitsWarper(temperature: -1.0)) + } + + func testTopKWarperInvalidParameters() { + // Test invalid topK values + XCTAssertThrowsError(try TopKLogitsWarper(topK: 0)) + XCTAssertThrowsError(try TopKLogitsWarper(topK: -1)) + + // Test invalid minTokensToKeep + XCTAssertThrowsError(try TopKLogitsWarper(topK: 5, minTokensToKeep: 0)) + XCTAssertThrowsError(try TopKLogitsWarper(topK: 5, minTokensToKeep: -1)) + } + + func testTopPWarperInvalidParameters() { + // Test invalid topP values + XCTAssertThrowsError(try TopPLogitsWarper(topP: -0.1)) + XCTAssertThrowsError(try TopPLogitsWarper(topP: 1.5)) + + // Test invalid minTokensToKeep + XCTAssertThrowsError(try TopPLogitsWarper(topP: 0.9, minTokensToKeep: 0)) + XCTAssertThrowsError(try TopPLogitsWarper(topP: 0.9, minTokensToKeep: -1)) + } + + func testRepetitionPenaltyInvalidParameters() { + // Test invalid penalty values + XCTAssertThrowsError(try RepetitionPenaltyLogitsProcessor(penalty: 0.0)) + XCTAssertThrowsError(try RepetitionPenaltyLogitsProcessor(penalty: -1.0)) + } } // MARK: - Test Helpers