diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index 51d5120..ba38550 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -5,12 +5,12 @@ import Observation public final class LanguageModelSession: @unchecked Sendable { public var isResponding: Bool { access(keyPath: \.isResponding) - return state.access { $0.isResponding } + return state.withLock { $0.isResponding } } public var transcript: Transcript { access(keyPath: \.transcript) - return state.access { $0.transcript } + return state.withLock { $0.transcript } } @ObservationIgnored private let state: Locked @@ -103,13 +103,13 @@ public final class LanguageModelSession: @unchecked Sendable { nonisolated private func beginResponding() { withMutation(keyPath: \.isResponding) { - state.access { $0.beginResponding() } + state.withLock { $0.beginResponding() } } } nonisolated private func endResponding() { withMutation(keyPath: \.isResponding) { - state.access { $0.endResponding() } + state.withLock { $0.endResponding() } } } @@ -159,7 +159,7 @@ public final class LanguageModelSession: @unchecked Sendable { ) ) session.withMutation(keyPath: \.transcript) { - session.state.access { $0.transcript.append(responseEntry) } + session.state.withLock { $0.transcript.append(responseEntry) } } } } catch { @@ -209,7 +209,7 @@ public final class LanguageModelSession: @unchecked Sendable { ) ) withMutation(keyPath: \.transcript) { - state.access { $0.transcript.append(promptEntry) } + state.withLock { $0.transcript.append(promptEntry) } } let response = try await model.respond( @@ -237,9 +237,9 @@ public final class LanguageModelSession: @unchecked Sendable { // Add tool entries and response to transcript withMutation(keyPath: \.transcript) { - state.access { state in - state.transcript.append(contentsOf: response.transcriptEntries) - state.transcript.append(responseEntry) + state.withLock { lockedState in + lockedState.transcript.append(contentsOf: response.transcriptEntries) + lockedState.transcript.append(responseEntry) } } @@ -262,7 +262,7 @@ public final class LanguageModelSession: @unchecked Sendable { ) ) withMutation(keyPath: \.transcript) { - state.access { $0.transcript.append(promptEntry) } + state.withLock { $0.transcript.append(promptEntry) } } return wrapStream( @@ -558,7 +558,7 @@ extension LanguageModelSession { ) ) withMutation(keyPath: \.transcript) { - state.access { $0.transcript.append(promptEntry) } + state.withLock { $0.transcript.append(promptEntry) } } // Extract text content for the Prompt parameter @@ -589,9 +589,9 @@ extension LanguageModelSession { // Add tool entries and response to transcript withMutation(keyPath: \.transcript) { - state.access { state in - state.transcript.append(contentsOf: response.transcriptEntries) - state.transcript.append(responseEntry) + state.withLock { lockedState in + lockedState.transcript.append(contentsOf: response.transcriptEntries) + lockedState.transcript.append(responseEntry) } } @@ -664,7 +664,7 @@ extension LanguageModelSession { ) ) withMutation(keyPath: \.transcript) { - state.access { $0.transcript.append(promptEntry) } + state.withLock { $0.transcript.append(promptEntry) } } // Extract text content for the Prompt parameter diff --git a/Sources/AnyLanguageModel/Locked.swift b/Sources/AnyLanguageModel/Locked.swift deleted file mode 100644 index 25d93d4..0000000 --- a/Sources/AnyLanguageModel/Locked.swift +++ /dev/null @@ -1,16 +0,0 @@ -import Foundation - -final class Locked { - private let lock = NSLock() - private var state: State - - init(_ state: State) { - self.state = state - } - - func access(_ block: (inout State) throws -> T) rethrows -> T { - try lock.withLock { try block(&self.state) } - } -} - -extension Locked: @unchecked Sendable where State: Sendable {} diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 130b103..4ffb877 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -27,8 +27,7 @@ import Foundation /// Coordinates a bounded in-memory cache with structured, coalesced loading. private final class ModelContextCache { private let cache: NSCache - private let lock = NSLock() - private var inFlight: [String: Task] = [:] + private let inFlight = Locked<[String: Task]>([:]) /// Creates a cache with a count-based eviction limit. init(countLimit: Int) { @@ -90,37 +89,31 @@ import Foundation } private func inFlightTask(for key: String) -> Task? { - lock.lock() - defer { lock.unlock() } - return inFlight[key] + inFlight.withLock { $0[key] } } private func setInFlight(_ task: Task, for key: String) { - lock.lock() - inFlight[key] = task - lock.unlock() + inFlight.withLock { $0[key] = task } } private func clearInFlight(for key: String) { - lock.lock() - inFlight[key] = nil - lock.unlock() + inFlight.withLock { $0[key] = nil } } private func removeInFlight(for key: String) -> Task? { - lock.lock() - defer { lock.unlock() } - let task = inFlight[key] - inFlight[key] = nil - return task + inFlight.withLock { + let task = $0[key] + $0[key] = nil + return task + } } private func removeAllInFlight() -> [Task] { - lock.lock() - defer { lock.unlock() } - let tasks = Array(inFlight.values) - inFlight.removeAll() - return tasks + inFlight.withLock { + let tasks = Array($0.values) + $0.removeAll() + return tasks + } } } diff --git a/Sources/AnyLanguageModel/Shared/Locked.swift b/Sources/AnyLanguageModel/Shared/Locked.swift new file mode 100644 index 0000000..2ea92f4 --- /dev/null +++ b/Sources/AnyLanguageModel/Shared/Locked.swift @@ -0,0 +1,24 @@ +import Foundation + +/// Protects shared mutable state behind an `NSLock`. +final class Locked { + private let lock = NSLock() + private var state: State + + /// Creates a locked container with the given initial state. + init(_ state: State) { + self.state = state + } + + /// Executes `body` while holding the lock. + /// + /// - Parameter body: A closure that reads or mutates the protected state. + /// - Returns: The value returned by `body`. + /// - Throws: Rethrows any error from `body`. + /// - Note: Keep critical sections small and synchronous. + func withLock(_ body: (inout State) throws -> T) rethrows -> T { + try lock.withLock { try body(&self.state) } + } +} + +extension Locked: @unchecked Sendable where State: Sendable {} diff --git a/Sources/AnyLanguageModel/StructuredGeneration.swift b/Sources/AnyLanguageModel/Shared/StructuredGeneration.swift similarity index 99% rename from Sources/AnyLanguageModel/StructuredGeneration.swift rename to Sources/AnyLanguageModel/Shared/StructuredGeneration.swift index 4ac3e75..aae832e 100644 --- a/Sources/AnyLanguageModel/StructuredGeneration.swift +++ b/Sources/AnyLanguageModel/Shared/StructuredGeneration.swift @@ -45,19 +45,14 @@ private final class StringTokenCache: @unchecked Sendable { let sampleTexts: [String] } - private var cache: [Key: Set] = [:] - private let lock = NSLock() + private let tokensByKey = Locked<[Key: Set]>([:]) func tokens(for key: Key) -> Set? { - lock.lock() - defer { lock.unlock() } - return cache[key] + tokensByKey.withLock { $0[key] } } func store(_ tokens: Set, for key: Key) { - lock.lock() - cache[key] = tokens - lock.unlock() + tokensByKey.withLock { $0[key] = tokens } } } diff --git a/Tests/AnyLanguageModelTests/LockedTests.swift b/Tests/AnyLanguageModelTests/LockedTests.swift index b81eda9..10426a8 100644 --- a/Tests/AnyLanguageModelTests/LockedTests.swift +++ b/Tests/AnyLanguageModelTests/LockedTests.swift @@ -3,32 +3,32 @@ import Testing @testable import AnyLanguageModel -@Suite("Locked") +@Suite("Locked Tests") struct LockedTests { @Test("Read access returns the initial value") func readAccess() { let locked = Locked(42) - let value = locked.access { $0 } + let value = locked.withLock { $0 } #expect(value == 42) } @Test("Write access mutates the state") func writeAccess() { let locked = Locked(0) - locked.access { $0 = 99 } - let value = locked.access { $0 } + locked.withLock { $0 = 99 } + let value = locked.withLock { $0 } #expect(value == 99) } @Test("Access returns the value from the closure") func returnValue() { let locked = Locked("hello") - let result = locked.access { state -> Int in + let result = locked.withLock { state -> Int in state += " world" return state.count } #expect(result == 11) - #expect(locked.access { $0 } == "hello world") + #expect(locked.withLock { $0 } == "hello world") } @Test("Access propagates thrown errors") @@ -37,7 +37,7 @@ struct LockedTests { let locked = Locked(0) #expect(throws: TestError.self) { - try locked.access { _ in throw TestError() } + try locked.withLock { _ in throw TestError() } } } @@ -50,14 +50,14 @@ struct LockedTests { } let locked = Locked(State(name: "initial", count: 0, tags: [])) - locked.access { state in + locked.withLock { state in state.name = "updated" state.count = 5 state.tags.append("a") state.tags.append("b") } - let snapshot = locked.access { $0 } + let snapshot = locked.withLock { $0 } #expect(snapshot.name == "updated") #expect(snapshot.count == 5) #expect(snapshot.tags == ["a", "b"]) @@ -70,11 +70,11 @@ struct LockedTests { await withTaskGroup(of: Void.self) { group in for _ in 0 ..< iterations { - group.addTask { locked.access { $0 += 1 } } + group.addTask { locked.withLock { $0 += 1 } } } } - let finalValue = locked.access { $0 } + let finalValue = locked.withLock { $0 } #expect(finalValue == iterations) } @@ -87,12 +87,12 @@ struct LockedTests { for i in 0 ..< iterations { let priority: TaskPriority = i.isMultiple(of: 2) ? .high : .background group.addTask(priority: priority) { - locked.access { $0 += 1 } + locked.withLock { $0 += 1 } } } } - let finalValue = locked.access { $0 } + let finalValue = locked.withLock { $0 } #expect(finalValue == iterations) } @@ -103,11 +103,11 @@ struct LockedTests { await withTaskGroup(of: Void.self) { group in for i in 0 ..< iterations { - group.addTask { locked.access { $0.append(i) } } + group.addTask { locked.withLock { $0.append(i) } } } } - let finalArray = locked.access { $0 } + let finalArray = locked.withLock { $0 } #expect(finalArray.count == iterations) } @@ -119,13 +119,13 @@ struct LockedTests { await withTaskGroup(of: Void.self) { group in for _ in 0 ..< iterations { - group.addTask { lockedA.access { $0 += 1 } } - group.addTask { lockedB.access { $0 += 1 } } + group.addTask { lockedA.withLock { $0 += 1 } } + group.addTask { lockedB.withLock { $0 += 1 } } } } - #expect(lockedA.access { $0 } == iterations) - #expect(lockedB.access { $0 } == iterations) + #expect(lockedA.withLock { $0 } == iterations) + #expect(lockedB.withLock { $0 } == iterations) } @Test("Can wrap a non-Sendable type") @@ -136,8 +136,8 @@ struct LockedTests { } let locked = Locked(Box(10)) - locked.access { $0.value += 5 } - let result = locked.access { $0.value } + locked.withLock { $0.value += 5 } + let result = locked.withLock { $0.value } #expect(result == 15) } @@ -145,8 +145,8 @@ struct LockedTests { func copySharesStorage() { let original = Locked(0) let copy = original - original.access { $0 = 42 } - let value = copy.access { $0 } + original.withLock { $0 = 42 } + let value = copy.withLock { $0 } #expect(value == 42) } } diff --git a/Tests/AnyLanguageModelTests/ObservationTests.swift b/Tests/AnyLanguageModelTests/ObservationTests.swift index 470a65a..a8e3c92 100644 --- a/Tests/AnyLanguageModelTests/ObservationTests.swift +++ b/Tests/AnyLanguageModelTests/ObservationTests.swift @@ -13,11 +13,11 @@ struct ObservationTests { withObservationTracking { _ = session.transcript } onChange: { - changed.access { $0 = true } + changed.withLock { $0 = true } } try await session.respond(to: "Hi") - #expect(changed.access { $0 } == true) + #expect(changed.withLock { $0 } == true) } @Test("Tracking transcript fires onChange when streamResponse mutates it") @@ -28,14 +28,14 @@ struct ObservationTests { withObservationTracking { _ = session.transcript } onChange: { - changed.access { $0 = true } + changed.withLock { $0 = true } } let stream = session.streamResponse(to: "Hi") for try await _ in stream {} try await Task.sleep(for: .milliseconds(10)) - #expect(changed.access { $0 } == true) + #expect(changed.withLock { $0 } == true) } @Test("Tracking isResponding fires onChange when respond mutates it") @@ -46,11 +46,11 @@ struct ObservationTests { withObservationTracking { _ = session.isResponding } onChange: { - changed.access { $0 = true } + changed.withLock { $0 = true } } try await session.respond(to: "Hi") - #expect(changed.access { $0 } == true) + #expect(changed.withLock { $0 } == true) } @Test("No onChange fires when no properties are tracked") @@ -61,11 +61,11 @@ struct ObservationTests { withObservationTracking { // Intentionally do not read any session properties } onChange: { - changed.access { $0 = true } + changed.withLock { $0 = true } } try await session.respond(to: "Hi") - #expect(changed.access { $0 } == false) + #expect(changed.withLock { $0 } == false) } @Test("Re-registering observation tracks subsequent changes") @@ -76,20 +76,20 @@ struct ObservationTests { withObservationTracking { _ = session.transcript } onChange: { - changeCount.access { $0 += 1 } + changeCount.withLock { $0 += 1 } } try await session.respond(to: "First") - #expect(changeCount.access { $0 } == 1) + #expect(changeCount.withLock { $0 } == 1) // Re-register after the first change fires withObservationTracking { _ = session.transcript } onChange: { - changeCount.access { $0 += 1 } + changeCount.withLock { $0 += 1 } } try await session.respond(to: "Second") - #expect(changeCount.access { $0 } == 2) + #expect(changeCount.withLock { $0 } == 2) } }