diff --git a/mediapipe/tasks/ios/genai/core/sources/GenAiInferenceError.swift b/mediapipe/tasks/ios/genai/core/sources/GenAiInferenceError.swift index af8635a523..da4054f694 100644 --- a/mediapipe/tasks/ios/genai/core/sources/GenAiInferenceError.swift +++ b/mediapipe/tasks/ios/genai/core/sources/GenAiInferenceError.swift @@ -18,6 +18,7 @@ import Foundation public enum GenAiInferenceError: Error { case invalidResponse case illegalMethodCall + case modelNotFound } extension GenAiInferenceError: LocalizedError { @@ -28,6 +29,8 @@ extension GenAiInferenceError: LocalizedError { return "The response returned by the model is invalid." case .illegalMethodCall: return "Response generation is already in progress." + case .modelNotFound: + return "No file found at the `modelPath` you provided." } } } @@ -44,6 +47,8 @@ extension GenAiInferenceError: CustomNSError { return 0 case .illegalMethodCall: return 1 + case .modelNotFound: + return 2 } } } diff --git a/mediapipe/tasks/ios/genai/core/sources/LlmTaskRunner.swift b/mediapipe/tasks/ios/genai/core/sources/LlmTaskRunner.swift index f21cea6d31..93199b362b 100644 --- a/mediapipe/tasks/ios/genai/core/sources/LlmTaskRunner.swift +++ b/mediapipe/tasks/ios/genai/core/sources/LlmTaskRunner.swift @@ -17,19 +17,62 @@ import MediaPipeTasksGenAIC /// This class is used to create and call appropriate methods on the C `LlmInferenceEngine_Session` /// to initialize, execute and terminate any MediaPipe `LlmInference` task. -public final class LlmTaskRunner { - typealias CLlmSession = UnsafeMutableRawPointer +/// Note: Tasks should not attempt to clear undeleted caches on initialization since user can create +/// multiple instances of the task and there is now way of knowing whether they are still +/// active. Deleting caches of active task instances will result in crashes when the C++ +/// functions are invoked. +/// Instead tasks can encapsulate `clearAllCachedFiles()` to provide a function to delete +/// any undeleted caches when the user wishes to. +final class LlmTaskRunner { + private typealias CLlmSession = UnsafeMutableRawPointer + + private static let cacheSuffix = ".cache" + private static let globalCacheDirectory = FileManager.default.temporaryDirectory + .versionIndependentAppending(component: "mediapipe.genai.inference.cache") + private static let cacheDirectory = LlmTaskRunner.globalCacheDirectory + .versionIndependentAppending(component: "\(UUID().uuidString)") private let cLlmSession: CLlmSession + + private let modelCacheFile: URL + /// Creates a new instance of `LlmTaskRunner` with the given session config. /// /// - Parameters: /// - sessionConfig: C session config of type `LlmSessionConfig`. - public init(sessionConfig: LlmSessionConfig) { - /// No safe guards for session creation since the C APIs only throw fatal errors. - /// `LlmInferenceEngine_CreateSession()` will always return an llm session if the call + init(config: Config) throws { + guard FileManager.default.fileExists(atPath: config.modelPath), + let modelName = config.modelPath.components(separatedBy: "/").last + else { + throw GenAiInferenceError.modelNotFound + } + + /// Adding a `UUID` prefix to the cache path to prevent the app from crashing if a model cache + /// is already found in the temporary directory. + /// Cache will be deleted when the task runner is de-allocated. Preferring deletion on + /// de-allocation to deleting all caches on initialization to prevent model caches of + /// other task runners from being de-allocated prematurely during their life time. + /// + /// Note: No safe guards for session creation since the C APIs only throw fatal errors. + /// `LlmInferenceEngine_CreateSession()` will always return a llm session if the call /// completes. - self.cLlmSession = withUnsafePointer(to: sessionConfig) { LlmInferenceEngine_CreateSession($0) } + cLlmSession = LlmTaskRunner.cacheDirectory.path.withCString { cCacheDir in + return config.modelPath.withCString { cModelPath in + let cSessionConfig = LlmSessionConfig( + model_path: cModelPath, + cache_dir: cCacheDir, + sequence_batch_size: Int(config.sequenceBatchSize), + num_decode_steps_per_sync: Int(config.numberOfDecodeStepsPerSync), + max_tokens: Int(config.maxTokens), + topk: Int(config.topk), + temperature: config.temperature, + random_seed: config.randomSeed) + return withUnsafePointer(to: cSessionConfig) { LlmInferenceEngine_CreateSession($0) } + } + } + + modelCacheFile = LlmTaskRunner.cacheDirectory.versionIndependentAppending( + component: "\(modelName)\(LlmTaskRunner.cacheSuffix)") } /// Invokes the C inference engine with the given input text to generate an array of `String` @@ -38,7 +81,7 @@ public final class LlmTaskRunner { /// - Parameters: /// - inputText: A `String` that is used to query the LLM. /// - Throws: An error if the LLM's response is invalid. - public func predict(inputText: String) throws -> [String] { + func predict(inputText: String) throws -> [String] { /// No safe guards for the call since the C++ APIs only throw fatal errors. /// `LlmInferenceEngine_Session_PredictSync()` will always return a `LlmResponseContext` if the /// call completes. @@ -60,7 +103,7 @@ public final class LlmTaskRunner { return responseStrings } - public func predict( + func predict( inputText: String, progress: @escaping (_ partialResult: [String]?, _ error: Error?) -> Void, completion: @escaping (() -> Void) ) { @@ -99,12 +142,104 @@ public final class LlmTaskRunner { } } + /// Clears all cached files created by `LlmInference` to prevent exponential growth of your app + /// size. Please ensure that this method is not called during the lifetime of any instances of + /// `LlmTaskRunner`. + class func clearAllCachedFiles() { + // Delete directory + do { + try FileManager.default.removeItem(at: LlmTaskRunner.globalCacheDirectory) + } + catch { + /// Errors thrown are not relevant to the user. They are usual not found errors. + } + } + deinit { LlmInferenceEngine_Session_Delete(cLlmSession) + + /// Responsibly deleting the model cache. + /// Performing on current thread since only one file needs to be deleted. + /// + /// Note: Implementation will have to be updated if C++ core changes the cache prefix. + /// + /// Note: `deinit` does not get invoked in the following circumstances: + /// 1. If a crash occurs before the task runner is de-allocated. + /// 2. If an instance of the task is created from `main()` and the app is terminated. + /// For eg:, if the task is an instance variable of the main `ViewController` which doesn't + /// get destroyed until the app quits. + /// Task interfaces that use the task runner should additionally provide a function that + /// encapsulates `LlmTaskrRunner.clearAllCachedFiles()` to cleanup any undeleted caches to + /// avoid exponential growth in app size. OS clears these directories only if the device runs + /// out of storage space. + /// Tasks should not attempt to clear undeleted caches on initialization since user can create + /// multiple instances of the task and there is now way of knowing whether they are still + /// active. Deleting caches of active task instances will result in crashes when the C++ + /// functions are invoked. + do { + try FileManager.default.removeItem(at: modelCacheFile) + } catch { + // Could not delete file. Common cause: file not found. + } } } extension LlmTaskRunner { + /// Configuration for setting up a `LlmTaskRunner`. + struct Config { + /// The absolute path to the model asset bundle stored locally on the device. + let modelPath: String + + let sequenceBatchSize: UInt + + let numberOfDecodeStepsPerSync: UInt + + /// The total length of the kv-cache. In other words, this is the total number of input + output + /// tokens the model needs to handle. + let maxTokens: UInt + + /// The top K number of tokens to be sampled from for each decoding step. A value of 1 means + /// greedy decoding. Defaults to 40. + let topk: UInt + + /// The randomness when decoding the next token. A value of 0.0f means greedy decoding. Defaults + /// to 0.8. + let temperature: Float + + /// The random seed for sampling tokens. + let randomSeed: Int + + /// Creates a new instance of `Config` with the provided values. + /// + /// - Parameters: + /// - modelPath: The absolute path to a model asset bundle stored locally on the device. + /// - sequenceBatchSize: Sequence batch size for encoding. Used by GPU only. Number of + /// input tokens to process at a time for batch processing. Setting this value to 1 means both + /// the encoding and decoding share the same graph of sequence length of 1. Setting this value + /// to 0 means the batch size will be optimized + /// programmatically. + /// - numberOfDecodeStepsPerSync: Number of decode steps per sync. Used by GPU only. + /// The default value is 3. + /// - maxTokens: Maximum number of tokens for input and output. + /// - topk: Top K number of tokens to be sampled from for each decoding step. + /// - temperature: Randomness when decoding the next token, 0.0f means greedy decoding. + /// - random_seed: Random seed for sampling tokens. + init( + modelPath: String, sequenceBatchSize: UInt, numberOfDecodeStepsPerSync: UInt, maxTokens: UInt, + topk: UInt, temperature: Float, randomSeed: Int + ) { + self.modelPath = modelPath + self.sequenceBatchSize = sequenceBatchSize + self.numberOfDecodeStepsPerSync = numberOfDecodeStepsPerSync + self.maxTokens = maxTokens + self.topk = topk + self.temperature = temperature + self.randomSeed = randomSeed + } + } +} + +private extension LlmTaskRunner { /// A wrapper class whose object will be used as the C++ callback context. /// The progress and completion callbacks cannot be invoked without a context. class CallbackInfo { @@ -130,8 +265,8 @@ extension LlmTaskRunner { } } -extension LlmTaskRunner { - private class func responseStrings(from responseContext: LlmResponseContext) -> [String]? { +private extension LlmTaskRunner { + class func responseStrings(from responseContext: LlmResponseContext) -> [String]? { guard let cResponseArray = responseContext.response_array else { return nil } @@ -147,3 +282,13 @@ extension LlmTaskRunner { return responseStrings } } + +fileprivate extension URL { + func versionIndependentAppending(component: String) -> URL { + if #available(iOS 16, *) { + return self.appending(component: component) + } else { + return self.appendingPathComponent(component) + } + } +} diff --git a/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift b/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift index 8c5df2e7a4..5c2a8ac3e3 100644 --- a/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift +++ b/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift @@ -13,19 +13,23 @@ // limitations under the License. import Foundation -import MediaPipeTasksGenAIC /// A MediaPipe task that performs inference using a given Large Language Model. /// /// Note: Inherits from `NSObject` for Objective C interoperability. -@objc(MPPLLMInference) public final class LlmInference: NSObject { - private static let numberOfDecodeStepsPerSync = 3 - private static let sequenceBatchSize = 0 +@objc(MPPLlmInference) public final class LlmInference: NSObject { + private static let numberOfDecodeStepsPerSync: UInt = 3 + private static let sequenceBatchSize: UInt = 0 + private static let cacheCleanupQueueName = "com.google.mediapipe.genai.cacheCleanupQueue.\(UUID().uuidString)" private static let responseGenerationInProgressQueueName = - "com.google.mediapipe.genai.isResponseGenerationInProgressQueue" + "com.google.mediapipe.genai.isResponseGenerationInProgressQueue.\(UUID().uuidString)" + /// Serial queue for cache cleanup. + private static let cacheCleanupQueue = DispatchQueue( + label: cacheCleanupQueueName) private let llmTaskRunner: LlmTaskRunner + /// Concurrent queue to implement readers-writers lock on `responseGenerationInProgress`. private let responseGenerationInProgressQueue = DispatchQueue( label: LlmInference.responseGenerationInProgressQueueName, attributes: .concurrent) @@ -52,25 +56,17 @@ import MediaPipeTasksGenAIC /// - Parameters: /// - options: The options of type `LlmInference.Options` to use for configuring the /// `LlmInference`. - @objc public init(options: Options) { - let modelPath = strdup(options.modelPath) - let cacheDirectory = strdup(FileManager.default.temporaryDirectory.path) - - defer { - free(modelPath) - free(cacheDirectory) - } - - let sessionConfig = LlmSessionConfig( - model_path: modelPath, - cache_dir: cacheDirectory, - sequence_batch_size: LlmInference.sequenceBatchSize, - num_decode_steps_per_sync: LlmInference.numberOfDecodeStepsPerSync, - max_tokens: options.maxTokens, + @objc public init(options: Options) throws { + let taskRunnerConfig = LlmTaskRunner.Config( + modelPath: options.modelPath, + sequenceBatchSize: LlmInference.sequenceBatchSize, + numberOfDecodeStepsPerSync: LlmInference.numberOfDecodeStepsPerSync, + maxTokens: options.maxTokens, topk: options.topk, temperature: options.temperature, - random_seed: options.randomSeed) - llmTaskRunner = LlmTaskRunner(sessionConfig: sessionConfig) + randomSeed: options.randomSeed) + + llmTaskRunner = try LlmTaskRunner(config: taskRunnerConfig) super.init() } @@ -80,9 +76,9 @@ import MediaPipeTasksGenAIC /// /// - Parameters: /// - modelPath: The absolute path to a model asset bundle stored locally on the device. - @objc public convenience init(modelPath: String) { + @objc public convenience init(modelPath: String) throws { let options = Options(modelPath: modelPath) - self.init(options: options) + try self.init(options: options) } /// Generates a response based on the input text. @@ -149,6 +145,20 @@ import MediaPipeTasksGenAIC }) } + /// Clears all cached files created by `LlmInference` to prevent exponential growth of your app + /// size. Please ensure that this method is not called during the lifetime of any instances of + /// `LlmInference`. If the cache is deleted while an instance of `LlmInference` is in scope, + /// calling one of its methods will result in undefined behaviour and may lead to a crash. + public class func clearAllCachedFiles(completion: @escaping(() -> Void)) { + /// Asynchronously deleting the files to prevent blocking the current thread as there may be + /// multiple undeleted weight caches. Choosing a serial queue to let callers wait until the + // previous call for deletion is completed. + cacheCleanupQueue.async { + LlmTaskRunner.clearAllCachedFiles() + completion() + } + } + /// Throw error if response generation is in progress or update response generation state. private func shouldContinueWithResponseGeneration() throws { if responseGenerationInProgress { @@ -158,7 +168,7 @@ import MediaPipeTasksGenAIC responseGenerationInProgress = true } - private static func humanReadableString( + private class func humanReadableString( llmResponses: [String], stripLeadingWhitespaces: Bool = true ) -> String? { guard let llmResponse = llmResponses.first else { @@ -180,11 +190,11 @@ extension LlmInference { /// The total length of the kv-cache. In other words, this is the total number of input + output /// tokens the model needs to handle. - @objc public var maxTokens: Int = 512 + @objc public var maxTokens: UInt = 512 /// The top K number of tokens to be sampled from for each decoding step. A value of 1 means /// greedy decoding. Defaults to 40. - @objc public var topk: Int = 40 + @objc public var topk: UInt = 40 /// The randomness when decoding the next token. A value of 0.0f means greedy decoding. Defaults /// to 0.8. @@ -203,17 +213,18 @@ extension LlmInference { self.modelPath = modelPath super.init() } + } } /// An extension to `String` to add some utility functions. -extension String { +fileprivate extension String { private static let tokenSplitter = "▁" /// Note this is NOT an underscore: ▁(U+2581) private static let newLine = "<0x0A>" private static let eod = "\\[eod\\]" - fileprivate func humanReadableString(stripLeadingWhitespaces: Bool = true) -> String? { + func humanReadableString(stripLeadingWhitespaces: Bool = true) -> String? { var humanReadableString = self.replacingOccurrences(of: String.tokenSplitter, with: " ") .replacingOccurrences(of: String.newLine, with: "\n") humanReadableString =