diff --git a/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift b/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift index 0ff53b5..77732dc 100644 --- a/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift +++ b/Examples/transformers-cli/Sources/transformers-cli/Transformers.swift @@ -27,10 +27,28 @@ struct TransformersCLI: AsyncParsableCommand { @Option( help: """ When enabled, two generation passes are ran, one to 'warm up' and another to collect \ - benchmark metrics. + benchmark metrics. """) var warmup: Bool = false + @Option(help: "Enable sampling mode (true) or use greedy decoding (false)") + var doSample: Bool = false + + @Option(help: "Temperature for sampling (lower = more deterministic, typical: 0.1-2.0)") + var temperature: Float? + + @Option(help: "Top-k filtering - only consider k most likely tokens (typical: 5-50)") + var topK: Int? + + @Option(help: "Top-p (nucleus) sampling - cumulative probability threshold (typical: 0.9-0.95)") + var topP: Float? + + @Option(help: "Min-p sampling - minimum probability threshold scaled by top token (typical: 0.01-0.2)") + var minP: Float? + + @Option(help: "Repetition penalty to discourage repeating tokens (typical: 1.0-2.0, 1.0 = no penalty)") + var repetitionPenalty: Float? + func generate( model: LanguageModel, config: GenerationConfig, @@ -88,11 +106,26 @@ struct TransformersCLI: AsyncParsableCommand { print("Loading model \(compiledURL)") let model = try LanguageModel.loadCompiled(url: compiledURL, computeUnits: computeUnits.asMLComputeUnits) - // Using greedy generation for now var config = model.defaultGenerationConfig - config.doSample = false + config.doSample = doSample config.maxNewTokens = maxLength + if let temperature = temperature { + config.temperature = temperature + } + if let topK = topK { + config.topK = topK + } + if let topP = topP { + config.topP = topP + } + if let minP = minP { + config.minP = minP + } + if let repetitionPenalty = repetitionPenalty { + config.repetitionPenalty = repetitionPenalty + } + // Given the size of the out-of-model computation, dispatch all // tensor operations to the CPU. diff --git a/Package.swift b/Package.swift index dc6c3f5..fb69f9f 100644 --- a/Package.swift +++ b/Package.swift @@ -24,6 +24,7 @@ let package = Package( .target(name: "Hub", dependencies: [.product(name: "Jinja", package: "swift-jinja")], resources: [.process("Resources")], swiftSettings: swiftSettings), .target(name: "Models", dependencies: ["Tokenizers", "Generation"]), .target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")]), + .testTarget(name: "GenerationTests", dependencies: ["Generation"]), .testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")], swiftSettings: swiftSettings), .testTarget(name: "ModelsTests", dependencies: ["Models", "Hub"], resources: [.process("Resources")]), .testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources")]), diff --git a/Sources/Generation/Decoders.swift b/Sources/Generation/Decoders.swift index fafbd1b..e6dfd73 100644 --- a/Sources/Generation/Decoders.swift +++ b/Sources/Generation/Decoders.swift @@ -5,24 +5,51 @@ import CoreML @available(macOS 15.0, iOS 18.0, *) func selectNextTokenUsingGreedyDecoding(from scores: MLTensor) -> MLTensor { - scores.argmax(alongAxis: -1).reshaped(to: [1, 1]) + let indices = scores.argmax(alongAxis: -1).reshaped(to: [1, 1]) + // Ensure indices are Int32 for concatenation with input tokens + return indices.scalarType == Int32.self ? indices : indices.cast(to: Int32.self) } -// MARK: Top-K Sampling +// MARK: Sampling +/// Performs multinomial sampling from processed logits. +/// +/// Assumes logits have already been processed by LogitsProcessorList +/// (temperature, top-k, top-p, etc. already applied). +/// +/// - Parameter scores: Processed logits tensor [batch_size, vocab_size] +/// - Returns: Sampled token ID tensor [batch_size, 1] @available(macOS 15.0, iOS 18.0, *) -func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float, topK: Int) -> MLTensor { - let temperatureAdjustedScores = scores / temperature - let (topKScores, topKIndices) = temperatureAdjustedScores.topK(topK) - let topKProbs = topKScores.softmax(alongAxis: -1) - let rnd = topKProbs.sum() * Float.random(in: 0..<1) - var accumTopKProbs = topKProbs.cumulativeSum(alongAxis: -1) - accumTopKProbs += (accumTopKProbs .< rnd) * 100.0 - let topKIndex = accumTopKProbs.argsort()[..., 0] - let nextTokenTensor = topKIndices.gathering( - atIndices: topKIndex, - alongAxis: topKIndices.rank - 1 - ) - return nextTokenTensor.reshaped(to: [1, 1]) +func selectNextTokenUsingSampling(from scores: MLTensor) -> MLTensor { + // Convert logits to probabilities + let probs = scores.softmax(alongAxis: -1) + + // Multinomial sampling using cumulative sum method: + // 1. Generate random number in [0, 1) + // 2. Compute cumulative sum of probabilities + // 3. Find first index where cumsum >= random_number + // + // This is equivalent to torch.multinomial() but using available MLTensor ops + + let batchSize = scores.shape[0] + let rndTensor = MLTensor(randomUniform: [batchSize, 1], in: 0..<1, scalarType: Float.self) + let cumulativeProbs = probs.cumulativeSum(alongAxis: -1) + + // Ensure random tensor matches the type of cumulativeProbs + let rnd = cumulativeProbs.scalarType == Float.self ? rndTensor : rndTensor.cast(to: cumulativeProbs.scalarType) + + // Create mask where cumsum >= rnd (these are candidates) + // We want the FIRST position where this is true + // Strategy: Set all positions where cumsum < rnd to a large value (1000.0) + // Set all positions where cumsum >= rnd to their index value + // Then argmin will give us the first qualifying index + + let mask = cumulativeProbs .< rnd + let penalized = mask * 1000.0 // Large value for positions to skip + let indexed = penalized + cumulativeProbs // Positions >= rnd will have small values + + let sampledIndex = indexed.argmin(alongAxis: -1).reshaped(to: [1, 1]) + // Ensure indices are Int32 for concatenation with input tokens + return sampledIndex.scalarType == Int32.self ? sampledIndex : sampledIndex.cast(to: Int32.self) } #endif // canImport(CoreML) diff --git a/Sources/Generation/Generation.swift b/Sources/Generation/Generation.swift index 0cdfd37..837b836 100644 --- a/Sources/Generation/Generation.swift +++ b/Sources/Generation/Generation.swift @@ -72,18 +72,27 @@ extension Generation { ) async -> GenerationOutput { let tokens = tokens.map { Int32($0) } var outputTokens = MLTensor(tokens).expandingShape(at: 0) - while outputTokens.shape[1] < config.maxLength { + + // Create logits processor list based on config + let logitsProcessorList = createLogitsProcessorList(config: config) + + let inputLength = outputTokens.shape[1] + let maxTotalLength = min(config.maxLength, inputLength + config.maxNewTokens) + + while outputTokens.shape[1] < maxTotalLength { + // Get raw logits from model let nextTokenScores = await model(outputTokens, config) + + // Apply logits processors + let processedScores = await logitsProcessorList(outputTokens, nextTokenScores) + + // Select next token based on generation mode let nextToken = switch config.generationMode { case .greedy: - selectNextTokenUsingGreedyDecoding(from: nextTokenScores) + selectNextTokenUsingGreedyDecoding(from: processedScores) case .sample: - selectNextTokenUsingTopKSampling( - from: nextTokenScores, - temperature: config.temperature, - topK: config.topK - ) + selectNextTokenUsingSampling(from: processedScores) default: fatalError("Generation mode \(config.generationMode) not implemented yet") } @@ -101,6 +110,53 @@ extension Generation { return await tensorToGenerationOutput(outputTokens) } + /// Creates a list of logits processors based on generation configuration. + /// + /// - Parameter config: Generation configuration specifying which processors to apply + /// - Returns: List of logits processors to apply during generation + private func createLogitsProcessorList(config: GenerationConfig) -> LogitsProcessorList { + var processors: [any LogitsProcessor] = [] + + // Repetition penalty (applied before sampling warpers) + if config.repetitionPenalty != 1.0 { + if let processor = try? RepetitionPenaltyLogitsProcessor(penalty: Float(config.repetitionPenalty)) { + processors.append(processor) + } + } + + // Temperature scaling (if not default) + if config.temperature > 0 && config.temperature != 1.0 { + if let processor = try? TemperatureLogitsWarper(temperature: config.temperature) { + processors.append(processor) + } + } + + // Top-K filtering (only apply if topK is meaningful) + // Note: We can't determine vocab size here, so TopKLogitsWarper handles the case + // where topK >= vocabSize internally + if config.topK > 0 && config.topK < Int.max { + if let processor = try? TopKLogitsWarper(topK: config.topK) { + processors.append(processor) + } + } + + // Top-P (nucleus) sampling + if config.topP < 1.0 { + if let processor = try? TopPLogitsWarper(topP: Float(config.topP)) { + processors.append(processor) + } + } + + // Min-P sampling (applied after temperature scaling) + if let minP = config.minP { + if let processor = try? MinPLogitsWarper(minP: Float(minP)) { + processors.append(processor) + } + } + + return LogitsProcessorList(processors: processors) + } + private func tensorToGenerationOutput(_ tensor: MLTensor) async -> GenerationOutput { await tensor.shapedArray(of: Int32.self).scalars.map { Int($0) } } diff --git a/Sources/Generation/GenerationConfig.swift b/Sources/Generation/GenerationConfig.swift index e7389fb..db125a9 100644 --- a/Sources/Generation/GenerationConfig.swift +++ b/Sources/Generation/GenerationConfig.swift @@ -39,10 +39,13 @@ public struct GenerationConfig { public var topK = 50 /// Cumulative probability threshold for top-p sampling. - public var topP = 1.0 + public var topP: Float = 1.0 + + /// Minimum token probability threshold, scaled by the most likely token's probability. + public var minP: Float? /// Penalty for token repetition (1.0 means no penalty). - public var repetitionPenalty = 1.0 + public var repetitionPenalty: Float = 1.0 /// Token ID used for padding sequences. public var padTokenId: Int? @@ -65,6 +68,7 @@ public struct GenerationConfig { /// - temperature: Sampling temperature /// - topK: Top-k sampling parameter /// - topP: Top-p sampling parameter + /// - minP: Min-p sampling parameter /// - repetitionPenalty: Repetition penalty factor public init( maxLength: Int = 20, @@ -73,10 +77,11 @@ public struct GenerationConfig { numBeams: Int = 1, numBeamGroups: Int = 1, penaltyAlpha: Double? = nil, - temperature: Double = 1.0, + temperature: Float = 1.0, topK: Int = 50, - topP: Double = 1.0, - repetitionPenalty: Double = 1.0 + topP: Float = 1.0, + minP: Float? = nil, + repetitionPenalty: Float = 1.0 ) { self.maxLength = maxLength self.maxNewTokens = maxNewTokens @@ -84,9 +89,10 @@ public struct GenerationConfig { self.numBeams = numBeams self.numBeamGroups = numBeamGroups self.penaltyAlpha = penaltyAlpha - self.temperature = Float(temperature) + self.temperature = temperature self.topK = topK self.topP = topP + self.minP = minP self.repetitionPenalty = repetitionPenalty } } diff --git a/Sources/Generation/LogitsWarper/LogitsProcessor.swift b/Sources/Generation/LogitsWarper/LogitsProcessor.swift new file mode 100644 index 0000000..a2a042b --- /dev/null +++ b/Sources/Generation/LogitsWarper/LogitsProcessor.swift @@ -0,0 +1,59 @@ +#if canImport(CoreML) +import CoreML + +/// Abstract base class for all logits processors that can be applied during generation. +/// +/// Logits processors modify the probability distribution over vocabulary tokens by transforming +/// the raw logit scores produced by language models. This enables various sampling strategies +/// such as temperature scaling, top-k/top-p filtering, and repetition penalties. +/// +/// Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py +@available(macOS 15.0, iOS 18.0, *) +public protocol LogitsProcessor { + /// Processes logits for next token prediction. + /// + /// - Parameters: + /// - inputIds: Tensor of input token IDs with shape `[batch_size, sequence_length]` + /// - scores: Tensor of raw logit scores with shape `[batch_size, vocab_size]` + /// - Returns: Processed logits tensor with shape `[batch_size, vocab_size]` + /// + /// - Note: The `inputIds` parameter provides context for processors that need to examine + /// the generated sequence (e.g., repetition penalty). Processors that don't need this + /// context (e.g., temperature) can ignore it. + func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor +} + +/// A list of logits processors that applies each processor sequentially. +/// +/// This class provides a convenient way to chain multiple logits processors together. +/// Each processor is applied in order to the logits tensor, with the output of one +/// processor becoming the input to the next. +@available(macOS 15.0, iOS 18.0, *) +public struct LogitsProcessorList { + public var processors: [any LogitsProcessor] + + public init(processors: [any LogitsProcessor]) { + self.processors = processors + } + + /// Applies all logits processors sequentially to the input scores. + /// + /// - Parameters: + /// - inputIds: Tensor of input token IDs with shape `[batch_size, sequence_length]` + /// - scores: Tensor of raw logit scores with shape `[batch_size, vocab_size]` + /// - Returns: Processed logits tensor with shape `[batch_size, vocab_size]` + public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { + // Following transformers convention: all logits processing happens in Float32 + // Cast to Float32 once at the start, process, then cast back to original type at the end + let originalScalarType = scores.scalarType + var processedScores = scores.scalarType == Float.self ? scores : scores.cast(to: Float.self) + + for processor in processors { + processedScores = await processor(inputIds, processedScores) + } + + // Cast back to original type if needed + return originalScalarType == Float.self ? processedScores : processedScores.cast(to: originalScalarType) + } +} +#endif diff --git a/Sources/Generation/LogitsWarper/MinPLogitsWarper.swift b/Sources/Generation/LogitsWarper/MinPLogitsWarper.swift new file mode 100644 index 0000000..9b2867b --- /dev/null +++ b/Sources/Generation/LogitsWarper/MinPLogitsWarper.swift @@ -0,0 +1,114 @@ +#if canImport(CoreML) +import CoreML + +/// LogitsProcessor that performs min-p filtering on the logits. +/// +/// Min-p keeps all tokens that are above a minimum probability, scaled by the probability +/// of the most likely token. As a result, the filter becomes more aggressive in the presence +/// of high-probability tokens, which is a sign of a confident output that we shouldn't deviate from. +/// +/// Often used together with `TemperatureLogitsWarper`. Used as an alternative to `TopPLogitsWarper` +/// and `TopKLogitsWarper`. +/// +/// Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in +/// the 0.99-0.8 range (use the opposite of normal `top_p` values). +/// +/// Based on: +/// - https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L460 +@available(macOS 15.0, iOS 18.0, *) +public struct MinPLogitsWarper: LogitsProcessor { + public let minP: Float + public let minTokensToKeep: Int + public let filterValue: Float + + /// Creates a min-p logits warper. + /// + /// - Parameters: + /// - minP: Minimum token probability, which will be scaled by the probability of the most likely token. + /// Must be between 0 and 1. Typical values are 0.01-0.2. + /// - minTokensToKeep: Minimum number of tokens that cannot be filtered. + /// - filterValue: Value to set for filtered tokens (default: -infinity) + /// - Throws: If parameters are invalid + public init(minP: Float, minTokensToKeep: Int = 1, filterValue: Float = -Float.infinity) throws { + guard minP >= 0 && minP <= 1.0 else { + throw LogitsProcessorError.invalidParameter("minP must be in [0, 1], got \(minP)") + } + guard minTokensToKeep >= 1 else { + throw LogitsProcessorError.invalidParameter("minTokensToKeep must be >= 1, got \(minTokensToKeep)") + } + self.minP = minP + self.minTokensToKeep = minTokensToKeep + self.filterValue = filterValue + } + + public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { + // Algorithm (following transformers implementation): + // 1. Compute probabilities from logits + // 2. Find max probability per batch + // 3. Create threshold = minP * maxProb + // 4. Sort logits and mask tokens where prob < threshold + // 5. Keep at least minTokensToKeep + // 6. Scatter back to original order + + let vocabSize = scores.shape[scores.rank - 1] + + // Compute probabilities + let probs = scores.softmax(alongAxis: -1) + + // Sort probabilities descending to get max (first element) + let sortedProbIndices = probs.argsort(alongAxis: -1, descendingOrder: true) + let sortedProbs = probs.gathering(atIndices: sortedProbIndices, alongAxis: -1) + + // Extract max prob per batch: first element of each sorted sequence + // Do this on CPU to avoid complex broadcasting issues + let sortedProbsArray = await sortedProbs.shapedArray(of: Float.self) + let batchSize = scores.shape[0] + var thresholdScalars = [Float]() + thresholdScalars.reserveCapacity(batchSize * vocabSize) + for batchIdx in 0..= minTokensToKeep AND shouldRemove) + let beyondMinimum = positions .>= Int32(minTokensToKeep) + let finalRemoveMask = sortedTokensToRemove .& beyondMinimum + + // Apply filter in sorted space + let sortedScores = scores.gathering(atIndices: sortedScoreIndices, alongAxis: -1) + let filterTensor = MLTensor(repeating: filterValue, shape: sortedScores.shape, scalarType: Float.self) + let filteredSorted = sortedScores.replacing(with: filterTensor, where: finalRemoveMask) + + // Scatter back to original order + return filteredSorted.gathering(atIndices: inversePermutation, alongAxis: -1) + } +} +#endif diff --git a/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift b/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift new file mode 100644 index 0000000..006bac3 --- /dev/null +++ b/Sources/Generation/LogitsWarper/RepetitionPenaltyLogitsProcessor.swift @@ -0,0 +1,92 @@ +#if canImport(CoreML) +import CoreML + +/// Error thrown by logits processors +public enum LogitsProcessorError: Error { + case invalidParameter(String) +} + +/// LogitsProcessor that prevents repetition of previous tokens through a penalty. +/// +/// For each token that has already appeared in the sequence: +/// - If the token's logit is negative: multiply by penalty (further suppresses it) +/// - If the token's logit is positive: divide by penalty (suppresses it) +/// +/// This penalty is applied at most once per token, regardless of how many times it appears. +/// +/// Typical penalty values: +/// - 1.0: No penalty +/// - > 1.0: Penalize repetition (e.g., 1.2 for balanced generation) +/// - 0.0 - 1.0: Encourage repetition +/// +/// Based on: +/// - https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L297 +/// - Paper: https://arxiv.org/abs/1909.05858 +@available(macOS 15.0, iOS 18.0, *) +public struct RepetitionPenaltyLogitsProcessor: LogitsProcessor { + public let penalty: Float + + /// Creates a repetition penalty logits processor. + /// + /// - Parameter penalty: Penalty factor. Values > 1.0 penalize repetition, values < 1.0 encourage it. + /// - Throws: If penalty is not strictly positive + public init(penalty: Float) throws { + guard penalty > 0 else { + throw LogitsProcessorError.invalidParameter("penalty must be strictly positive, got \(penalty)") + } + self.penalty = penalty + } + + public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { + guard penalty != 1.0 else { return scores } + + // Optimized implementation following transformers: + // 1. Gather scores for tokens that appear in input_ids + // 2. Apply conditional penalty: if score < 0: *= penalty, else: /= penalty + // 3. Scatter penalized values back to original positions + + // Gather scores for tokens that appear in input_ids + let gatheredScores = scores.gathering(atIndices: inputIds, alongAxis: -1) + + // Apply conditional penalty based on sign (vectorized) + let negativeScores = gatheredScores .< 0.0 + let penalizedScores = negativeScores.cast(to: Float.self) * (gatheredScores * penalty) + (1.0 - negativeScores.cast(to: Float.self)) * (gatheredScores / penalty) + + // Scatter penalized values back to original positions + // Note: MLTensor doesn't have direct scatter, so we use CPU operations for this step + let vocabSize = scores.shape[scores.rank - 1] + let batchSize = scores.shape[0] + + let inputIdsArray = await inputIds.shapedArray(of: Int32.self) + let penalizedArray = await penalizedScores.shapedArray(of: Float.self) + var scoresArray = await scores.shapedArray(of: Float.self) + + for batchIdx in 0..= 0 && tokenId < vocabSize else { continue } + + // For rank-2: [batch_size, vocab_size] + if scores.rank == 2 { + let scoreIdx = batchOffset + tokenId + let penalizedIdx = seqStart + tokenIdx + scoresArray.scalars[scoreIdx] = penalizedArray.scalars[penalizedIdx] + } + // For rank-3: [batch_size, seq_len, vocab_size] - update last position + else if scores.rank == 3 { + let lastSeqPos = scores.shape[1] - 1 + let scoreIdx = batchOffset + lastSeqPos * vocabSize + tokenId + let penalizedIdx = seqStart + tokenIdx + scoresArray.scalars[scoreIdx] = penalizedArray.scalars[penalizedIdx] + } + } + } + + return MLTensor(shape: scores.shape, scalars: scoresArray.scalars, scalarType: Float.self) + } +} +#endif diff --git a/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift b/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift new file mode 100644 index 0000000..e3f9383 --- /dev/null +++ b/Sources/Generation/LogitsWarper/TemperatureLogitsWarper.swift @@ -0,0 +1,34 @@ +#if canImport(CoreML) +import CoreML + +/// LogitsProcessor for temperature scaling, which effectively controls the randomness +/// of predicted tokens by modulating the logits distribution. +/// +/// Temperature < 1.0 makes the model more confident (sharper distribution). +/// Temperature > 1.0 makes the model less confident (flatter distribution). +/// Temperature = 1.0 leaves the distribution unchanged. +/// +/// Often used together with `TopPLogitsWarper` and `TopKLogitsWarper`. +/// +/// Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L231 +@available(macOS 15.0, iOS 18.0, *) +public struct TemperatureLogitsWarper: LogitsProcessor { + public let temperature: Float + + /// Creates a temperature logits warper. + /// + /// - Parameter temperature: Strictly positive float value used to modulate the logits distribution. + /// Must be > 0. Values close to 0 approximate greedy decoding. + /// - Throws: If temperature is not strictly positive + public init(temperature: Float) throws { + guard temperature > 0 else { + throw LogitsProcessorError.invalidParameter("temperature must be strictly positive, got \(temperature)") + } + self.temperature = temperature + } + + public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { + scores / temperature + } +} +#endif diff --git a/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift b/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift new file mode 100644 index 0000000..5472c82 --- /dev/null +++ b/Sources/Generation/LogitsWarper/TopKLogitsWarper.swift @@ -0,0 +1,85 @@ +#if canImport(CoreML) +import CoreML + +/// LogitsProcessor that performs top-k filtering, restricting to the k highest probability elements. +/// +/// Filters out all tokens except the k most likely ones by setting their logits to -inf. +/// This reduces the risk of low-probability tokens being sampled. +/// +/// Often used together with `TemperatureLogitsWarper` and `TopPLogitsWarper`. +/// Pro tip: In practice, LLMs use top_k in the 5-50 range. +/// +/// Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L532 +@available(macOS 15.0, iOS 18.0, *) +public struct TopKLogitsWarper: LogitsProcessor { + public let topK: Int + public let filterValue: Float + public let minTokensToKeep: Int + + /// Creates a top-k logits warper. + /// + /// - Parameters: + /// - topK: Number of highest probability tokens to keep + /// - filterValue: Value to set filtered tokens to (default: -infinity) + /// - minTokensToKeep: Minimum tokens that cannot be filtered (default: 1) + /// - Throws: If topK is not strictly positive or if minTokensToKeep is less than 1 + /// - Note: If topK is larger than the vocabulary size, no filtering is applied. + public init(topK: Int, filterValue: Float = -.infinity, minTokensToKeep: Int = 1) throws { + guard topK > 0 else { + throw LogitsProcessorError.invalidParameter("topK must be strictly positive, got \(topK)") + } + guard minTokensToKeep >= 1 else { + throw LogitsProcessorError.invalidParameter("minTokensToKeep must be at least 1, got \(minTokensToKeep)") + } + + self.topK = max(topK, minTokensToKeep) + self.filterValue = filterValue + self.minTokensToKeep = minTokensToKeep + } + + public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { + let vocabSize = scores.shape[scores.rank - 1] + let k = min(topK, vocabSize) // Safety check + + // If k equals vocabSize, no filtering is needed + if k >= vocabSize { + return scores + } + + // Get the k-th highest score (the threshold) + let (topKValues, _) = scores.topK(k) + + // The threshold is the smallest value in the top-k (last element) + // We need to get the value at index [k-1] along the last dimension + let thresholdScores = await topKValues.shapedArray(of: Float.self) + + // For each batch item, get the k-th largest score + let batchSize = scores.shape[0] + var thresholds = [Float]() + + for batchIdx in 0..= 0 && topP <= 1.0 else { + throw LogitsProcessorError.invalidParameter("topP must be in [0, 1], got \(topP)") + } + guard minTokensToKeep >= 1 else { + throw LogitsProcessorError.invalidParameter("minTokensToKeep must be at least 1, got \(minTokensToKeep)") + } + + self.topP = topP + self.filterValue = filterValue + self.minTokensToKeep = minTokensToKeep + } + + public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor { + // Algorithm (following transformers implementation): + // 1. Sort logits in descending order + // 2. Compute softmax probabilities on sorted logits + // 3. Compute cumulative sum of probabilities + // 4. Remove tokens with cumulative probability > top_p + // 5. Keep at least min_tokens_to_keep + // 6. Scatter mask back to original indexing using inverse permutation + + let vocabSize = scores.shape[scores.rank - 1] + + // Scores are already in Float32 (handled by LogitsProcessorList) + // Sort in descending order (highest scores first) + let sortedIndices = scores.argsort(alongAxis: -1, descendingOrder: true) + + // Build inverse permutation for scattering back + let inversePermutation = sortedIndices.argsort(alongAxis: -1) + + // Gather scores in sorted order + let sortedScores = scores.gathering(atIndices: sortedIndices, alongAxis: -1) + + // Compute probabilities and cumulative sum in sorted order + let sortedProbs = sortedScores.softmax(alongAxis: -1) + let cumulativeProbs = sortedProbs.cumulativeSum(alongAxis: -1) + + // Shift cumsum to exclude current token (HuggingFace convention) + // This ensures we include the first token that pushes us over the threshold + let cumulativeProbsShifted = cumulativeProbs - sortedProbs + + // Create position tensor [0, 1, 2, ..., vocabSize-1] for minTokensToKeep check + let baseShape = Array(repeating: 1, count: sortedScores.rank - 1) + [vocabSize] + var multiples = sortedScores.shape + multiples[multiples.count - 1] = 1 + + let positions = MLTensor( + rangeFrom: Int32(0), + to: Int32(vocabSize), + by: 1, + scalarType: Int32.self + ) + .reshaped(to: baseShape) + .tiled(multiples: multiples) + .cast(to: Float.self) + + // Create mask in sorted order: + // Remove if: position >= minTokensToKeep AND cumsum_shifted > topP + let beyondMinimum = positions .>= Float(minTokensToKeep) + let exceedsThreshold = cumulativeProbsShifted .> topP + let removeMaskSorted = beyondMinimum .& exceedsThreshold + + // Apply filter value in sorted space + let filterTensor = MLTensor( + repeating: filterValue, + shape: sortedScores.shape, + scalarType: Float.self + ) + let filteredSorted = sortedScores.replacing(with: filterTensor, where: removeMaskSorted) + + // Scatter back to original indexing using inverse permutation + return filteredSorted.gathering(atIndices: inversePermutation, alongAxis: -1) + } +} +#endif diff --git a/Tests/GenerationTests/GenerationIntegrationTests.swift b/Tests/GenerationTests/GenerationIntegrationTests.swift new file mode 100644 index 0000000..8963a65 --- /dev/null +++ b/Tests/GenerationTests/GenerationIntegrationTests.swift @@ -0,0 +1,358 @@ +import CoreML +import Tokenizers +import XCTest + +@testable import Generation + +@available(macOS 15.0, iOS 18.0, *) +final class GenerationIntegrationTests: XCTestCase { + + // MARK: - Mock Model for Testing + + /// Mock language model that returns predictable logits for testing + class MockLanguageModel { + var callCount = 0 + var logitsHistory: [MLTensor] = [] + + func predictNextToken(_ inputTokens: MLTensor, _ config: GenerationConfig) async -> MLTensor { + callCount += 1 + + // Return different logits based on the sequence length + let seqLength = inputTokens.shape[1] + + // Simulate a vocabulary of 10 tokens + // Create logits that favor certain tokens based on context + let vocabSize = 10 + var logits = [Float](repeating: 0.0, count: vocabSize) + + switch seqLength { + case 1: + // First generation: favor token 5 + logits = [1.0, 1.5, 2.0, 2.5, 3.0, 10.0, 3.0, 2.5, 2.0, 1.5] + case 2: + // Second generation: favor token 3 + logits = [1.0, 2.0, 3.0, 8.0, 3.0, 2.0, 2.0, 1.5, 1.0, 0.5] + case 3: + // Third generation: create a more uniform distribution + logits = [2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 3.5, 3.0, 2.5, 2.0] + default: + // Default: slightly favor middle tokens + logits = [1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.5] + } + + let tensor = MLTensor(shape: [1, vocabSize], scalars: logits, scalarType: Float.self) + logitsHistory.append(tensor) + return tensor + } + } + + // MARK: - Integration Tests + + func testGreedyGenerationWithoutProcessors() async throws { + let model = MockLanguageModel() + + var config = GenerationConfig(maxNewTokens: 3) + config.doSample = false // Greedy mode + config.eosTokenId = -1 // Disable early stopping + config.maxLength = 10 + + let generation = TestGeneration() + let startTokens = [0] // Start with token 0 + + let output = await generation.generate( + config: config, + tokens: startTokens, + model: model.predictNextToken + ) + + // Greedy should always pick the highest logit + // Token 0 -> Token 5 (logit 10.0) -> Token 3 (logit 8.0) -> Token 5 (logit 4.5) + XCTAssertEqual(output.count, 4, "Should generate 3 new tokens + initial token") + XCTAssertEqual(output[0], 0, "First token should be the start token") + XCTAssertEqual(output[1], 5, "Second token should be 5 (highest logit)") + XCTAssertEqual(output[2], 3, "Third token should be 3 (highest logit)") + XCTAssertEqual(output[3], 5, "Fourth token should be 5 (highest logit)") + + XCTAssertEqual(model.callCount, 3, "Model should be called 3 times") + } + + func testSamplingWithTemperature() async throws { + let model = MockLanguageModel() + + var config = GenerationConfig(maxNewTokens: 3) + config.doSample = true // Sampling mode + config.temperature = 0.1 // Low temperature = more deterministic + config.eosTokenId = -1 + + let generation = TestGeneration() + let startTokens = [0] + + let output = await generation.generate( + config: config, + tokens: startTokens, + model: model.predictNextToken + ) + + XCTAssertEqual(output.count, 4, "Should generate 3 new tokens + initial token") + XCTAssertEqual(output[0], 0, "First token should be the start token") + + // With low temperature, sampling should still prefer high-probability tokens + // We can't assert exact tokens due to randomness, but can verify structure + XCTAssertTrue(output[1] < 10, "Generated token should be in vocab range") + XCTAssertTrue(output[2] < 10, "Generated token should be in vocab range") + XCTAssertTrue(output[3] < 10, "Generated token should be in vocab range") + } + + func testTopKFiltering() async throws { + let model = MockLanguageModel() + + var config = GenerationConfig(maxNewTokens: 3) + config.doSample = true // Sampling mode + config.topK = 3 // Only consider top 3 tokens + config.temperature = 1.0 + config.eosTokenId = -1 + + let generation = TestGeneration() + let startTokens = [0] + + // Run generation multiple times to test top-k filtering + for _ in 0..<5 { + model.callCount = 0 + model.logitsHistory = [] + + let output = await generation.generate( + config: config, + tokens: startTokens, + model: model.predictNextToken + ) + + XCTAssertEqual(output.count, 4, "Should generate 3 new tokens + initial token") + + // Verify that generated tokens are within valid range + for token in output[1...] { + XCTAssertTrue(token >= 0 && token < 10, "Token \(token) should be in vocab range") + } + } + } + + func testTopPFiltering() async throws { + let model = MockLanguageModel() + + var config = GenerationConfig(maxNewTokens: 2) + config.doSample = true // Sampling mode + config.topP = 0.9 // Top 90% probability mass + config.temperature = 1.0 + config.eosTokenId = -1 + + let generation = TestGeneration() + let startTokens = [0] + + let output = await generation.generate( + config: config, + tokens: startTokens, + model: model.predictNextToken + ) + + XCTAssertEqual(output.count, 3, "Should generate 2 new tokens + initial token") + + // Top-P should filter out low-probability tokens + for token in output[1...] { + XCTAssertTrue(token >= 0 && token < 10, "Token should be in vocab range") + } + } + + func testRepetitionPenalty() async throws { + // Create a model that always returns the same high-scoring token + class RepetitiveModel { + func predict(_ inputTokens: MLTensor, _ config: GenerationConfig) async -> MLTensor { + // Always favor token 7 + let logits: [Float] = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 10.0, 1.0, 1.0] + return MLTensor(shape: [1, 10], scalars: logits, scalarType: Float.self) + } + } + + let model = RepetitiveModel() + + // Test WITHOUT repetition penalty + var configNoPenalty = GenerationConfig(maxNewTokens: 3) + configNoPenalty.doSample = false // Greedy mode + configNoPenalty.repetitionPenalty = 1.0 // No penalty + configNoPenalty.eosTokenId = -1 + + let generation = TestGeneration() + let startTokens = [0] + + let outputNoPenalty = await generation.generate( + config: configNoPenalty, + tokens: startTokens, + model: model.predict + ) + + // Without penalty, should keep selecting token 7 + XCTAssertEqual(outputNoPenalty, [0, 7, 7, 7], "Without penalty, should repeat token 7") + + // Test WITH repetition penalty + var configWithPenalty = GenerationConfig(maxNewTokens: 3) + configWithPenalty.doSample = false // Greedy mode + configWithPenalty.repetitionPenalty = 10.0 // Strong penalty (10.0 / 10.0 = 1.0, so other tokens win) + configWithPenalty.eosTokenId = -1 + + let outputWithPenalty = await generation.generate( + config: configWithPenalty, + tokens: startTokens, + model: model.predict + ) + + // With strong penalty, should select token 7 first time, but then avoid it + XCTAssertEqual(outputWithPenalty[0], 0, "First token should be start token") + XCTAssertEqual(outputWithPenalty[1], 7, "Second token should be 7 (highest score)") + + // After token 7 is penalized, other tokens should be chosen + // (exact tokens depend on penalty calculation, but should NOT all be 7) + let uniqueTokens = Set(outputWithPenalty[1...]) + XCTAssertTrue(uniqueTokens.count > 1, "With repetition penalty, should generate diverse tokens, got \(outputWithPenalty)") + } + + func testCombinedProcessors() async throws { + let model = MockLanguageModel() + + var config = GenerationConfig(maxNewTokens: 3) + config.doSample = true // Sampling mode + config.temperature = 0.8 // Slightly focused + config.topK = 5 // Top 5 tokens + config.topP = 0.95 // 95% probability mass + config.repetitionPenalty = 1.1 // Slight penalty + config.eosTokenId = -1 + + let generation = TestGeneration() + let startTokens = [0] + + let output = await generation.generate( + config: config, + tokens: startTokens, + model: model.predictNextToken + ) + + XCTAssertEqual(output.count, 4, "Should generate 3 new tokens + initial token") + XCTAssertEqual(output[0], 0, "First token should be the start token") + + // All processors should work together + // Can't assert exact tokens due to randomness, but verify structure + for token in output[1...] { + XCTAssertTrue(token >= 0 && token < 10, "Token should be in vocab range") + } + + // Verify model was called correct number of times + XCTAssertEqual(model.callCount, 3, "Model should be called 3 times") + } + + func testMinPFiltering() async throws { + let model = MockLanguageModel() + + // Test with min-p: keep tokens with prob >= minP * max_prob + var configWithMinP = GenerationConfig(maxNewTokens: 3) + configWithMinP.doSample = true // Sampling mode + configWithMinP.temperature = 1.0 // No temperature adjustment + configWithMinP.minP = 0.05 // Relatively permissive threshold + configWithMinP.eosTokenId = -1 + configWithMinP.maxLength = 10 + + let generation = TestGeneration() + let startTokens = [0] + + let output = await generation.generate( + config: configWithMinP, + tokens: startTokens, + model: model.predictNextToken + ) + + XCTAssertEqual(output.count, 4, "Should generate 3 new tokens + initial token") + XCTAssertEqual(output[0], 0, "First token should be the start token") + + // All tokens should be valid + for token in output[1...] { + XCTAssertTrue(token >= 0 && token < 10, "Token should be in vocab range") + } + + // Test with more aggressive min-p + model.callCount = 0 + var configAggressiveMinP = GenerationConfig(maxNewTokens: 3) + configAggressiveMinP.doSample = true + configAggressiveMinP.temperature = 1.0 + configAggressiveMinP.minP = 0.5 // Much more aggressive + configAggressiveMinP.eosTokenId = -1 + configAggressiveMinP.maxLength = 10 + + let outputAggressive = await generation.generate( + config: configAggressiveMinP, + tokens: startTokens, + model: model.predictNextToken + ) + + XCTAssertEqual(outputAggressive.count, 4, "Should generate 3 new tokens + initial token") + // With aggressive min-p, should sample from fewer options + // (exact behavior depends on model, but verify it doesn't crash) + } + + func testEarlyStoppingWithEOS() async throws { + // Create a model that returns EOS token after 2 generations + class EOSModel { + var callCount = 0 + + func predict(_ inputTokens: MLTensor, _ config: GenerationConfig) async -> MLTensor { + callCount += 1 + + let vocabSize = 10 + var logits = [Float](repeating: 1.0, count: vocabSize) + + if callCount >= 2 { + // After 2 calls, strongly favor EOS token (which we'll set as token 9) + logits[9] = 100.0 + } else { + // Before that, favor token 5 + logits[5] = 10.0 + } + + return MLTensor(shape: [1, vocabSize], scalars: logits, scalarType: Float.self) + } + } + + let model = EOSModel() + + var config = GenerationConfig(maxNewTokens: 10) // Request many tokens + config.doSample = false // Greedy mode + config.eosTokenId = 9 // Token 9 is EOS + + let generation = TestGeneration() + let startTokens = [0] + + let output = await generation.generate( + config: config, + tokens: startTokens, + model: model.predict + ) + + // Should stop early when EOS is encountered + XCTAssertLessThan(output.count, 11, "Should stop before generating all 10 tokens") + XCTAssertEqual(output[0], 0, "First token should be start token") + + // Model should be called fewer times due to early stopping + XCTAssertLessThan(model.callCount, 10, "Model should be called fewer times due to EOS") + } +} + +// MARK: - Test Helper + +@available(macOS 15.0, iOS 18.0, *) +struct TestGeneration: Generation { + func generate( + config: GenerationConfig, + prompt: String, + model: NextTokenModel, + tokenizer: Tokenizers.Tokenizer, + callback: PredictionStringCallback? + ) async -> String { + // Not used in these tests + return "" + } +} diff --git a/Tests/GenerationTests/LogitsProcessorTests.swift b/Tests/GenerationTests/LogitsProcessorTests.swift new file mode 100644 index 0000000..44d16f2 --- /dev/null +++ b/Tests/GenerationTests/LogitsProcessorTests.swift @@ -0,0 +1,335 @@ +import CoreML +import XCTest + +@testable import Generation + +@available(macOS 15.0, iOS 18.0, *) +final class LogitsProcessorTests: XCTestCase { + private let accuracy: Float = 0.0001 + + // MARK: - Temperature Tests + + func testTemperatureWarper() async throws { + let warper = try TemperatureLogitsWarper(temperature: 2.0) + + // Create input: batch_size=1, seq_len=3 + let inputIds = MLTensor(shape: [1, 3], scalars: [Int32(1), Int32(2), Int32(3)], scalarType: Int32.self) + // Create scores: batch_size=1, vocab_size=3 + let scores = MLTensor(shape: [1, 3], scalars: [Float(2.0), Float(4.0), Float(6.0)], scalarType: Float.self) + + let result = await warper(inputIds, scores) + let expected: [Float] = [1.0, 2.0, 3.0] // Each score divided by 2.0 + + await assertMLTensorEqual(result, expected: expected, accuracy: accuracy) + } + + func testTemperatureWarperWithDifferentValues() async throws { + // Test temperature < 1 (sharper distribution) + let sharper = try TemperatureLogitsWarper(temperature: 0.5) + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 2], scalars: [Float(1.0), Float(2.0)], scalarType: Float.self) + + let result = await sharper(inputIds, scores) + let expected: [Float] = [2.0, 4.0] // Divided by 0.5 = multiplied by 2 + + await assertMLTensorEqual(result, expected: expected, accuracy: accuracy) + } + + // MARK: - Top-K Tests + + func testTopKWarper() async throws { + let warper = try TopKLogitsWarper(topK: 3) + + let inputIds = MLTensor(shape: [1, 2], scalars: [Int32(1), Int32(2)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 5], scalars: [Float(1.0), Float(2.0), Float(3.0), Float(4.0), Float(5.0)], scalarType: Float.self) + + let result = await warper(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + // Top 3 tokens (5, 4, 3) should remain, others should be -inf + XCTAssertTrue(resultArray[0].isInfinite && resultArray[0] < 0, "Token 0 should be -inf") + XCTAssertTrue(resultArray[1].isInfinite && resultArray[1] < 0, "Token 1 should be -inf") + XCTAssertEqual(resultArray[2], 3.0, accuracy: accuracy, "Token 2 should be kept") + XCTAssertEqual(resultArray[3], 4.0, accuracy: accuracy, "Token 3 should be kept") + XCTAssertEqual(resultArray[4], 5.0, accuracy: accuracy, "Token 4 should be kept") + } + + func testTopKWarperWithSmallK() async throws { + let warper = try TopKLogitsWarper(topK: 1) + + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 3], scalars: [Float(1.0), Float(5.0), Float(3.0)], scalarType: Float.self) + + let result = await warper(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + // Only token with score 5.0 should remain + XCTAssertTrue(resultArray[0].isInfinite && resultArray[0] < 0) + XCTAssertEqual(resultArray[1], 5.0, accuracy: accuracy) + XCTAssertTrue(resultArray[2].isInfinite && resultArray[2] < 0) + } + + // MARK: - Top-P Tests + + func testTopPWarper() async throws { + let warper = try TopPLogitsWarper(topP: 0.9) + + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + // Create a distribution where top tokens dominate: [0.0, 1.0, 2.0, 3.0, 10.0] + // After softmax, token 4 will have ~99.7% probability + let scores = MLTensor(shape: [1, 5], scalars: [Float(0.0), Float(1.0), Float(2.0), Float(3.0), Float(10.0)], scalarType: Float.self) + + let result = await warper(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + // Token 4 (score 10.0) should definitely be kept (highest probability) + XCTAssertFalse(resultArray[4].isInfinite, "Highest probability token should be kept") + + // Some lower tokens should be filtered to -inf + let filteredCount = resultArray.filter { $0.isInfinite && $0 < 0 }.count + XCTAssertTrue(filteredCount > 0, "Top-P should filter some low-probability tokens") + } + + func testTopPWarperWithHighThreshold() async throws { + // With topP=0.99, almost all tokens should be kept + let warper = try TopPLogitsWarper(topP: 0.99) + + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 5], scalars: [Float(1.0), Float(2.0), Float(3.0), Float(4.0), Float(5.0)], scalarType: Float.self) + + let result = await warper(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + // With high topP and relatively uniform distribution, most tokens should be kept + let keptCount = resultArray.filter { !($0.isInfinite && $0 < 0) }.count + XCTAssertTrue(keptCount >= 4, "High topP should keep most tokens") + } + + // MARK: - Repetition Penalty Tests + + func testRepetitionPenaltyProcessor() async throws { + let processor = try RepetitionPenaltyLogitsProcessor(penalty: 2.0) + + // Input sequence with tokens [1, 2, 3] + let inputIds = MLTensor(shape: [1, 3], scalars: [Int32(1), Int32(2), Int32(3)], scalarType: Int32.self) + + // Scores for vocab of size 5: [0.5, -0.5, 1.0, -1.0, 2.0] + let scores = MLTensor(shape: [1, 5], scalars: [Float(0.5), Float(-0.5), Float(1.0), Float(-1.0), Float(2.0)], scalarType: Float.self) + + let result = await processor(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + // Token 0: not in sequence, unchanged + XCTAssertEqual(resultArray[0], 0.5, accuracy: accuracy, "Token 0 should be unchanged") + + // Token 1 (score -0.5 < 0): multiplied by penalty = -1.0 + XCTAssertEqual(resultArray[1], -1.0, accuracy: accuracy, "Token 1 should be penalized (negative)") + + // Token 2 (score 1.0 > 0): divided by penalty = 0.5 + XCTAssertEqual(resultArray[2], 0.5, accuracy: accuracy, "Token 2 should be penalized (positive)") + + // Token 3 (score -1.0 < 0): multiplied by penalty = -2.0 + XCTAssertEqual(resultArray[3], -2.0, accuracy: accuracy, "Token 3 should be penalized (negative)") + + // Token 4: not in sequence, unchanged + XCTAssertEqual(resultArray[4], 2.0, accuracy: accuracy, "Token 4 should be unchanged") + } + + func testRepetitionPenaltyWithNoPenalty() async throws { + let processor = try RepetitionPenaltyLogitsProcessor(penalty: 1.0) + + let inputIds = MLTensor(shape: [1, 2], scalars: [Int32(1), Int32(2)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 5], scalars: [Float(1.0), Float(2.0), Float(3.0), Float(4.0), Float(5.0)], scalarType: Float.self) + + let result = await processor(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + let expectedArray = await scores.shapedArray(of: Float.self).scalars + + // With penalty=1.0, scores should be unchanged + XCTAssertEqual(resultArray, expectedArray, "Penalty of 1.0 should not change scores") + } + + func testRepetitionPenaltyWithRank3Scores() async throws { + let processor = try RepetitionPenaltyLogitsProcessor(penalty: 2.0) + + // Input sequence with tokens [1, 2, 3] + let inputIds = MLTensor(shape: [1, 3], scalars: [Int32(1), Int32(2), Int32(3)], scalarType: Int32.self) + + // Scores shaped as [batch, sequence_length, vocab] -> [1, 1, 5] + let scores = MLTensor( + shape: [1, 1, 5], + scalars: [Float(0.5), Float(-0.5), Float(1.0), Float(-1.0), Float(2.0)], + scalarType: Float.self + ) + + let result = await processor(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + let expected: [Float] = [0.5, -1.0, 0.5, -2.0, 2.0] + XCTAssertEqual(resultArray.count, expected.count, "Flattened tensor mismatch") + for (value, exp) in zip(resultArray, expected) { + XCTAssertEqual(value, exp, accuracy: accuracy) + } + } + + // MARK: - Processor List Tests + + func testLogitsProcessorList() async throws { + let temp = try TemperatureLogitsWarper(temperature: 2.0) + let topK = try TopKLogitsWarper(topK: 3) + let processorList = LogitsProcessorList(processors: [temp, topK]) + + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 5], scalars: [Float(2.0), Float(4.0), Float(6.0), Float(8.0), Float(10.0)], scalarType: Float.self) + + // First temperature divides by 2: [1, 2, 3, 4, 5] + // Then top-k keeps top 3: [-inf, -inf, 3, 4, 5] + let result = await processorList(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + XCTAssertTrue(resultArray[0].isInfinite && resultArray[0] < 0) + XCTAssertTrue(resultArray[1].isInfinite && resultArray[1] < 0) + XCTAssertEqual(resultArray[2], 3.0, accuracy: accuracy) + XCTAssertEqual(resultArray[3], 4.0, accuracy: accuracy) + XCTAssertEqual(resultArray[4], 5.0, accuracy: accuracy) + } + + func testEmptyProcessorList() async throws { + let processorList = LogitsProcessorList(processors: []) + + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 3], scalars: [Float(1.0), Float(2.0), Float(3.0)], scalarType: Float.self) + + let result = await processorList(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + let expectedArray = await scores.shapedArray(of: Float.self).scalars + + // Should be unchanged + XCTAssertEqual(resultArray, expectedArray) + } + + // MARK: - Min-P Tests + + func testMinPWarper() async throws { + let warper = try MinPLogitsWarper(minP: 0.1) + + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + // Scores: [1.0, 2.0, 3.0, 4.0, 5.0] + // After softmax, probabilities will be computed + // Max prob will be for score=5.0 + // Min threshold = 0.1 * max_prob + // Tokens with prob < threshold should be filtered + let scores = MLTensor(shape: [1, 5], scalars: [Float(1.0), Float(2.0), Float(3.0), Float(4.0), Float(5.0)], scalarType: Float.self) + + let result = await warper(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + // Compute expected: softmax probabilities manually + let scoresArray = await scores.shapedArray(of: Float.self).scalars + let expScores = scoresArray.map { exp($0) } + let sumExp = expScores.reduce(0, +) + let probs = expScores.map { $0 / sumExp } + let maxProb = probs.max()! + let threshold = 0.1 * maxProb + + // Check that low probability tokens are filtered + for (idx, prob) in probs.enumerated() { + if prob < threshold { + XCTAssertTrue(resultArray[idx].isInfinite && resultArray[idx] < 0, "Token \(idx) with prob \(prob) should be filtered") + } else { + XCTAssertEqual(resultArray[idx], scoresArray[idx], accuracy: accuracy, "Token \(idx) should not be filtered") + } + } + } + + func testMinPWarperKeepsMinTokens() async throws { + // Even with aggressive minP, should keep at least minTokensToKeep tokens + let warper = try MinPLogitsWarper(minP: 0.99, minTokensToKeep: 2) + + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 5], scalars: [Float(1.0), Float(2.0), Float(3.0), Float(4.0), Float(5.0)], scalarType: Float.self) + + let result = await warper(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + // Count non-infinite values + let nonInfiniteCount = resultArray.filter { !$0.isInfinite }.count + XCTAssertGreaterThanOrEqual(nonInfiniteCount, 2, "Should keep at least 2 tokens") + } + + func testMinPWarperWithLowThreshold() async throws { + // With very low minP, most tokens should pass + let warper = try MinPLogitsWarper(minP: 0.001) + + let inputIds = MLTensor(shape: [1, 1], scalars: [Int32(1)], scalarType: Int32.self) + let scores = MLTensor(shape: [1, 5], scalars: [Float(1.0), Float(2.0), Float(3.0), Float(4.0), Float(5.0)], scalarType: Float.self) + + let result = await warper(inputIds, scores) + let resultArray = await result.shapedArray(of: Float.self).scalars + + // Most or all tokens should remain + let nonInfiniteCount = resultArray.filter { !$0.isInfinite }.count + XCTAssertGreaterThanOrEqual(nonInfiniteCount, 4, "With low minP, most tokens should pass") + } + + func testMinPWarperInvalidParameters() { + // Test invalid minP + XCTAssertThrowsError(try MinPLogitsWarper(minP: -0.1)) + XCTAssertThrowsError(try MinPLogitsWarper(minP: 1.5)) + + // Test invalid minTokensToKeep + XCTAssertThrowsError(try MinPLogitsWarper(minP: 0.1, minTokensToKeep: 0)) + XCTAssertThrowsError(try MinPLogitsWarper(minP: 0.1, minTokensToKeep: -1)) + } + + // MARK: - Parameter Validation Tests + + func testTemperatureWarperInvalidParameters() { + // Test invalid temperature values + XCTAssertThrowsError(try TemperatureLogitsWarper(temperature: 0.0)) + XCTAssertThrowsError(try TemperatureLogitsWarper(temperature: -1.0)) + } + + func testTopKWarperInvalidParameters() { + // Test invalid topK values + XCTAssertThrowsError(try TopKLogitsWarper(topK: 0)) + XCTAssertThrowsError(try TopKLogitsWarper(topK: -1)) + + // Test invalid minTokensToKeep + XCTAssertThrowsError(try TopKLogitsWarper(topK: 5, minTokensToKeep: 0)) + XCTAssertThrowsError(try TopKLogitsWarper(topK: 5, minTokensToKeep: -1)) + } + + func testTopPWarperInvalidParameters() { + // Test invalid topP values + XCTAssertThrowsError(try TopPLogitsWarper(topP: -0.1)) + XCTAssertThrowsError(try TopPLogitsWarper(topP: 1.5)) + + // Test invalid minTokensToKeep + XCTAssertThrowsError(try TopPLogitsWarper(topP: 0.9, minTokensToKeep: 0)) + XCTAssertThrowsError(try TopPLogitsWarper(topP: 0.9, minTokensToKeep: -1)) + } + + func testRepetitionPenaltyInvalidParameters() { + // Test invalid penalty values + XCTAssertThrowsError(try RepetitionPenaltyLogitsProcessor(penalty: 0.0)) + XCTAssertThrowsError(try RepetitionPenaltyLogitsProcessor(penalty: -1.0)) + } +} + +// MARK: - Test Helpers + +@available(macOS 15.0, iOS 18.0, *) +func assertMLTensorEqual( + _ tensor: MLTensor, + expected: [Float], + accuracy: Float, + file: StaticString = #filePath, + line: UInt = #line +) async { + let actual = await tensor.shapedArray(of: Float.self).scalars + XCTAssertEqual(actual.count, expected.count, "Tensor size mismatch", file: file, line: line) + for (a, e) in zip(actual, expected) { + XCTAssertEqual(a, e, accuracy: accuracy, file: file, line: line) + } +}