Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added deletion of cache files and updated interface of iOS LlmTaskRunner and LlmInference #5191

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Foundation
public enum GenAiInferenceError: Error {
case invalidResponse
case illegalMethodCall
case modelNotFound
}

extension GenAiInferenceError: LocalizedError {
Expand All @@ -28,6 +29,8 @@ extension GenAiInferenceError: LocalizedError {
return "The response returned by the model is invalid."
case .illegalMethodCall:
return "Response generation is already in progress."
case .modelNotFound:
return "No file found at the `modelPath` you provided."
}
}
}
Expand All @@ -44,6 +47,8 @@ extension GenAiInferenceError: CustomNSError {
return 0
case .illegalMethodCall:
return 1
case .modelNotFound:
return 2
}
}
}
165 changes: 155 additions & 10 deletions mediapipe/tasks/ios/genai/core/sources/LlmTaskRunner.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,62 @@ import MediaPipeTasksGenAIC

/// This class is used to create and call appropriate methods on the C `LlmInferenceEngine_Session`
/// to initialize, execute and terminate any MediaPipe `LlmInference` task.
public final class LlmTaskRunner {
typealias CLlmSession = UnsafeMutableRawPointer
/// Note: Tasks should not attempt to clear undeleted caches on initialization since user can create
/// multiple instances of the task and there is now way of knowing whether they are still
/// active. Deleting caches of active task instances will result in crashes when the C++
/// functions are invoked.
/// Instead tasks can encapsulate `clearAllCachedFiles()` to provide a function to delete
/// any undeleted caches when the user wishes to.
final class LlmTaskRunner {
private typealias CLlmSession = UnsafeMutableRawPointer

private static let cacheSuffix = ".cache"
private static let globalCacheDirectory = FileManager.default.temporaryDirectory
.versionIndependentAppending(component: "mediapipe.genai.inference.cache")
private static let cacheDirectory = LlmTaskRunner.globalCacheDirectory
.versionIndependentAppending(component: "\(UUID().uuidString)")

private let cLlmSession: CLlmSession

private let modelCacheFile: URL

/// Creates a new instance of `LlmTaskRunner` with the given session config.
///
/// - Parameters:
/// - sessionConfig: C session config of type `LlmSessionConfig`.
public init(sessionConfig: LlmSessionConfig) {
/// No safe guards for session creation since the C APIs only throw fatal errors.
/// `LlmInferenceEngine_CreateSession()` will always return an llm session if the call
init(config: Config) throws {
guard FileManager.default.fileExists(atPath: config.modelPath),
let modelName = config.modelPath.components(separatedBy: "/").last
else {
throw GenAiInferenceError.modelNotFound
}

/// Adding a `UUID` prefix to the cache path to prevent the app from crashing if a model cache
/// is already found in the temporary directory.
/// Cache will be deleted when the task runner is de-allocated. Preferring deletion on
/// de-allocation to deleting all caches on initialization to prevent model caches of
/// other task runners from being de-allocated prematurely during their life time.
///
/// Note: No safe guards for session creation since the C APIs only throw fatal errors.
/// `LlmInferenceEngine_CreateSession()` will always return a llm session if the call
/// completes.
self.cLlmSession = withUnsafePointer(to: sessionConfig) { LlmInferenceEngine_CreateSession($0) }
cLlmSession = LlmTaskRunner.cacheDirectory.path.withCString { cCacheDir in
return config.modelPath.withCString { cModelPath in
let cSessionConfig = LlmSessionConfig(
model_path: cModelPath,
cache_dir: cCacheDir,
sequence_batch_size: Int(config.sequenceBatchSize),
num_decode_steps_per_sync: Int(config.numberOfDecodeStepsPerSync),
max_tokens: Int(config.maxTokens),
topk: Int(config.topk),
temperature: config.temperature,
random_seed: config.randomSeed)
return withUnsafePointer(to: cSessionConfig) { LlmInferenceEngine_CreateSession($0) }
}
}

modelCacheFile = LlmTaskRunner.cacheDirectory.versionIndependentAppending(
component: "\(modelName)\(LlmTaskRunner.cacheSuffix)")
}

/// Invokes the C inference engine with the given input text to generate an array of `String`
Expand All @@ -38,7 +81,7 @@ public final class LlmTaskRunner {
/// - Parameters:
/// - inputText: A `String` that is used to query the LLM.
/// - Throws: An error if the LLM's response is invalid.
public func predict(inputText: String) throws -> [String] {
func predict(inputText: String) throws -> [String] {
/// No safe guards for the call since the C++ APIs only throw fatal errors.
/// `LlmInferenceEngine_Session_PredictSync()` will always return a `LlmResponseContext` if the
/// call completes.
Expand All @@ -60,7 +103,7 @@ public final class LlmTaskRunner {
return responseStrings
}

public func predict(
func predict(
inputText: String, progress: @escaping (_ partialResult: [String]?, _ error: Error?) -> Void,
completion: @escaping (() -> Void)
) {
Expand Down Expand Up @@ -99,12 +142,104 @@ public final class LlmTaskRunner {
}
}

/// Clears all cached files created by `LlmInference` to prevent exponential growth of your app
/// size. Please ensure that this method is not called during the lifetime of any instances of
/// `LlmTaskRunner`.
class func clearAllCachedFiles() {
// Delete directory
do {
try FileManager.default.removeItem(at: LlmTaskRunner.globalCacheDirectory)
}
catch {
/// Errors thrown are not relevant to the user. They are usual not found errors.
}
}

deinit {
LlmInferenceEngine_Session_Delete(cLlmSession)

/// Responsibly deleting the model cache.
/// Performing on current thread since only one file needs to be deleted.
///
/// Note: Implementation will have to be updated if C++ core changes the cache prefix.
///
/// Note: `deinit` does not get invoked in the following circumstances:
/// 1. If a crash occurs before the task runner is de-allocated.
/// 2. If an instance of the task is created from `main()` and the app is terminated.
/// For eg:, if the task is an instance variable of the main `ViewController` which doesn't
/// get destroyed until the app quits.
/// Task interfaces that use the task runner should additionally provide a function that
/// encapsulates `LlmTaskrRunner.clearAllCachedFiles()` to cleanup any undeleted caches to
/// avoid exponential growth in app size. OS clears these directories only if the device runs
/// out of storage space.
/// Tasks should not attempt to clear undeleted caches on initialization since user can create
/// multiple instances of the task and there is now way of knowing whether they are still
/// active. Deleting caches of active task instances will result in crashes when the C++
/// functions are invoked.
do {
try FileManager.default.removeItem(at: modelCacheFile)
} catch {
// Could not delete file. Common cause: file not found.
}
}
}

extension LlmTaskRunner {
/// Configuration for setting up a `LlmTaskRunner`.
struct Config {
/// The absolute path to the model asset bundle stored locally on the device.
let modelPath: String

let sequenceBatchSize: UInt

let numberOfDecodeStepsPerSync: UInt

/// The total length of the kv-cache. In other words, this is the total number of input + output
/// tokens the model needs to handle.
let maxTokens: UInt

/// The top K number of tokens to be sampled from for each decoding step. A value of 1 means
/// greedy decoding. Defaults to 40.
let topk: UInt

/// The randomness when decoding the next token. A value of 0.0f means greedy decoding. Defaults
/// to 0.8.
let temperature: Float

/// The random seed for sampling tokens.
let randomSeed: Int

/// Creates a new instance of `Config` with the provided values.
///
/// - Parameters:
/// - modelPath: The absolute path to a model asset bundle stored locally on the device.
/// - sequenceBatchSize: Sequence batch size for encoding. Used by GPU only. Number of
/// input tokens to process at a time for batch processing. Setting this value to 1 means both
/// the encoding and decoding share the same graph of sequence length of 1. Setting this value
/// to 0 means the batch size will be optimized
/// programmatically.
/// - numberOfDecodeStepsPerSync: Number of decode steps per sync. Used by GPU only.
/// The default value is 3.
/// - maxTokens: Maximum number of tokens for input and output.
/// - topk: Top K number of tokens to be sampled from for each decoding step.
/// - temperature: Randomness when decoding the next token, 0.0f means greedy decoding.
/// - random_seed: Random seed for sampling tokens.
init(
modelPath: String, sequenceBatchSize: UInt, numberOfDecodeStepsPerSync: UInt, maxTokens: UInt,
topk: UInt, temperature: Float, randomSeed: Int
) {
self.modelPath = modelPath
self.sequenceBatchSize = sequenceBatchSize
self.numberOfDecodeStepsPerSync = numberOfDecodeStepsPerSync
self.maxTokens = maxTokens
self.topk = topk
self.temperature = temperature
self.randomSeed = randomSeed
}
}
}

private extension LlmTaskRunner {
/// A wrapper class whose object will be used as the C++ callback context.
/// The progress and completion callbacks cannot be invoked without a context.
class CallbackInfo {
Expand All @@ -130,8 +265,8 @@ extension LlmTaskRunner {
}
}

extension LlmTaskRunner {
private class func responseStrings(from responseContext: LlmResponseContext) -> [String]? {
private extension LlmTaskRunner {
class func responseStrings(from responseContext: LlmResponseContext) -> [String]? {
guard let cResponseArray = responseContext.response_array else {
return nil
}
Expand All @@ -147,3 +282,13 @@ extension LlmTaskRunner {
return responseStrings
}
}

fileprivate extension URL {
func versionIndependentAppending(component: String) -> URL {
if #available(iOS 16, *) {
return self.appending(component: component)
} else {
return self.appendingPathComponent(component)
}
}
}
69 changes: 40 additions & 29 deletions mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,23 @@
// limitations under the License.

import Foundation
import MediaPipeTasksGenAIC

/// A MediaPipe task that performs inference using a given Large Language Model.
///
/// Note: Inherits from `NSObject` for Objective C interoperability.
@objc(MPPLLMInference) public final class LlmInference: NSObject {
private static let numberOfDecodeStepsPerSync = 3
private static let sequenceBatchSize = 0
@objc(MPPLlmInference) public final class LlmInference: NSObject {
private static let numberOfDecodeStepsPerSync: UInt = 3
private static let sequenceBatchSize: UInt = 0
private static let cacheCleanupQueueName = "com.google.mediapipe.genai.cacheCleanupQueue.\(UUID().uuidString)"
private static let responseGenerationInProgressQueueName =
"com.google.mediapipe.genai.isResponseGenerationInProgressQueue"
"com.google.mediapipe.genai.isResponseGenerationInProgressQueue.\(UUID().uuidString)"
/// Serial queue for cache cleanup.
private static let cacheCleanupQueue = DispatchQueue(
label: cacheCleanupQueueName)

private let llmTaskRunner: LlmTaskRunner

/// Concurrent queue to implement readers-writers lock on `responseGenerationInProgress`.
private let responseGenerationInProgressQueue = DispatchQueue(
label: LlmInference.responseGenerationInProgressQueueName,
attributes: .concurrent)
Expand All @@ -52,25 +56,17 @@ import MediaPipeTasksGenAIC
/// - Parameters:
/// - options: The options of type `LlmInference.Options` to use for configuring the
/// `LlmInference`.
@objc public init(options: Options) {
let modelPath = strdup(options.modelPath)
let cacheDirectory = strdup(FileManager.default.temporaryDirectory.path)

defer {
free(modelPath)
free(cacheDirectory)
}

let sessionConfig = LlmSessionConfig(
model_path: modelPath,
cache_dir: cacheDirectory,
sequence_batch_size: LlmInference.sequenceBatchSize,
num_decode_steps_per_sync: LlmInference.numberOfDecodeStepsPerSync,
max_tokens: options.maxTokens,
@objc public init(options: Options) throws {
let taskRunnerConfig = LlmTaskRunner.Config(
modelPath: options.modelPath,
sequenceBatchSize: LlmInference.sequenceBatchSize,
numberOfDecodeStepsPerSync: LlmInference.numberOfDecodeStepsPerSync,
maxTokens: options.maxTokens,
topk: options.topk,
temperature: options.temperature,
random_seed: options.randomSeed)
llmTaskRunner = LlmTaskRunner(sessionConfig: sessionConfig)
randomSeed: options.randomSeed)

llmTaskRunner = try LlmTaskRunner(config: taskRunnerConfig)

super.init()
}
Expand All @@ -80,9 +76,9 @@ import MediaPipeTasksGenAIC
///
/// - Parameters:
/// - modelPath: The absolute path to a model asset bundle stored locally on the device.
@objc public convenience init(modelPath: String) {
@objc public convenience init(modelPath: String) throws {
let options = Options(modelPath: modelPath)
self.init(options: options)
try self.init(options: options)
}

/// Generates a response based on the input text.
Expand Down Expand Up @@ -149,6 +145,20 @@ import MediaPipeTasksGenAIC
})
}

/// Clears all cached files created by `LlmInference` to prevent exponential growth of your app
/// size. Please ensure that this method is not called during the lifetime of any instances of
/// `LlmInference`. If the cache is deleted while an instance of `LlmInference` is in scope,
/// calling one of its methods will result in undefined behaviour and may lead to a crash.
public class func clearAllCachedFiles(completion: @escaping(() -> Void)) {
/// Asynchronously deleting the files to prevent blocking the current thread as there may be
/// multiple undeleted weight caches. Choosing a serial queue to let callers wait until the
// previous call for deletion is completed.
cacheCleanupQueue.async {
LlmTaskRunner.clearAllCachedFiles()
completion()
}
}

/// Throw error if response generation is in progress or update response generation state.
private func shouldContinueWithResponseGeneration() throws {
if responseGenerationInProgress {
Expand All @@ -158,7 +168,7 @@ import MediaPipeTasksGenAIC
responseGenerationInProgress = true
}

private static func humanReadableString(
private class func humanReadableString(
llmResponses: [String], stripLeadingWhitespaces: Bool = true
) -> String? {
guard let llmResponse = llmResponses.first else {
Expand All @@ -180,11 +190,11 @@ extension LlmInference {

/// The total length of the kv-cache. In other words, this is the total number of input + output
/// tokens the model needs to handle.
@objc public var maxTokens: Int = 512
@objc public var maxTokens: UInt = 512

/// The top K number of tokens to be sampled from for each decoding step. A value of 1 means
/// greedy decoding. Defaults to 40.
@objc public var topk: Int = 40
@objc public var topk: UInt = 40

/// The randomness when decoding the next token. A value of 0.0f means greedy decoding. Defaults
/// to 0.8.
Expand All @@ -203,17 +213,18 @@ extension LlmInference {
self.modelPath = modelPath
super.init()
}

}
}

/// An extension to `String` to add some utility functions.
extension String {
fileprivate extension String {
private static let tokenSplitter = "▁"
/// Note this is NOT an underscore: ▁(U+2581)
private static let newLine = "<0x0A>"
private static let eod = "\\[eod\\]"

fileprivate func humanReadableString(stripLeadingWhitespaces: Bool = true) -> String? {
func humanReadableString(stripLeadingWhitespaces: Bool = true) -> String? {
var humanReadableString = self.replacingOccurrences(of: String.tokenSplitter, with: " ")
.replacingOccurrences(of: String.newLine, with: "\n")
humanReadableString =
Expand Down