Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,28 @@ struct TransformersCLI: AsyncParsableCommand {
@Option(
help: """
When enabled, two generation passes are ran, one to 'warm up' and another to collect \
benchmark metrics.
benchmark metrics.
""")
var warmup: Bool = false

@Option(help: "Enable sampling mode (true) or use greedy decoding (false)")
var doSample: Bool = false

@Option(help: "Temperature for sampling (lower = more deterministic, typical: 0.1-2.0)")
var temperature: Float?

@Option(help: "Top-k filtering - only consider k most likely tokens (typical: 5-50)")
var topK: Int?

@Option(help: "Top-p (nucleus) sampling - cumulative probability threshold (typical: 0.9-0.95)")
var topP: Float?

@Option(help: "Min-p sampling - minimum probability threshold scaled by top token (typical: 0.01-0.2)")
var minP: Float?

@Option(help: "Repetition penalty to discourage repeating tokens (typical: 1.0-2.0, 1.0 = no penalty)")
var repetitionPenalty: Float?

func generate(
model: LanguageModel,
config: GenerationConfig,
Expand Down Expand Up @@ -88,11 +106,26 @@ struct TransformersCLI: AsyncParsableCommand {
print("Loading model \(compiledURL)")
let model = try LanguageModel.loadCompiled(url: compiledURL, computeUnits: computeUnits.asMLComputeUnits)

// Using greedy generation for now
var config = model.defaultGenerationConfig
config.doSample = false
config.doSample = doSample
config.maxNewTokens = maxLength

if let temperature = temperature {
config.temperature = temperature
}
if let topK = topK {
config.topK = topK
}
if let topP = topP {
config.topP = topP
}
if let minP = minP {
config.minP = minP
}
if let repetitionPenalty = repetitionPenalty {
config.repetitionPenalty = repetitionPenalty
}

// Given the size of the out-of-model computation, dispatch all
// tensor operations to the CPU.

Expand Down
1 change: 1 addition & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ let package = Package(
.target(name: "Hub", dependencies: [.product(name: "Jinja", package: "swift-jinja")], resources: [.process("Resources")], swiftSettings: swiftSettings),
.target(name: "Models", dependencies: ["Tokenizers", "Generation"]),
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")]),
.testTarget(name: "GenerationTests", dependencies: ["Generation"]),
.testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")], swiftSettings: swiftSettings),
.testTarget(name: "ModelsTests", dependencies: ["Models", "Hub"], resources: [.process("Resources")]),
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources")]),
Expand Down
57 changes: 42 additions & 15 deletions Sources/Generation/Decoders.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,51 @@ import CoreML

@available(macOS 15.0, iOS 18.0, *)
func selectNextTokenUsingGreedyDecoding(from scores: MLTensor) -> MLTensor {
scores.argmax(alongAxis: -1).reshaped(to: [1, 1])
let indices = scores.argmax(alongAxis: -1).reshaped(to: [1, 1])
// Ensure indices are Int32 for concatenation with input tokens
return indices.scalarType == Int32.self ? indices : indices.cast(to: Int32.self)
}

// MARK: Top-K Sampling
// MARK: Sampling

/// Performs multinomial sampling from processed logits.
///
/// Assumes logits have already been processed by LogitsProcessorList
/// (temperature, top-k, top-p, etc. already applied).
///
/// - Parameter scores: Processed logits tensor [batch_size, vocab_size]
/// - Returns: Sampled token ID tensor [batch_size, 1]
@available(macOS 15.0, iOS 18.0, *)
func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float, topK: Int) -> MLTensor {
let temperatureAdjustedScores = scores / temperature
let (topKScores, topKIndices) = temperatureAdjustedScores.topK(topK)
let topKProbs = topKScores.softmax(alongAxis: -1)
let rnd = topKProbs.sum() * Float.random(in: 0..<1)
var accumTopKProbs = topKProbs.cumulativeSum(alongAxis: -1)
accumTopKProbs += (accumTopKProbs .< rnd) * 100.0
let topKIndex = accumTopKProbs.argsort()[..., 0]
let nextTokenTensor = topKIndices.gathering(
atIndices: topKIndex,
alongAxis: topKIndices.rank - 1
)
return nextTokenTensor.reshaped(to: [1, 1])
func selectNextTokenUsingSampling(from scores: MLTensor) -> MLTensor {
// Convert logits to probabilities
let probs = scores.softmax(alongAxis: -1)

// Multinomial sampling using cumulative sum method:
// 1. Generate random number in [0, 1)
// 2. Compute cumulative sum of probabilities
// 3. Find first index where cumsum >= random_number
//
// This is equivalent to torch.multinomial() but using available MLTensor ops

let batchSize = scores.shape[0]
let rndTensor = MLTensor(randomUniform: [batchSize, 1], in: 0..<1, scalarType: Float.self)
let cumulativeProbs = probs.cumulativeSum(alongAxis: -1)

// Ensure random tensor matches the type of cumulativeProbs
let rnd = cumulativeProbs.scalarType == Float.self ? rndTensor : rndTensor.cast(to: cumulativeProbs.scalarType)

// Create mask where cumsum >= rnd (these are candidates)
// We want the FIRST position where this is true
// Strategy: Set all positions where cumsum < rnd to a large value (1000.0)
// Set all positions where cumsum >= rnd to their index value
// Then argmin will give us the first qualifying index

let mask = cumulativeProbs .< rnd
let penalized = mask * 1000.0 // Large value for positions to skip
let indexed = penalized + cumulativeProbs // Positions >= rnd will have small values

let sampledIndex = indexed.argmin(alongAxis: -1).reshaped(to: [1, 1])
// Ensure indices are Int32 for concatenation with input tokens
return sampledIndex.scalarType == Int32.self ? sampledIndex : sampledIndex.cast(to: Int32.self)
}
#endif // canImport(CoreML)
70 changes: 63 additions & 7 deletions Sources/Generation/Generation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,27 @@ extension Generation {
) async -> GenerationOutput {
let tokens = tokens.map { Int32($0) }
var outputTokens = MLTensor(tokens).expandingShape(at: 0)
while outputTokens.shape[1] < config.maxLength {

// Create logits processor list based on config
let logitsProcessorList = createLogitsProcessorList(config: config)

let inputLength = outputTokens.shape[1]
let maxTotalLength = min(config.maxLength, inputLength + config.maxNewTokens)

while outputTokens.shape[1] < maxTotalLength {
// Get raw logits from model
let nextTokenScores = await model(outputTokens, config)

// Apply logits processors
let processedScores = await logitsProcessorList(outputTokens, nextTokenScores)

// Select next token based on generation mode
let nextToken =
switch config.generationMode {
case .greedy:
selectNextTokenUsingGreedyDecoding(from: nextTokenScores)
selectNextTokenUsingGreedyDecoding(from: processedScores)
case .sample:
selectNextTokenUsingTopKSampling(
from: nextTokenScores,
temperature: config.temperature,
topK: config.topK
)
selectNextTokenUsingSampling(from: processedScores)
default:
fatalError("Generation mode \(config.generationMode) not implemented yet")
}
Expand All @@ -101,6 +110,53 @@ extension Generation {
return await tensorToGenerationOutput(outputTokens)
}

/// Creates a list of logits processors based on generation configuration.
///
/// - Parameter config: Generation configuration specifying which processors to apply
/// - Returns: List of logits processors to apply during generation
private func createLogitsProcessorList(config: GenerationConfig) -> LogitsProcessorList {
var processors: [any LogitsProcessor] = []

// Repetition penalty (applied before sampling warpers)
if config.repetitionPenalty != 1.0 {
if let processor = try? RepetitionPenaltyLogitsProcessor(penalty: Float(config.repetitionPenalty)) {
processors.append(processor)
}
}

// Temperature scaling (if not default)
if config.temperature > 0 && config.temperature != 1.0 {
if let processor = try? TemperatureLogitsWarper(temperature: config.temperature) {
processors.append(processor)
}
}

// Top-K filtering (only apply if topK is meaningful)
// Note: We can't determine vocab size here, so TopKLogitsWarper handles the case
// where topK >= vocabSize internally
if config.topK > 0 && config.topK < Int.max {
if let processor = try? TopKLogitsWarper(topK: config.topK) {
processors.append(processor)
}
}

// Top-P (nucleus) sampling
if config.topP < 1.0 {
if let processor = try? TopPLogitsWarper(topP: Float(config.topP)) {
processors.append(processor)
}
}

// Min-P sampling (applied after temperature scaling)
if let minP = config.minP {
if let processor = try? MinPLogitsWarper(minP: Float(minP)) {
processors.append(processor)
}
}

return LogitsProcessorList(processors: processors)
}

private func tensorToGenerationOutput(_ tensor: MLTensor) async -> GenerationOutput {
await tensor.shapedArray(of: Int32.self).scalars.map { Int($0) }
}
Expand Down
18 changes: 12 additions & 6 deletions Sources/Generation/GenerationConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ public struct GenerationConfig {
public var topK = 50

/// Cumulative probability threshold for top-p sampling.
public var topP = 1.0
public var topP: Float = 1.0

/// Minimum token probability threshold, scaled by the most likely token's probability.
public var minP: Float?

/// Penalty for token repetition (1.0 means no penalty).
public var repetitionPenalty = 1.0
public var repetitionPenalty: Float = 1.0

/// Token ID used for padding sequences.
public var padTokenId: Int?
Expand All @@ -65,6 +68,7 @@ public struct GenerationConfig {
/// - temperature: Sampling temperature
/// - topK: Top-k sampling parameter
/// - topP: Top-p sampling parameter
/// - minP: Min-p sampling parameter
/// - repetitionPenalty: Repetition penalty factor
public init(
maxLength: Int = 20,
Expand All @@ -73,20 +77,22 @@ public struct GenerationConfig {
numBeams: Int = 1,
numBeamGroups: Int = 1,
penaltyAlpha: Double? = nil,
temperature: Double = 1.0,
temperature: Float = 1.0,
topK: Int = 50,
topP: Double = 1.0,
repetitionPenalty: Double = 1.0
topP: Float = 1.0,
minP: Float? = nil,
repetitionPenalty: Float = 1.0
) {
self.maxLength = maxLength
self.maxNewTokens = maxNewTokens
self.doSample = doSample
self.numBeams = numBeams
self.numBeamGroups = numBeamGroups
self.penaltyAlpha = penaltyAlpha
self.temperature = Float(temperature)
self.temperature = temperature
self.topK = topK
self.topP = topP
self.minP = minP
self.repetitionPenalty = repetitionPenalty
}
}
Expand Down
59 changes: 59 additions & 0 deletions Sources/Generation/LogitsWarper/LogitsProcessor.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#if canImport(CoreML)
import CoreML

/// Abstract base class for all logits processors that can be applied during generation.
///
/// Logits processors modify the probability distribution over vocabulary tokens by transforming
/// the raw logit scores produced by language models. This enables various sampling strategies
/// such as temperature scaling, top-k/top-p filtering, and repetition penalties.
///
/// Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py
@available(macOS 15.0, iOS 18.0, *)
public protocol LogitsProcessor {
/// Processes logits for next token prediction.
///
/// - Parameters:
/// - inputIds: Tensor of input token IDs with shape `[batch_size, sequence_length]`
/// - scores: Tensor of raw logit scores with shape `[batch_size, vocab_size]`
/// - Returns: Processed logits tensor with shape `[batch_size, vocab_size]`
///
/// - Note: The `inputIds` parameter provides context for processors that need to examine
/// the generated sequence (e.g., repetition penalty). Processors that don't need this
/// context (e.g., temperature) can ignore it.
func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor
}

/// A list of logits processors that applies each processor sequentially.
///
/// This class provides a convenient way to chain multiple logits processors together.
/// Each processor is applied in order to the logits tensor, with the output of one
/// processor becoming the input to the next.
@available(macOS 15.0, iOS 18.0, *)
public struct LogitsProcessorList {
public var processors: [any LogitsProcessor]

public init(processors: [any LogitsProcessor]) {
self.processors = processors
}

/// Applies all logits processors sequentially to the input scores.
///
/// - Parameters:
/// - inputIds: Tensor of input token IDs with shape `[batch_size, sequence_length]`
/// - scores: Tensor of raw logit scores with shape `[batch_size, vocab_size]`
/// - Returns: Processed logits tensor with shape `[batch_size, vocab_size]`
public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor {
// Following transformers convention: all logits processing happens in Float32
// Cast to Float32 once at the start, process, then cast back to original type at the end
let originalScalarType = scores.scalarType
var processedScores = scores.scalarType == Float.self ? scores : scores.cast(to: Float.self)

for processor in processors {
processedScores = await processor(inputIds, processedScores)
}

// Cast back to original type if needed
return originalScalarType == Float.self ? processedScores : processedScores.cast(to: originalScalarType)
}
}
#endif
Loading