Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions Libraries/MLXLMCommon/Evaluate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,21 @@ public struct ArgMaxSampler: LogitSampler {
public struct TopPSampler: LogitSampler {
let temp: MLXArray
let topP: MLXArray
let randomState: MLXRandom.RandomState

public init(temperature: Float, topP: Float) {
self.temp = MLXArray(temperature)
self.topP = MLXArray(topP)
self.randomState = MLXRandom.RandomState()
}

private let compiledTopPSampling: (MLXArray, MLXArray, MLXArray) -> MLXArray = {
compile(inputs: [MLXRandom.globalState], outputs: [MLXRandom.globalState]) {
logits, topP, temp in
public func sample(logits: MLXArray) -> MLXArray {
var logits = logits
if logits.dtype == .bfloat16 {
logits = logits.asType(.float32)
}

return withRandomState(randomState) {
let probs = softmax(logits / temp, axis: -1)
let sortedIndices = argSort(probs, axis: -1)

Expand All @@ -158,34 +164,23 @@ public struct TopPSampler: LogitSampler {
let sortedToken = categorical(log(topProbs))
return sortedIndices.squeezed(axis: 0)[sortedToken]
}
}()

public func sample(logits: MLXArray) -> MLXArray {
var logits = logits
if logits.dtype == .bfloat16 {
logits = logits.asType(.float32)
}

return compiledTopPSampling(logits, topP, temp)
}
}

/// Processor that uses `temperature` to sample the logits
public struct CategoricalSampler: LogitSampler {
let temp: MLXArray
let randomState: MLXRandom.RandomState

public init(temperature: Float) {
self.temp = MLXArray(temperature)
self.randomState = MLXRandom.RandomState()
}

private let compiledCategorical: (MLXArray, MLXArray) -> MLXArray = {
compile(inputs: [MLXRandom.globalState], outputs: [MLXRandom.globalState]) { logits, temp in
public func sample(logits: MLXArray) -> MLXArray {
return withRandomState(randomState) {
categorical(logits * (1 / temp))
}
}()

public func sample(logits: MLXArray) -> MLXArray {
compiledCategorical(logits, temp)
}
}

Expand Down
125 changes: 125 additions & 0 deletions Tests/MLXLMTests/EvalTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,131 @@ public class EvalTests: XCTestCase {
XCTAssertEqual(output.shape, [1, 5, 100])
}

func testConcurrentEvaluation() async throws {
let config = LlamaConfiguration(
hiddenSize: 64, hiddenLayers: 4, intermediateSize: 128, attentionHeads: 8,
rmsNormEps: 0.00001, vocabularySize: 100, kvHeads: 4)
let model = LlamaModel(config)
quantize(model: model, groupSize: 64, bits: 4)

// Force evaluation of all model weights before concurrent usage
// This ensures all weight promises are realized and avoids race conditions
eval(model)

let numTasks = 3
let results = await withTaskGroup(of: MLXArray.self) { group in
var allResults: [MLXArray] = []

for taskId in 0 ..< numTasks {
group.addTask {
let input = MLXArray([
1 + taskId, 2 + taskId, 3 + taskId, 4 + taskId, 5 + taskId,
])[.newAxis, .ellipsis]
let output = model.callAsFunction(input, cache: nil)
return output
}
}

for await result in group {
allResults.append(result)
}

return allResults
}

XCTAssertEqual(results.count, numTasks)

for (index, result) in results.enumerated() {
XCTAssertEqual(result.shape, [1, 5, 100])
}
}

func testConcurrentSampling() async throws {
let vocabSize = 100
let logits = MLXRandom.normal([1, vocabSize])

let numSamplers = 4
let results = try await withThrowingTaskGroup(of: Int.self) { group in
var samplerResults: [Int] = []

for samplerId in 0 ..< numSamplers {
group.addTask {
return try withRandomState(MLXRandom.RandomState(seed: UInt64(samplerId))) {
if samplerId % 2 == 0 {
return categorical(logits).item(Int.self)
} else {
return logits.argMax(axis: -1).item(Int.self)
}
}
}
}

for try await result in group {
samplerResults.append(result)
}

return samplerResults
}

XCTAssertEqual(results.count, numSamplers)

for result in results {
XCTAssertGreaterThanOrEqual(result, 0)
XCTAssertLessThan(result, vocabSize)
}
}

func testRandomStateIsolation() async throws {
let config = LlamaConfiguration(
hiddenSize: 32, hiddenLayers: 2, intermediateSize: 64, attentionHeads: 4,
rmsNormEps: 0.00001, vocabularySize: 50, kvHeads: 2)

// Force evaluation of all model weights before concurrent usage
// This ensures all weight promises are realized and avoids race conditions
let model = LlamaModel(config)
eval(model)

let sharedLogits = MLXArray.ones([1, 50])
let numSamplers = 5
let samplesPerTask = 10

let allResults = try await withThrowingTaskGroup(of: [Int].self) { group in
var results: [[Int]] = []

for samplerId in 0 ..< numSamplers {
group.addTask {
var taskResults: [Int] = []
let sampler = CategoricalSampler(temperature: 1.0)

for sampleId in 0 ..< samplesPerTask {
let token = try withRandomState(
MLXRandom.RandomState(seed: UInt64(samplerId * 1000 + sampleId))
) {
return sampler.sample(logits: sharedLogits)
}
taskResults.append(token.item(Int.self))
}

return taskResults
}
}

for try await result in group {
results.append(result)
}

return results
}

XCTAssertEqual(allResults.count, numSamplers)

for samplerResults in allResults {
XCTAssertEqual(samplerResults.count, samplesPerTask)
}

let uniqueSequences = Set(allResults.map { $0.description })
XCTAssertGreaterThan(uniqueSequences.count, 0)
}
}

struct TestTokenizer: Tokenizer {
Expand Down