From b87e042560b4a4c5382815ad8679a9ab6a78dc5b Mon Sep 17 00:00:00 2001 From: Mike Grabowski Date: Tue, 30 Sep 2025 20:40:24 +0200 Subject: [PATCH 01/18] initial commit --- packages/mlc/android/.editorconfig | 14 -- packages/mlc/android/build.gradle | 95 +++------- packages/mlc/android/gradle.properties | 10 +- .../mlc/android/src/main/AndroidManifest.xml | 11 +- .../android/src/main/AndroidManifestNew.xml | 11 -- .../com/{ai => reactnativeai}/ChatState.kt | 28 ++- .../com/{ai => reactnativeai}/ModelState.kt | 13 +- .../NativeMLCEngineModule.kt} | 175 ++++++++++-------- .../NativeMLCEnginePackage.kt} | 33 ++-- packages/mlc/android/src/newarch/AiSpec.kt | 2 +- packages/mlc/android/src/oldarch/AiSpec.kt | 13 -- packages/mlc/package.json | 3 + 12 files changed, 185 insertions(+), 223 deletions(-) delete mode 100644 packages/mlc/android/.editorconfig delete mode 100644 packages/mlc/android/src/main/AndroidManifestNew.xml rename packages/mlc/android/src/main/java/com/{ai => reactnativeai}/ChatState.kt (75%) rename packages/mlc/android/src/main/java/com/{ai => reactnativeai}/ModelState.kt (91%) rename packages/mlc/android/src/main/java/com/{ai/AiModule.kt => reactnativeai/NativeMLCEngineModule.kt} (71%) rename packages/mlc/android/src/main/java/com/{ai/AiPackage.kt => reactnativeai/NativeMLCEnginePackage.kt} (51%) delete mode 100644 packages/mlc/android/src/oldarch/AiSpec.kt diff --git a/packages/mlc/android/.editorconfig b/packages/mlc/android/.editorconfig deleted file mode 100644 index f9340500..00000000 --- a/packages/mlc/android/.editorconfig +++ /dev/null @@ -1,14 +0,0 @@ -[*.{kt,kts}] -indent_style=space -indent_size=2 -continuation_indent_size=2 -insert_final_newline=true -max_line_length=160 -ktlint_code_style=android_studio -ktlint_standard=enabled -ktlint_experimental=enabled -ktlint_standard_filename=disabled # dont require PascalCase filenames -ktlint_standard_no-wildcard-imports=disabled # allow .* imports -ktlint_function_signature_body_expression_wrapping=multiline -ij_kotlin_allow_trailing_comma_on_call_site=false -ij_kotlin_allow_trailing_comma=false diff --git a/packages/mlc/android/build.gradle b/packages/mlc/android/build.gradle index 2d8fea4d..12bee6f5 100644 --- a/packages/mlc/android/build.gradle +++ b/packages/mlc/android/build.gradle @@ -1,5 +1,7 @@ +import com.android.Version + buildscript { - def kotlin_version = rootProject.ext.has("kotlinVersion") ? rootProject.ext.get("kotlinVersion") : project.properties["Ai_kotlinVersion"] + def kotlin_version = rootProject.ext.has("kotlinVersion") ? rootProject.ext.get("kotlinVersion") : project.properties["MLC_kotlinVersion"] repositories { google() @@ -12,45 +14,34 @@ buildscript { } } -def reactNativeArchitectures() { - def value = rootProject.getProperties().get("reactNativeArchitectures") - return value ? value.split(",") : ["armeabi-v7a", "x86", "x86_64", "arm64-v8a"] -} - -def isNewArchitectureEnabled() { - return rootProject.hasProperty("newArchEnabled") && rootProject.getProperty("newArchEnabled") == "true" -} - apply plugin: "com.android.library" apply plugin: "kotlin-android" - -if (isNewArchitectureEnabled()) { - apply plugin: "com.facebook.react" -} +apply plugin: "com.facebook.react" def getExtOrDefault(name) { - return rootProject.ext.has(name) ? rootProject.ext.get(name) : project.properties["Ai_" + name] + return rootProject.ext.has(name) ? rootProject.ext.get(name) : project.properties["MLC_" + name] } def getExtOrIntegerDefault(name) { - return rootProject.ext.has(name) ? rootProject.ext.get(name) : (project.properties["Ai_" + name]).toInteger() + return rootProject.ext.has(name) ? rootProject.ext.get(name) : (project.properties["MLC_" + name]).toInteger() } -def supportsNamespace() { - def parsed = com.android.Version.ANDROID_GRADLE_PLUGIN_VERSION.tokenize('.') +static def supportsNamespace() { + def parsed = Version.ANDROID_GRADLE_PLUGIN_VERSION.tokenize('.') def major = parsed[0].toInteger() def minor = parsed[1].toInteger() + + // Namespace support was added in 7.3.0 return (major == 7 && minor >= 3) || major >= 8 } - android { if (supportsNamespace()) { - namespace "com.ai" - + namespace "com.reactnativeai" + sourceSets { main { - manifest.srcFile "src/main/AndroidManifestNew.xml" + manifest.srcFile "src/main/AndroidManifest.xml" } } } @@ -60,7 +51,6 @@ android { defaultConfig { minSdkVersion getExtOrIntegerDefault("minSdkVersion") targetSdkVersion getExtOrIntegerDefault("targetSdkVersion") - buildConfigField "boolean", "IS_NEW_ARCHITECTURE_ENABLED", isNewArchitectureEnabled().toString() } buildFeatures { @@ -84,20 +74,13 @@ android { sourceSets { main { - if (isNewArchitectureEnabled()) { - java.srcDirs += [ - "src/newarch", - "${project.buildDir}/generated/source/codegen/java" - ] - } else { - java.srcDirs += ["src/oldarch"] - } - - // Include prebuilt native libraries - jniLibs.srcDirs += ["${projectDir}/../prebuilt/android/jniLibs"] - - // Include prebuilt model assets - assets.srcDirs += ["${projectDir}/../prebuilt/android/models"] + java.srcDirs += [ + "src", + "generated/java", + "generated/jni", + "../prebuilt/android/lib/mlc4j/src/main/java" + ] + jniLibs.srcDirs += ["../prebuilt/android/lib/mlc4j/output"] } } } @@ -112,34 +95,16 @@ def kotlin_version = getExtOrDefault("kotlinVersion") dependencies { implementation "com.facebook.react:react-native:+" implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version" + + // Dependencies implementation 'com.google.code.gson:gson:2.10.1' - - // Use prebuilt MLC runtime AAR (no more mlc4j project dependency!) - implementation files("${projectDir}/../prebuilt/android/mlc-runtime.aar") - - // UI dependencies - implementation 'androidx.core:core-ktx:1.10.1' - implementation 'androidx.lifecycle:lifecycle-runtime-ktx:2.6.1' - implementation 'com.github.jeziellago:compose-markdown:0.5.2' - implementation 'androidx.activity:activity-compose:1.7.1' - implementation platform('androidx.compose:compose-bom:2022.10.00') - implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.6.1' - implementation 'androidx.compose.material3:material3:1.1.0' - implementation 'androidx.compose.material:material-icons-extended' - implementation 'androidx.appcompat:appcompat:1.6.1' - implementation 'androidx.navigation:navigation-compose:2.5.3' - - // Testing - testImplementation 'junit:junit:4.13.2' - androidTestImplementation 'androidx.test.ext:junit:1.1.5' - androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1' - androidTestImplementation platform('androidx.compose:compose-bom:2022.10.00') + + // MLC4J dependency + implementation files('../prebuilt/android/lib/mlc4j/output/tvm4j_core.jar') } -if (isNewArchitectureEnabled()) { - react { - jsRootDir = file("../src/") - libraryName = "Ai" - codegenJavaPackageName = "com.ai" - } -} \ No newline at end of file +react { + jsRootDir = file("../src/") + libraryName = "MLC" + codegenJavaPackageName = "com.reactnativeai" +} diff --git a/packages/mlc/android/gradle.properties b/packages/mlc/android/gradle.properties index e045f675..cd6660de 100644 --- a/packages/mlc/android/gradle.properties +++ b/packages/mlc/android/gradle.properties @@ -1,5 +1,5 @@ -Ai_kotlinVersion=1.7.0 -Ai_minSdkVersion=21 -Ai_targetSdkVersion=31 -Ai_compileSdkVersion=31 -Ai_ndkversion=21.4.7075529 +MLC_kotlinVersion=1.7.0 +MLC_minSdkVersion=21 +MLC_targetSdkVersion=31 +MLC_compileSdkVersion=31 +MLC_ndkversion=21.4.7075529 diff --git a/packages/mlc/android/src/main/AndroidManifest.xml b/packages/mlc/android/src/main/AndroidManifest.xml index 928ca695..05facbb8 100644 --- a/packages/mlc/android/src/main/AndroidManifest.xml +++ b/packages/mlc/android/src/main/AndroidManifest.xml @@ -1,3 +1,12 @@ + package="com.reactnativeai"> + diff --git a/packages/mlc/android/src/main/AndroidManifestNew.xml b/packages/mlc/android/src/main/AndroidManifestNew.xml deleted file mode 100644 index 178a69d3..00000000 --- a/packages/mlc/android/src/main/AndroidManifestNew.xml +++ /dev/null @@ -1,11 +0,0 @@ - - - - - - - diff --git a/packages/mlc/android/src/main/java/com/ai/ChatState.kt b/packages/mlc/android/src/main/java/com/reactnativeai/ChatState.kt similarity index 75% rename from packages/mlc/android/src/main/java/com/ai/ChatState.kt rename to packages/mlc/android/src/main/java/com/reactnativeai/ChatState.kt index cd984fca..c47c4510 100644 --- a/packages/mlc/android/src/main/java/com/ai/ChatState.kt +++ b/packages/mlc/android/src/main/java/com/reactnativeai/ChatState.kt @@ -1,15 +1,15 @@ -package com.ai +package com.reactnativeai import ai.mlc.mlcllm.MLCEngine import ai.mlc.mlcllm.OpenAIProtocol import ai.mlc.mlcllm.OpenAIProtocol.ChatCompletionMessage -import java.io.File -import java.util.concurrent.Executors import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.channels.toList import kotlinx.coroutines.launch +import java.io.File +import java.util.concurrent.Executors class Chat(modelConfig: ModelConfig, modelDir: File) { private val engine = MLCEngine() @@ -21,21 +21,33 @@ class Chat(modelConfig: ModelConfig, modelDir: File) { engine.reload(modelDir.path, modelConfig.modelLib) } - fun generateResponse(messages: MutableList, callback: GenerateCallback) { + fun generateResponse( + messages: MutableList, + callback: GenerateCallback + ) { executorService.submit { viewModelScope.launch { val chatResponse = engine.chat.completions.create(messages = messages) - val response = chatResponse.toList().joinToString("") { it.choices.joinToString("") { it.delta.content?.text ?: "" } } + val response = chatResponse.toList().joinToString("") { + it.choices.joinToString("") { choice -> + choice.delta.content?.text ?: "" + } + } callback.onMessageReceived(response) } } } - fun streamResponse(messages: MutableList, callback: StreamCallback) { + fun streamResponse( + messages: MutableList, + callback: StreamCallback + ) { executorService.submit { viewModelScope.launch { - val chatResponse = engine.chat.completions.create(messages = messages, stream_options = OpenAIProtocol.StreamOptions(include_usage = true)) - + val chatResponse = engine.chat.completions.create( + messages = messages, + stream_options = OpenAIProtocol.StreamOptions(include_usage = true) + ) var finishReasonLength = false var streamingText = "" diff --git a/packages/mlc/android/src/main/java/com/ai/ModelState.kt b/packages/mlc/android/src/main/java/com/reactnativeai/ModelState.kt similarity index 91% rename from packages/mlc/android/src/main/java/com/ai/ModelState.kt rename to packages/mlc/android/src/main/java/com/reactnativeai/ModelState.kt index 413b229a..38496e33 100644 --- a/packages/mlc/android/src/main/java/com/ai/ModelState.kt +++ b/packages/mlc/android/src/main/java/com/reactnativeai/ModelState.kt @@ -1,9 +1,8 @@ -package com.ai +package com.reactnativeai -import androidx.compose.runtime.mutableIntStateOf -import com.ai.AiModule.Companion.MODEL_CONFIG_FILENAME -import com.ai.AiModule.Companion.MODEL_URL_SUFFIX -import com.ai.AiModule.Companion.PARAMS_CONFIG_FILENAME +import com.reactnativeai.NativeMLCEngineModule.Companion.MODEL_CONFIG_FILENAME +import com.reactnativeai.NativeMLCEngineModule.Companion.MODEL_URL_SUFFIX +import com.reactnativeai.NativeMLCEngineModule.Companion.PARAMS_CONFIG_FILENAME import com.google.gson.Gson import java.io.File import java.io.FileOutputStream @@ -17,7 +16,7 @@ import kotlinx.coroutines.withContext class ModelState(private val modelConfig: ModelConfig, private val modelDir: File) { private var paramsConfig = ParamsConfig(emptyList()) val progress = MutableStateFlow(0) - val total = mutableIntStateOf(1) + val total = MutableStateFlow(1) val id: UUID = UUID.randomUUID() private val remainingTasks = emptySet().toMutableSet() private val downloadingTasks = emptySet().toMutableSet() @@ -113,7 +112,7 @@ class ModelState(private val modelConfig: ModelConfig, private val modelDir: Fil private fun indexModel() { progress.value = 0 - total.intValue = modelConfig.tokenizerFiles.size + paramsConfig.paramsRecords.size + total.value = modelConfig.tokenizerFiles.size + paramsConfig.paramsRecords.size // Adding Tokenizer to download tasks for (tokenizerFilename in modelConfig.tokenizerFiles) { diff --git a/packages/mlc/android/src/main/java/com/ai/AiModule.kt b/packages/mlc/android/src/main/java/com/reactnativeai/NativeMLCEngineModule.kt similarity index 71% rename from packages/mlc/android/src/main/java/com/ai/AiModule.kt rename to packages/mlc/android/src/main/java/com/reactnativeai/NativeMLCEngineModule.kt index 64cd2aee..7449fb77 100644 --- a/packages/mlc/android/src/main/java/com/ai/AiModule.kt +++ b/packages/mlc/android/src/main/java/com/reactnativeai/NativeMLCEngineModule.kt @@ -1,4 +1,4 @@ -package com.ai +package com.reactnativeai import ai.mlc.mlcllm.OpenAIProtocol import ai.mlc.mlcllm.OpenAIProtocol.ChatCompletionMessage @@ -6,8 +6,6 @@ import android.os.Environment import android.util.Log import com.facebook.react.bridge.* import com.facebook.react.bridge.ReactContext.RCTDeviceEventEmitter -import com.facebook.react.module.annotations.ReactModule -import com.facebook.react.turbomodule.core.interfaces.TurboModule import com.google.gson.Gson import com.google.gson.annotations.SerializedName import java.io.File @@ -20,15 +18,11 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch import kotlinx.coroutines.withContext -@ReactModule(name = AiModule.NAME) -class AiModule(reactContext: ReactApplicationContext) : - ReactContextBaseJavaModule(reactContext), - TurboModule { - +class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEngineSpec(reactContext) { override fun getName(): String = NAME companion object { - const val NAME = "Ai" + const val NAME = "NativeMLCEngine" const val APP_CONFIG_FILENAME = "mlc-app-config.json" const val MODEL_CONFIG_FILENAME = "mlc-chat-config.json" @@ -40,6 +34,7 @@ class AiModule(reactContext: ReactApplicationContext) : emptyList().toMutableList(), emptyList().toMutableList() ) + private val gson = Gson() private lateinit var chat: Chat @@ -78,8 +73,7 @@ class AiModule(reactContext: ReactApplicationContext) : return modelConfig } - @ReactMethod - fun getModel(name: String, promise: Promise) { + override fun getModel(name: String, promise: Promise) { appConfig = getAppConfig() val modelConfig = appConfig.modelList.find { modelRecord -> modelRecord.modelId == name } @@ -89,17 +83,14 @@ class AiModule(reactContext: ReactApplicationContext) : return } - // Return a JSON object with details val modelConfigInstance = Arguments.createMap().apply { - putString("modelId", modelConfig.modelId) - putString("modelLib", modelConfig.modelLib) // Add more fields if needed + putString("model_id", modelConfig.modelId) } promise.resolve(modelConfigInstance) } - @ReactMethod - fun getModels(promise: Promise) { + override fun getModels(promise: Promise) { try { appConfig = getAppConfig() appConfig.modelLibs = emptyList().toMutableList() @@ -116,72 +107,91 @@ class AiModule(reactContext: ReactApplicationContext) : } } - @ReactMethod - fun doGenerate(instanceId: String, messages: ReadableArray, promise: Promise) { - val messageList = mutableListOf() - - for (i in 0 until messages.size()) { - val messageMap = messages.getMap(i) // Extract ReadableMap - - val role = if (messageMap.getString("role") == "user") OpenAIProtocol.ChatCompletionRole.user else OpenAIProtocol.ChatCompletionRole.assistant - val content = messageMap.getString("content") ?: "" - - messageList.add(ChatCompletionMessage(role, content)) - } - - CoroutineScope(Dispatchers.Main).launch { - try { - chat.generateResponse( - messageList, - callback = object : Chat.GenerateCallback { - override fun onMessageReceived(message: String) { - promise.resolve(message) - } - } - ) - } catch (e: Exception) { - Log.e("AI", "Error generating response", e) - } - } + override fun generateText( + messages: ReadableArray?, + options: ReadableMap?, + promise: Promise? + ) { + TODO("Not yet implemented") } - @ReactMethod - fun doStream(instanceId: String, messages: ReadableArray, promise: Promise) { - val messageList = mutableListOf() - - for (i in 0 until messages.size()) { - val messageMap = messages.getMap(i) // Extract ReadableMap - - val role = if (messageMap.getString("role") == "user") OpenAIProtocol.ChatCompletionRole.user else OpenAIProtocol.ChatCompletionRole.assistant - val content = messageMap.getString("content") ?: "" - - messageList.add(ChatCompletionMessage(role, content)) - } - CoroutineScope(Dispatchers.Main).launch { - chat.streamResponse( - messageList, - callback = object : Chat.StreamCallback { - override fun onUpdate(message: String) { - val event: WritableMap = Arguments.createMap().apply { - putString("content", message) - } - sendEvent("onChatUpdate", event) - } + override fun streamText( + messages: ReadableArray?, + options: ReadableMap?, + promise: Promise? + ) { + TODO("Not yet implemented") + } - override fun onFinished(message: String) { - val event: WritableMap = Arguments.createMap().apply { - putString("content", message) - } - sendEvent("onChatComplete", event) - } - } - ) - } - promise.resolve(null) + override fun cancelStream(streamId: String?, promise: Promise?) { + TODO("Not yet implemented") } - @ReactMethod - fun downloadModel(instanceId: String, promise: Promise) { +// @ReactMethod +// fun doGenerate(instanceId: String, messages: ReadableArray, promise: Promise) { +// val messageList = mutableListOf() +// +// for (i in 0 until messages.size()) { +// val messageMap = messages.getMap(i) // Extract ReadableMap +// +// val role = if (messageMap.getString("role") == "user") OpenAIProtocol.ChatCompletionRole.user else OpenAIProtocol.ChatCompletionRole.assistant +// val content = messageMap.getString("content") ?: "" +// +// messageList.add(ChatCompletionMessage(role, content)) +// } +// +// CoroutineScope(Dispatchers.Main).launch { +// try { +// chat.generateResponse( +// messageList, +// callback = object : Chat.GenerateCallback { +// override fun onMessageReceived(message: String) { +// promise.resolve(message) +// } +// } +// ) +// } catch (e: Exception) { +// Log.e("AI", "Error generating response", e) +// } +// } +// } + +// @ReactMethod +// fun doStream(instanceId: String, messages: ReadableArray, promise: Promise) { +// val messageList = mutableListOf() +// +// for (i in 0 until messages.size()) { +// val messageMap = messages.getMap(i) // Extract ReadableMap +// +// val role = if (messageMap.getString("role") == "user") OpenAIProtocol.ChatCompletionRole.user else OpenAIProtocol.ChatCompletionRole.assistant +// val content = messageMap.getString("content") ?: "" +// +// messageList.add(ChatCompletionMessage(role, content)) +// } +// CoroutineScope(Dispatchers.Main).launch { +// chat.streamResponse( +// messageList, +// callback = object : Chat.StreamCallback { +// override fun onUpdate(message: String) { +// val event: WritableMap = Arguments.createMap().apply { +// putString("content", message) +// } +// sendEvent("onChatUpdate", event) +// } +// +// override fun onFinished(message: String) { +// val event: WritableMap = Arguments.createMap().apply { +// putString("content", message) +// } +// sendEvent("onChatComplete", event) +// } +// } +// ) +// } +// promise.resolve(null) +// } + + override fun downloadModel(instanceId: String, promise: Promise) { CoroutineScope(Dispatchers.IO).launch { try { val appConfig = getAppConfig() @@ -221,12 +231,15 @@ class AiModule(reactContext: ReactApplicationContext) : } } + override fun removeModel(modelId: String?, promise: Promise?) { + TODO("Not yet implemented") + } + private fun sendEvent(eventName: String, data: Any?) { reactApplicationContext.getJSModule(RCTDeviceEventEmitter::class.java)?.emit(eventName, data) } - @ReactMethod - fun prepareModel(instanceId: String, promise: Promise) { + override fun prepareModel(instanceId: String, promise: Promise) { CoroutineScope(Dispatchers.IO).launch { try { val appConfig = getAppConfig() @@ -253,6 +266,10 @@ class AiModule(reactContext: ReactApplicationContext) : } } + override fun unloadModel(promise: Promise?) { + TODO("Not yet implemented") + } + private suspend fun downloadModelConfig(modelRecord: ModelRecord) { withContext(Dispatchers.IO) { // Don't download if config is downloaded already diff --git a/packages/mlc/android/src/main/java/com/ai/AiPackage.kt b/packages/mlc/android/src/main/java/com/reactnativeai/NativeMLCEnginePackage.kt similarity index 51% rename from packages/mlc/android/src/main/java/com/ai/AiPackage.kt rename to packages/mlc/android/src/main/java/com/reactnativeai/NativeMLCEnginePackage.kt index e63df794..0d747589 100644 --- a/packages/mlc/android/src/main/java/com/ai/AiPackage.kt +++ b/packages/mlc/android/src/main/java/com/reactnativeai/NativeMLCEnginePackage.kt @@ -1,16 +1,17 @@ -package com.ai +package com.reactnativeai + +import com.facebook.react.BaseReactPackage -import com.facebook.react.TurboReactPackage import com.facebook.react.bridge.NativeModule import com.facebook.react.bridge.ReactApplicationContext import com.facebook.react.module.model.ReactModuleInfo import com.facebook.react.module.model.ReactModuleInfoProvider import java.util.HashMap -class AiPackage : TurboReactPackage() { +class NativeMLCEnginePackage : BaseReactPackage() { override fun getModule(name: String, reactContext: ReactApplicationContext): NativeModule? = - if (name == AiModule.NAME) { - AiModule(reactContext) + if (name == NativeMLCEngineModule.NAME) { + NativeMLCEngineModule(reactContext) } else { null } @@ -18,20 +19,14 @@ class AiPackage : TurboReactPackage() { override fun getReactModuleInfoProvider(): ReactModuleInfoProvider = ReactModuleInfoProvider { val moduleInfos: MutableMap = HashMap() - val isTurboModule: Boolean = BuildConfig.IS_NEW_ARCHITECTURE_ENABLED - moduleInfos[AiModule.NAME] = ReactModuleInfo( - AiModule.NAME, - AiModule.NAME, - // canOverrideExistingModule - false, - // needsEagerInit - false, - // hasConstants - true, - // isCxxModule - false, - // isTurboModule - isTurboModule + moduleInfos[NativeMLCEngineModule.NAME] = ReactModuleInfo( + NativeMLCEngineModule.NAME, + NativeMLCEngineModule.NAME, + canOverrideExistingModule = false, + needsEagerInit = false, + hasConstants = true, + isCxxModule = false, + isTurboModule = true ) moduleInfos diff --git a/packages/mlc/android/src/newarch/AiSpec.kt b/packages/mlc/android/src/newarch/AiSpec.kt index e85df490..d120835c 100644 --- a/packages/mlc/android/src/newarch/AiSpec.kt +++ b/packages/mlc/android/src/newarch/AiSpec.kt @@ -1,4 +1,4 @@ -package com.ai +package com.reactnativeai import com.facebook.react.bridge.ReactApplicationContext diff --git a/packages/mlc/android/src/oldarch/AiSpec.kt b/packages/mlc/android/src/oldarch/AiSpec.kt deleted file mode 100644 index bd5fe9cb..00000000 --- a/packages/mlc/android/src/oldarch/AiSpec.kt +++ /dev/null @@ -1,13 +0,0 @@ -package com.ai - -import com.facebook.react.bridge.Promise -import com.facebook.react.bridge.ReactApplicationContext -import com.facebook.react.bridge.ReactContextBaseJavaModule - -abstract class AiSpec internal constructor(context: ReactApplicationContext) : ReactContextBaseJavaModule(context) { - - abstract fun getModel(name: String, promise: Promise) - abstract fun getModels(promise: Promise) - abstract fun doGenerate(instanceId: String, text: String, promise: Promise) - abstract fun doStream(instanceId: String, text: String, promise: Promise) -} diff --git a/packages/mlc/package.json b/packages/mlc/package.json index 8009b319..41d990fc 100644 --- a/packages/mlc/package.json +++ b/packages/mlc/package.json @@ -85,6 +85,9 @@ "modulesProvider": { "NativeMLCEngine": "MLCEngine" } + }, + "android": { + "javaPackageName": "com.reactnativeai" } }, "dependencies": { From 234240f42b98f59eb731a2776dadab1503323c04 Mon Sep 17 00:00:00 2001 From: Mike Grabowski Date: Thu, 9 Oct 2025 19:43:40 -0700 Subject: [PATCH 02/18] chore: move packages, clean-up implementation, drop gson and more --- packages/mlc/android/build.gradle | 9 +- .../mlc/android/src/main/AndroidManifest.xml | 2 +- .../ai}/ChatState.kt | 4 +- .../ai}/ModelState.kt | 18 +-- .../ai}/NativeMLCEngineModule.kt | 118 ++++++------------ .../ai}/NativeMLCEnginePackage.kt | 2 +- packages/mlc/android/src/newarch/AiSpec.kt | 5 - packages/mlc/package.json | 2 +- 8 files changed, 58 insertions(+), 102 deletions(-) rename packages/mlc/android/src/main/java/com/{reactnativeai => callstack/ai}/ChatState.kt (96%) rename packages/mlc/android/src/main/java/com/{reactnativeai => callstack/ai}/ModelState.kt (86%) rename packages/mlc/android/src/main/java/com/{reactnativeai => callstack/ai}/NativeMLCEngineModule.kt (70%) rename packages/mlc/android/src/main/java/com/{reactnativeai => callstack/ai}/NativeMLCEnginePackage.kt (97%) delete mode 100644 packages/mlc/android/src/newarch/AiSpec.kt diff --git a/packages/mlc/android/build.gradle b/packages/mlc/android/build.gradle index 12bee6f5..93cc7f95 100644 --- a/packages/mlc/android/build.gradle +++ b/packages/mlc/android/build.gradle @@ -11,11 +11,13 @@ buildscript { dependencies { classpath "com.android.tools.build:gradle:7.2.1" classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version" + classpath "org.jetbrains.kotlin:kotlin-serialization:$kotlin_version" } } apply plugin: "com.android.library" apply plugin: "kotlin-android" +apply plugin: "kotlinx-serialization" apply plugin: "com.facebook.react" def getExtOrDefault(name) { @@ -37,7 +39,7 @@ static def supportsNamespace() { android { if (supportsNamespace()) { - namespace "com.reactnativeai" + namespace "com.callstack.ai" sourceSets { main { @@ -80,6 +82,7 @@ android { "generated/jni", "../prebuilt/android/lib/mlc4j/src/main/java" ] + assets.srcDirs += ["../prebuilt/android/lib/mlc4j/src/main/assets"] jniLibs.srcDirs += ["../prebuilt/android/lib/mlc4j/output"] } } @@ -97,7 +100,7 @@ dependencies { implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version" // Dependencies - implementation 'com.google.code.gson:gson:2.10.1' + implementation 'org.jetbrains.kotlinx:kotlinx-serialization-json:1.6.3' // MLC4J dependency implementation files('../prebuilt/android/lib/mlc4j/output/tvm4j_core.jar') @@ -106,5 +109,5 @@ dependencies { react { jsRootDir = file("../src/") libraryName = "MLC" - codegenJavaPackageName = "com.reactnativeai" + codegenJavaPackageName = "com.callstack.ai" } diff --git a/packages/mlc/android/src/main/AndroidManifest.xml b/packages/mlc/android/src/main/AndroidManifest.xml index 05facbb8..ae210e41 100644 --- a/packages/mlc/android/src/main/AndroidManifest.xml +++ b/packages/mlc/android/src/main/AndroidManifest.xml @@ -1,5 +1,5 @@ + package="com.callstack.ai"> + diff --git a/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt b/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt index 0067ac20..c968c6ea 100644 --- a/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt +++ b/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt @@ -85,10 +85,9 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn val chatResponse = engine.chat.completions.create( messages = messageList, - temperature = options?.getDouble("temperature")?.toFloat(), - max_tokens = options?.getInt("maxTokens"), - top_p = options?.getDouble("topP")?.toFloat(), - frequency_penalty = options?.getDouble("topK")?.toFloat(), + temperature = options?.takeIf { it.hasKey("temperature") }?.getDouble("temperature")?.toFloat(), + max_tokens = options?.takeIf { it.hasKey("maxTokens") }?.getInt("maxTokens"), + top_p = options?.takeIf { it.hasKey("topP") }?.getDouble("topP")?.toFloat(), response_format = responseFormat ) From cfd7bf2e30e6e45172d1a0434df228ffcb990553 Mon Sep 17 00:00:00 2001 From: Mike Grabowski Date: Mon, 13 Oct 2025 00:07:22 -0700 Subject: [PATCH 08/18] update android --- apps/example-apple/src/App.ios.tsx | 162 ++++++++++++++++ apps/example-apple/src/App.tsx | 147 +------------- apps/example-apple/src/screens/MLCScreen.tsx | 2 +- .../com/callstack/ai/NativeMLCEngineModule.kt | 183 +++++++++++------- 4 files changed, 278 insertions(+), 216 deletions(-) create mode 100644 apps/example-apple/src/App.ios.tsx diff --git a/apps/example-apple/src/App.ios.tsx b/apps/example-apple/src/App.ios.tsx new file mode 100644 index 00000000..fa717648 --- /dev/null +++ b/apps/example-apple/src/App.ios.tsx @@ -0,0 +1,162 @@ +import './global.css' + +import { createNativeBottomTabNavigator } from '@bottom-tabs/react-navigation' +import { NavigationContainer } from '@react-navigation/native' +import { createNativeStackNavigator } from '@react-navigation/native-stack' +import { StatusBar } from 'expo-status-bar' +import React from 'react' +import { Image } from 'react-native' +import { KeyboardProvider } from 'react-native-keyboard-controller' +import { SafeAreaProvider } from 'react-native-safe-area-context' + +import LLMScreen from './screens/LLMScreen' +import MLCScreen from './screens/MLCScreen' +import PlaygroundScreen from './screens/PlaygroundScreen' +import SpeechScreen from './screens/SpeechScreen' +import TranscribeScreen from './screens/TranscribeScreen' + +const Tab = createNativeBottomTabNavigator() + +const RootStack = createNativeStackNavigator() +const LLMStack = createNativeStackNavigator() +const MLCStack = createNativeStackNavigator() +const PlaygroundStack = createNativeStackNavigator() +const TranscribeStack = createNativeStackNavigator() +const SpeechStack = createNativeStackNavigator() + +function LLMStackScreen() { + return ( + + ( + + ), + }} + /> + + ) +} + +function MLCStackScreen() { + return ( + + + + ) +} + +function PlaygroundStackScreen() { + return ( + + + + ) +} + +function TranscribeStackScreen() { + return ( + + + + ) +} + +function SpeechStackScreen() { + return ( + + + + ) +} + +function Tabs() { + return ( + + ({ sfSymbol: 'brain.head.profile' }), + }} + /> + ({ sfSymbol: 'cpu' }), + }} + /> + ({ sfSymbol: 'play.circle' }), + }} + /> + ({ sfSymbol: 'text.quote' }), + }} + /> + ({ sfSymbol: 'speaker.wave.3' }), + }} + /> + + ) +} + +export default function App() { + return ( + + + + + + + + + + + ) +} diff --git a/apps/example-apple/src/App.tsx b/apps/example-apple/src/App.tsx index fa717648..0023a93e 100644 --- a/apps/example-apple/src/App.tsx +++ b/apps/example-apple/src/App.tsx @@ -1,161 +1,16 @@ import './global.css' -import { createNativeBottomTabNavigator } from '@bottom-tabs/react-navigation' -import { NavigationContainer } from '@react-navigation/native' -import { createNativeStackNavigator } from '@react-navigation/native-stack' -import { StatusBar } from 'expo-status-bar' import React from 'react' -import { Image } from 'react-native' import { KeyboardProvider } from 'react-native-keyboard-controller' import { SafeAreaProvider } from 'react-native-safe-area-context' -import LLMScreen from './screens/LLMScreen' import MLCScreen from './screens/MLCScreen' -import PlaygroundScreen from './screens/PlaygroundScreen' -import SpeechScreen from './screens/SpeechScreen' -import TranscribeScreen from './screens/TranscribeScreen' - -const Tab = createNativeBottomTabNavigator() - -const RootStack = createNativeStackNavigator() -const LLMStack = createNativeStackNavigator() -const MLCStack = createNativeStackNavigator() -const PlaygroundStack = createNativeStackNavigator() -const TranscribeStack = createNativeStackNavigator() -const SpeechStack = createNativeStackNavigator() - -function LLMStackScreen() { - return ( - - ( - - ), - }} - /> - - ) -} - -function MLCStackScreen() { - return ( - - - - ) -} - -function PlaygroundStackScreen() { - return ( - - - - ) -} - -function TranscribeStackScreen() { - return ( - - - - ) -} - -function SpeechStackScreen() { - return ( - - - - ) -} - -function Tabs() { - return ( - - ({ sfSymbol: 'brain.head.profile' }), - }} - /> - ({ sfSymbol: 'cpu' }), - }} - /> - ({ sfSymbol: 'play.circle' }), - }} - /> - ({ sfSymbol: 'text.quote' }), - }} - /> - ({ sfSymbol: 'speaker.wave.3' }), - }} - /> - - ) -} export default function App() { return ( - - - - - - + ) diff --git a/apps/example-apple/src/screens/MLCScreen.tsx b/apps/example-apple/src/screens/MLCScreen.tsx index eb48abef..d0f27ec2 100644 --- a/apps/example-apple/src/screens/MLCScreen.tsx +++ b/apps/example-apple/src/screens/MLCScreen.tsx @@ -37,7 +37,7 @@ export default function MLCScreen() { // Step 2: Create and prepare model const model = mlc.languageModel(modelId) await model.download((event) => { - setStatusText(`Downloading model: ${event.status}`) + setStatusText(`Downloading model: ${event.percentage}`) }) setStatusText('Preparing model...') await model.prepare() diff --git a/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt b/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt index c968c6ea..63e0cf18 100644 --- a/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt +++ b/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt @@ -25,8 +25,8 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn private val json = Json { ignoreUnknownKeys = true } private val engine by lazy { MLCEngine() } - private val executorService = Executors.newSingleThreadExecutor() - private val engineScope = CoroutineScope(Dispatchers.Main + Job()) + private val executorService = Executors.newFixedThreadPool(1) + private val engineScope = CoroutineScope(Dispatchers.IO) private val appConfig by lazy { val jsonString = reactApplicationContext.applicationContext.assets.open(APP_CONFIG_FILENAME).bufferedReader().use { it.readText() } @@ -64,6 +64,71 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn messages: ReadableArray, options: ReadableMap?, promise: Promise + ) { + engineScope.launch { + try { + val messageList = mutableListOf() + + for (i in 0 until messages.size()) { + val messageMap = messages.getMap(i) + val role = if (messageMap?.getString("role") == "user") OpenAIProtocol.ChatCompletionRole.user else OpenAIProtocol.ChatCompletionRole.assistant + val content = messageMap?.getString("content") ?: "" + messageList.add(ChatCompletionMessage(role, content)) + } + + val responseFormat = options?.getMap("responseFormat")?.let { formatMap -> + val type = formatMap.getString("type") ?: "text" + val schema = formatMap.getString("schema") + OpenAIProtocol.ResponseFormat(type, schema) + } + + val chatResponse = engine.chat.completions.create( + messages = messageList, + temperature = options?.takeIf { it.hasKey("temperature") }?.getDouble("temperature")?.toFloat(), + max_tokens = options?.takeIf { it.hasKey("maxTokens") }?.getInt("maxTokens"), + top_p = options?.takeIf { it.hasKey("topP") }?.getDouble("topP")?.toFloat(), + response_format = responseFormat, + stream_options = OpenAIProtocol.StreamOptions( + true + ) + ) + + val responseList = chatResponse.toList() + val lastResponse = responseList.lastOrNull() + + val accumulatedContent = responseList.joinToString("") { response -> + response.choices.firstOrNull()?.delta?.content?.text ?: "" + } + + val finalRole = responseList.mapNotNull { it.choices.firstOrNull()?.delta?.role?.toString() }.lastOrNull() + val finalFinishReason = responseList.mapNotNull { it.choices.firstOrNull()?.finish_reason }.lastOrNull() + + val response = Arguments.createMap().apply { + putString("role", finalRole ?: "assistant") + putString("content", accumulatedContent) + putArray("tool_calls", Arguments.createArray()) + finalFinishReason?.let { putString("finish_reason", it) } + lastResponse?.usage?.let { usage -> + val usageArgs = Arguments.createMap().apply { + putInt("prompt_tokens", usage.prompt_tokens) + putInt("completion_tokens", usage.completion_tokens) + putInt("total_tokens", usage.total_tokens) + } + putMap("usage", usageArgs) + } + } + + promise.resolve(response) + } catch (e: Exception) { + promise.reject("GENERATION_ERROR", e.message) + } + } + } + + override fun streamText( + messages: ReadableArray, + options: ReadableMap?, + promise: Promise ) { executorService.submit { engineScope.launch { @@ -88,12 +153,55 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn temperature = options?.takeIf { it.hasKey("temperature") }?.getDouble("temperature")?.toFloat(), max_tokens = options?.takeIf { it.hasKey("maxTokens") }?.getInt("maxTokens"), top_p = options?.takeIf { it.hasKey("topP") }?.getDouble("topP")?.toFloat(), - response_format = responseFormat + response_format = responseFormat, + stream_options = OpenAIProtocol.StreamOptions( + true + ) ) - val response = chatResponse.toList().joinToString("") { - it.choices.joinToString("") { choice -> - choice.delta.content?.text ?: "" + var accumulatedContent = "" + var finalRole: String? = null + var finalFinishReason: String? = null + var usage: Map? = null + + for (streamResponse in chatResponse) { + // Check for usage (indicates completion) + streamResponse.usage?.let { + usage = mapOf( + "prompt_tokens" to it.prompt_tokens, + "completion_tokens" to it.completion_tokens, + "total_tokens" to it.total_tokens + ) + } + + streamResponse.choices.firstOrNull()?.let { choice -> + choice.delta.content?.let { content -> + accumulatedContent += content.text ?: "" + } + choice.finish_reason?.let { finishReason -> + finalFinishReason = finishReason + } + choice.delta.role?.let { role -> + finalRole = role.toString() + } + } + + if (usage != null) { + break + } + } + + val response = Arguments.createMap().apply { + putString("role", finalRole ?: "assistant") + putString("content", accumulatedContent) + finalFinishReason?.let { putString("finish_reason", it) } + usage?.let { usageMap -> + val usageArgs = Arguments.createMap().apply { + putInt("prompt_tokens", usageMap["prompt_tokens"] as Int) + putInt("completion_tokens", usageMap["completion_tokens"] as Int) + putInt("total_tokens", usageMap["total_tokens"] as Int) + } + putMap("usage", usageArgs) } } @@ -105,69 +213,6 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn } } - override fun streamText( - messages: ReadableArray, - options: ReadableMap?, - promise: Promise - ) { -// executorService.submit { -// engineScope.launch { -// try { -// val messageList = mutableListOf() -// -// for (i in 0 until messages.size()) { -// val messageMap = messages.getMap(i) -// val role = if (messageMap?.getString("role") == "user") OpenAIProtocol.ChatCompletionRole.user else OpenAIProtocol.ChatCompletionRole.assistant -// val content = messageMap?.getString("content") ?: "" -// messageList.add(ChatCompletionMessage(role, content)) -// } -// -// val chatResponse = engine.chat.completions.create( -// messages = messageList, -// stream_options = OpenAIProtocol.StreamOptions(include_usage = true) -// ) -// var finishReasonLength = false -// var streamingText = "" -// -// for (res in chatResponse) { -// for (choice in res.choices) { -// choice.delta.content?.let { content -> -// streamingText = content.asText() -// } -// choice.finish_reason?.let { finishReason -> -// if (finishReason == "length") { -// finishReasonLength = true -// } -// } -// } -// -// val event: WritableMap = Arguments.createMap().apply { -// putString("content", streamingText) -// } -// sendEvent("onChatUpdate", event) -// -// if (finishReasonLength) { -// streamingText = " [output truncated due to context length limit...]" -// val truncatedEvent: WritableMap = Arguments.createMap().apply { -// putString("content", streamingText) -// } -// sendEvent("onChatUpdate", truncatedEvent) -// } -// } -// -// val finalEvent: WritableMap = Arguments.createMap().apply { -// putString("content", streamingText) -// } -// sendEvent("onChatComplete", finalEvent) -// -// promise.resolve(null) -// } catch (e: Exception) { -// promise.reject("STREAMING_ERROR", "Error streaming text", e) -// } -// } -// } - } - override fun cancelStream(streamId: String, promise: Promise) { TODO("Not yet implemented") } From 6d2f0a2892091cf20ba8639a148691967069d9d3 Mon Sep 17 00:00:00 2001 From: Mike Grabowski Date: Mon, 13 Oct 2025 00:32:20 -0700 Subject: [PATCH 09/18] feat: ui tweak --- apps/example-apple/src/screens/MLCScreen.tsx | 83 ++++++++++++-------- 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/apps/example-apple/src/screens/MLCScreen.tsx b/apps/example-apple/src/screens/MLCScreen.tsx index d0f27ec2..285a6504 100644 --- a/apps/example-apple/src/screens/MLCScreen.tsx +++ b/apps/example-apple/src/screens/MLCScreen.tsx @@ -26,33 +26,42 @@ export default function MLCScreen() { const [structuredResponse, setStructuredResponse] = useState | null>(null) + const [model, setModel] = useState(null) + const [modelId, setModelId] = useState('') const setupModel = async () => { - // Step 1: Get available models - setStatusText('Getting available models...') - const models = await MLCEngine.getModels() - const modelId = models[0]!.model_id! - setStatusText(`Selected model: ${modelId}`) - - // Step 2: Create and prepare model - const model = mlc.languageModel(modelId) - await model.download((event) => { - setStatusText(`Downloading model: ${event.percentage}`) - }) - setStatusText('Preparing model...') - await model.prepare() - setStatusText('Model ready') - - return model + try { + setIsLoading(true) + setStatusText('Getting available models...') + const models = await MLCEngine.getModels() + const selectedModelId = models[0]!.model_id! + setModelId(selectedModelId) + setStatusText(`Selected model: ${selectedModelId}`) + + const modelInstance = mlc.languageModel(selectedModelId) + await modelInstance.download((event) => { + setStatusText(`Downloading model: ${event.percentage}`) + }) + setStatusText('Preparing model...') + await modelInstance.prepare() + setModel(modelInstance) + setStatusText('Model ready') + } catch (error) { + setStatusText(`Setup error: ${error}`) + } finally { + setIsLoading(false) + } } const runGenerateText = async () => { try { + if (!model) { + setStatusText('Please setup model first') + return + } setIsLoading(true) setResponse('') - const model = await setupModel() - // Generate text using AI SDK setStatusText('Generating response...') @@ -75,8 +84,6 @@ export default function MLCScreen() { setIsLoading(true) setResponse('') - const model = await setupModel() - // Stream text using AI SDK setStatusText('Streaming response...') @@ -102,9 +109,6 @@ export default function MLCScreen() { setIsLoading(true) setStructuredResponse(null) - const model = await setupModel() - - // Generate structured output setStatusText('Generating structured response...') const result = await generateObject({ @@ -132,47 +136,62 @@ export default function MLCScreen() { + + Setup Model + + + {model ? `Ready: ${modelId}` : 'Get models → Download → Prepare'} + + + + Generate Text - Get models → Download → Prepare → Generate + Use prepared model to generate text Stream Text - Same as above, but streaming incrementally + Stream text incrementally using prepared model Generate Structured Output - AI identifies itself with name & description + Generate structured output using prepared model From fc7cf7c94ba721dd8ba55687422e8b8801738e75 Mon Sep 17 00:00:00 2001 From: Mike Grabowski Date: Mon, 13 Oct 2025 01:03:12 -0700 Subject: [PATCH 10/18] fix --- apps/example-apple/src/screens/MLCScreen.tsx | 14 ++-- .../com/callstack/ai/NativeMLCEngineModule.kt | 65 +++++++++---------- 2 files changed, 39 insertions(+), 40 deletions(-) diff --git a/apps/example-apple/src/screens/MLCScreen.tsx b/apps/example-apple/src/screens/MLCScreen.tsx index 285a6504..f9a91491 100644 --- a/apps/example-apple/src/screens/MLCScreen.tsx +++ b/apps/example-apple/src/screens/MLCScreen.tsx @@ -32,9 +32,12 @@ export default function MLCScreen() { const setupModel = async () => { try { setIsLoading(true) + setStatusText('Getting available models...') + const models = await MLCEngine.getModels() const selectedModelId = models[0]!.model_id! + setModelId(selectedModelId) setStatusText(`Selected model: ${selectedModelId}`) @@ -42,8 +45,11 @@ export default function MLCScreen() { await modelInstance.download((event) => { setStatusText(`Downloading model: ${event.percentage}`) }) + setStatusText('Preparing model...') + await modelInstance.prepare() + setModel(modelInstance) setStatusText('Model ready') } catch (error) { @@ -55,14 +61,9 @@ export default function MLCScreen() { const runGenerateText = async () => { try { - if (!model) { - setStatusText('Please setup model first') - return - } setIsLoading(true) setResponse('') - // Generate text using AI SDK setStatusText('Generating response...') const result = await generateText({ @@ -84,10 +85,9 @@ export default function MLCScreen() { setIsLoading(true) setResponse('') - // Stream text using AI SDK setStatusText('Streaming response...') - const result = await streamText({ + const result = streamText({ model, prompt: 'Hello! Who are you? Please introduce yourself.', }) diff --git a/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt b/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt index 63e0cf18..09b7be01 100644 --- a/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt +++ b/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt @@ -10,7 +10,6 @@ import java.io.File import java.util.UUID import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.Job import kotlinx.coroutines.launch import kotlinx.coroutines.channels.toList import java.util.concurrent.Executors @@ -130,6 +129,9 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn options: ReadableMap?, promise: Promise ) { + // TODO: Patch MLC to return requestId and expose abortFunc from Android SDK + val streamId = UUID.randomUUID().toString() + executorService.submit { engineScope.launch { try { @@ -159,53 +161,50 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn ) ) - var accumulatedContent = "" - var finalRole: String? = null var finalFinishReason: String? = null - var usage: Map? = null for (streamResponse in chatResponse) { - // Check for usage (indicates completion) - streamResponse.usage?.let { - usage = mapOf( - "prompt_tokens" to it.prompt_tokens, - "completion_tokens" to it.completion_tokens, - "total_tokens" to it.total_tokens - ) - } - + // Emit stream update streamResponse.choices.firstOrNull()?.let { choice -> choice.delta.content?.let { content -> - accumulatedContent += content.text ?: "" + val deltaArgs = Arguments.createMap().apply { + val delta = Arguments.createMap().apply { + putString("content", content.text ?: "") + putString("role", choice.delta.role.toString()) + } + putMap("delta", delta) + } + emitOnChatUpdate(deltaArgs) } choice.finish_reason?.let { finishReason -> finalFinishReason = finishReason } - choice.delta.role?.let { role -> - finalRole = role.toString() - } } - if (usage != null) { - break - } - } - - val response = Arguments.createMap().apply { - putString("role", finalRole ?: "assistant") - putString("content", accumulatedContent) - finalFinishReason?.let { putString("finish_reason", it) } - usage?.let { usageMap -> - val usageArgs = Arguments.createMap().apply { - putInt("prompt_tokens", usageMap["prompt_tokens"] as Int) - putInt("completion_tokens", usageMap["completion_tokens"] as Int) - putInt("total_tokens", usageMap["total_tokens"] as Int) + // Check for usage (indicates completion) + streamResponse.usage?.let { usage -> + val completeArgs = Arguments.createMap().apply { + val usageArgs = Arguments.createMap().apply { + putInt("prompt_tokens", usage.prompt_tokens) + putInt("completion_tokens", usage.completion_tokens) + putInt("total_tokens", usage.total_tokens) + val extraArgs = Arguments.createMap().apply { + usage.extra?.let { extra -> + extra.prefill_tokens_per_s?.let { putDouble("prefill_tokens_per_s", extra.prefill_tokens_per_s.toDouble()) } + extra.decode_tokens_per_s?.let { putDouble("decode_tokens_per_s", it.toDouble()) } + extra.num_prefill_tokens?.let { putInt("num_prefill_tokens", it) } + } + } + putMap("extra", extraArgs) + } + putMap("usage", usageArgs) + putString("finish_reason", finalFinishReason) } - putMap("usage", usageArgs) + emitOnChatComplete(completeArgs) } } - promise.resolve(response) + promise.resolve(streamId) } catch (e: Exception) { promise.reject("GENERATION_ERROR", e.message) } From 8034dd63b77a7548947702b5e75752c6cd9c7518 Mon Sep 17 00:00:00 2001 From: Mike Grabowski Date: Mon, 13 Oct 2025 01:06:17 -0700 Subject: [PATCH 11/18] remove progress counter --- .../android/src/main/java/com/callstack/ai/ModelDownloader.kt | 2 -- 1 file changed, 2 deletions(-) diff --git a/packages/mlc/android/src/main/java/com/callstack/ai/ModelDownloader.kt b/packages/mlc/android/src/main/java/com/callstack/ai/ModelDownloader.kt index ac589d62..8b34eb89 100644 --- a/packages/mlc/android/src/main/java/com/callstack/ai/ModelDownloader.kt +++ b/packages/mlc/android/src/main/java/com/callstack/ai/ModelDownloader.kt @@ -44,11 +44,9 @@ class ModelDownloader( listOf( async { downloadSingleFile(MODEL_CONFIG_FILENAME) - onProgress(progressCounter.incrementAndGet(), -1) }, async { downloadSingleFile(PARAMS_CONFIG_FILENAME) - onProgress(progressCounter.incrementAndGet(), -1) } ).awaitAll() } From fd78a95bd3993708bbb1d16b2ffd22c54de93734 Mon Sep 17 00:00:00 2001 From: Mike Grabowski Date: Mon, 13 Oct 2025 02:51:33 -0700 Subject: [PATCH 12/18] chore: finish the work --- .../com/callstack/ai/NativeMLCEngineModule.kt | 70 +++++++++++++------ packages/mlc/mlc-package-config-android.json | 36 ++++------ 2 files changed, 63 insertions(+), 43 deletions(-) diff --git a/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt b/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt index 09b7be01..d00e314d 100644 --- a/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt +++ b/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt @@ -66,14 +66,7 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn ) { engineScope.launch { try { - val messageList = mutableListOf() - - for (i in 0 until messages.size()) { - val messageMap = messages.getMap(i) - val role = if (messageMap?.getString("role") == "user") OpenAIProtocol.ChatCompletionRole.user else OpenAIProtocol.ChatCompletionRole.assistant - val content = messageMap?.getString("content") ?: "" - messageList.add(ChatCompletionMessage(role, content)) - } + val messageList = parseMessages(messages) val responseFormat = options?.getMap("responseFormat")?.let { formatMap -> val type = formatMap.getString("type") ?: "text" @@ -112,6 +105,24 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn putInt("prompt_tokens", usage.prompt_tokens) putInt("completion_tokens", usage.completion_tokens) putInt("total_tokens", usage.total_tokens) + val extraArgs = Arguments.createMap().apply { + usage.extra?.let { extra -> + extra.prefill_tokens_per_s?.let { + putDouble( + "prefill_tokens_per_s", + extra.prefill_tokens_per_s.toDouble() + ) + } + extra.decode_tokens_per_s?.let { + putDouble( + "decode_tokens_per_s", + it.toDouble() + ) + } + extra.num_prefill_tokens?.let { putInt("num_prefill_tokens", it) } + } + } + putMap("extra", extraArgs) } putMap("usage", usageArgs) } @@ -135,14 +146,7 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn executorService.submit { engineScope.launch { try { - val messageList = mutableListOf() - - for (i in 0 until messages.size()) { - val messageMap = messages.getMap(i) - val role = if (messageMap?.getString("role") == "user") OpenAIProtocol.ChatCompletionRole.user else OpenAIProtocol.ChatCompletionRole.assistant - val content = messageMap?.getString("content") ?: "" - messageList.add(ChatCompletionMessage(role, content)) - } + val messageList = parseMessages(messages) val responseFormat = options?.getMap("responseFormat")?.let { formatMap -> val type = formatMap.getString("type") ?: "text" @@ -181,6 +185,8 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn } } + promise.resolve(streamId) + // Check for usage (indicates completion) streamResponse.usage?.let { usage -> val completeArgs = Arguments.createMap().apply { @@ -190,8 +196,18 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn putInt("total_tokens", usage.total_tokens) val extraArgs = Arguments.createMap().apply { usage.extra?.let { extra -> - extra.prefill_tokens_per_s?.let { putDouble("prefill_tokens_per_s", extra.prefill_tokens_per_s.toDouble()) } - extra.decode_tokens_per_s?.let { putDouble("decode_tokens_per_s", it.toDouble()) } + extra.prefill_tokens_per_s?.let { + putDouble( + "prefill_tokens_per_s", + extra.prefill_tokens_per_s.toDouble() + ) + } + extra.decode_tokens_per_s?.let { + putDouble( + "decode_tokens_per_s", + it.toDouble() + ) + } extra.num_prefill_tokens?.let { putInt("num_prefill_tokens", it) } } } @@ -203,8 +219,6 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn emitOnChatComplete(completeArgs) } } - - promise.resolve(streamId) } catch (e: Exception) { promise.reject("GENERATION_ERROR", e.message) } @@ -270,6 +284,22 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn promise.resolve(null) } + private fun parseMessages(messages: ReadableArray): List { + val messageList = mutableListOf() + + for (i in 0 until messages.size()) { + val messageMap = messages.getMap(i) + val role = if (messageMap?.getString("role") == "user") + OpenAIProtocol.ChatCompletionRole.user + else + OpenAIProtocol.ChatCompletionRole.assistant + val content = messageMap?.getString("content") ?: "" + messageList.add(ChatCompletionMessage(role, content)) + } + + return messageList + } + private fun getModelConfig(modelId: String): Pair? { val modelRecord = appConfig.model_list.find { it.model_id == modelId } ?: return null val modelDir = File(reactApplicationContext.getExternalFilesDir(""), modelRecord.model_id) diff --git a/packages/mlc/mlc-package-config-android.json b/packages/mlc/mlc-package-config-android.json index 48d10083..adc9a6c4 100644 --- a/packages/mlc/mlc-package-config-android.json +++ b/packages/mlc/mlc-package-config-android.json @@ -2,44 +2,34 @@ "device": "android", "model_list": [ { - "model": "HF://mlc-ai/Llama-3.2-3B-Instruct-q4f16_1-MLC", - "model_id": "Llama-3.2-3B-Instruct", - "estimated_vram_bytes": 2000000000, + "model": "HF://mlc-ai/Llama-3.2-1B-Instruct-q4f16_1-MLC", + "model_id": "Llama-3.2-1B-Instruct-q4f16_1-MLC", + "estimated_vram_bytes": 1200000000, "bundle_weight": false, "overrides": { "context_window_size": 4096, - "prefill_chunk_size": 1024 - } - }, - { - "model": "HF://mlc-ai/Phi-3-mini-4k-instruct-q4f16_1-MLC", - "model_id": "Phi-3-mini-4k-instruct", - "estimated_vram_bytes": 2500000000, - "bundle_weight": false, - "overrides": { - "context_window_size": 4096, - "prefill_chunk_size": 1024 + "prefill_chunk_size": 512 } }, { - "model": "HF://mlc-ai/Mistral-7B-Instruct-v0.2-q4f16_1-MLC", - "model_id": "Mistral-7B-Instruct", - "estimated_vram_bytes": 4500000000, + "model": "HF://mlc-ai/Phi-3.5-mini-instruct-q4f16_1-MLC", + "model_id": "Phi-3.5-mini-instruct-q4f16_1-MLC", + "estimated_vram_bytes": 2300000000, "bundle_weight": false, "overrides": { "context_window_size": 4096, - "prefill_chunk_size": 1024 + "prefill_chunk_size": 256 } }, { - "model": "HF://mlc-ai/Qwen2.5-1.5B-Instruct-q4f16_1-MLC", - "model_id": "Qwen2.5-1.5B-Instruct", - "estimated_vram_bytes": 1000000000, + "model": "HF://mlc-ai/Qwen2.5-0.5B-Instruct-q4f16_1-MLC", + "model_id": "Qwen2.5-0.5B-Instruct-q4f16_1-MLC", + "estimated_vram_bytes": 600000000, "bundle_weight": false, "overrides": { - "context_window_size": 4096, + "context_window_size": 2048, "prefill_chunk_size": 1024 } } ] -} \ No newline at end of file +} From 5704d0bf3cb567f16db0acf3fa36de772cdfb9c9 Mon Sep 17 00:00:00 2001 From: Mike Grabowski Date: Mon, 13 Oct 2025 03:01:09 -0700 Subject: [PATCH 13/18] modify on android --- packages/mlc/ios/MLCEngine.mm | 90 ++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 43 deletions(-) diff --git a/packages/mlc/ios/MLCEngine.mm b/packages/mlc/ios/MLCEngine.mm index 9769f49c..80d71e4d 100644 --- a/packages/mlc/ios/MLCEngine.mm +++ b/packages/mlc/ios/MLCEngine.mm @@ -323,9 +323,9 @@ - (BOOL)downloadFile:(NSString*)modelUrl filename:(NSString*)filename toURL:(NSU return YES; } -// Download all model files with status updates +// Download all model files with percentage updates - (void)downloadModelFiles:(NSDictionary*)modelRecord - status:(void (^)(NSString* status))statusCallback + progress:(void (^)(double percentage))progressCallback error:(NSError**)error { NSString* modelId = modelRecord[@"model_id"]; NSString* modelUrl = modelRecord[@"model_url"]; @@ -357,17 +357,8 @@ - (void)downloadModelFiles:(NSDictionary*)modelRecord return; } - // Download and save model config if it doesn't exist - if (![[NSFileManager defaultManager] fileExistsAtPath:[modelConfigURL path]]) { - if (statusCallback) statusCallback(@"Downloading model configuration..."); - if (![self downloadFile:modelUrl filename:@"mlc-chat-config.json" toURL:modelConfigURL error:error]) { - return; - } - } - // Download and save ndarray-cache if it doesn't exist if (![[NSFileManager defaultManager] fileExistsAtPath:[ndarrayCacheURL path]]) { - if (statusCallback) statusCallback(@"Downloading cache configuration..."); if (![self downloadFile:modelUrl filename:@"ndarray-cache.json" toURL:ndarrayCacheURL error:error]) { return; } @@ -388,31 +379,14 @@ - (void)downloadModelFiles:(NSDictionary*)modelRecord *error = ndarrayCacheJsonError; return; } - - // Download parameter files from ndarray cache - NSArray* records = ndarrayCache[@"records"]; - if ([records isKindOfClass:[NSArray class]] && records.count > 0) { - int currentFile = 0; - int totalFiles = (int)records.count; - - for (NSDictionary* record in records) { - NSString* dataPath = record[@"dataPath"]; - if (dataPath) { - NSURL* fileURL = [modelDirURL URLByAppendingPathComponent:dataPath]; - if (![[NSFileManager defaultManager] fileExistsAtPath:[fileURL path]]) { - currentFile++; - NSString* fileName = [dataPath lastPathComponent]; - if (statusCallback) { - statusCallback([NSString stringWithFormat:@"Downloading %@ (%d/%d)...", fileName, currentFile, totalFiles]); - } - if (![self downloadFile:modelUrl filename:dataPath toURL:fileURL error:error]) { - return; - } - } - } + + // Download and save model config if it doesn't exist + if (![[NSFileManager defaultManager] fileExistsAtPath:[modelConfigURL path]]) { + if (![self downloadFile:modelUrl filename:@"mlc-chat-config.json" toURL:modelConfigURL error:error]) { + return; } } - + // Read and parse model config NSData* modelConfigData = [NSData dataWithContentsOfURL:modelConfigURL]; if (!modelConfigData) { @@ -429,21 +403,51 @@ - (void)downloadModelFiles:(NSDictionary*)modelRecord return; } - // Download tokenizer files + // Create unified list of files to download + NSMutableArray* filesToDownload = [NSMutableArray new]; + + // Add parameter files from ndarray cache + NSArray* records = ndarrayCache[@"records"]; + if ([records isKindOfClass:[NSArray class]]) { + for (NSDictionary* record in records) { + NSString* dataPath = record[@"dataPath"]; + if (dataPath) { + NSURL* fileURL = [modelDirURL URLByAppendingPathComponent:dataPath]; + if (![[NSFileManager defaultManager] fileExistsAtPath:[fileURL path]]) { + [filesToDownload addObject:dataPath]; + } + } + } + } + + // Add tokenizer files NSArray* tokenizerFiles = modelConfig[@"tokenizer_files"]; - if ([tokenizerFiles isKindOfClass:[NSArray class]] && tokenizerFiles.count > 0) { - if (statusCallback) statusCallback(@"Downloading tokenizer files..."); + if ([tokenizerFiles isKindOfClass:[NSArray class]]) { for (NSString* filename in tokenizerFiles) { NSURL* fileURL = [modelDirURL URLByAppendingPathComponent:filename]; if (![[NSFileManager defaultManager] fileExistsAtPath:[fileURL path]]) { - if (![self downloadFile:modelUrl filename:filename toURL:fileURL error:error]) { - return; - } + [filesToDownload addObject:filename]; } } } - if (statusCallback) statusCallback(@"Download complete"); + // Download all files with progress tracking + NSInteger totalFiles = filesToDownload.count; + for (NSInteger i = 0; i < totalFiles; i++) { + NSString* filename = filesToDownload[i]; + NSURL* fileURL = [modelDirURL URLByAppendingPathComponent:filename]; + + // Calculate and emit progress + double percentage = totalFiles > 0 ? (double)(i + 1) / totalFiles * 100.0 : 100.0; + if (progressCallback) { + progressCallback(round(percentage)); + } + + // Download the file + if (![self downloadFile:modelUrl filename:filename toURL:fileURL error:error]) { + return; + } + } } - (void)downloadModel:(NSString*)modelId @@ -460,8 +464,8 @@ - (void)downloadModel:(NSString*)modelId NSError* downloadError = nil; [self downloadModelFiles:modelRecord - status:^(NSString* status) { - [self emitOnDownloadProgress:@{@"status" : status}]; + progress:^(double percentage) { + [self emitOnDownloadProgress:@{@"percentage" : @(percentage)}]; } error:&downloadError]; From 9f9957eb6b62f1e73a0a121cc2ac41d74913a833 Mon Sep 17 00:00:00 2001 From: Mike Grabowski Date: Mon, 13 Oct 2025 03:03:10 -0700 Subject: [PATCH 14/18] fix --- .../src/main/java/com/callstack/ai/NativeMLCEngineModule.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt b/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt index d00e314d..cb3f413f 100644 --- a/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt +++ b/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt @@ -167,6 +167,8 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn var finalFinishReason: String? = null + promise.resolve(streamId) + for (streamResponse in chatResponse) { // Emit stream update streamResponse.choices.firstOrNull()?.let { choice -> @@ -185,8 +187,6 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn } } - promise.resolve(streamId) - // Check for usage (indicates completion) streamResponse.usage?.let { usage -> val completeArgs = Arguments.createMap().apply { From 63fa673291ebead1d150367e1e4f370aa8d476f6 Mon Sep 17 00:00:00 2001 From: Mike Grabowski Date: Mon, 13 Oct 2025 06:11:50 -0700 Subject: [PATCH 15/18] chore: address feedback --- .../java/com/callstack/ai/NativeMLCEngineModule.kt | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt b/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt index cb3f413f..a54ab507 100644 --- a/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt +++ b/packages/mlc/android/src/main/java/com/callstack/ai/NativeMLCEngineModule.kt @@ -289,10 +289,15 @@ class NativeMLCEngineModule(reactContext: ReactApplicationContext) : NativeMLCEn for (i in 0 until messages.size()) { val messageMap = messages.getMap(i) - val role = if (messageMap?.getString("role") == "user") - OpenAIProtocol.ChatCompletionRole.user - else - OpenAIProtocol.ChatCompletionRole.assistant + val roleString = messageMap?.getString("role") ?: throw IllegalArgumentException("Message role is required") + + val role = when (roleString) { + "user" -> OpenAIProtocol.ChatCompletionRole.user + "assistant" -> OpenAIProtocol.ChatCompletionRole.assistant + "system" -> OpenAIProtocol.ChatCompletionRole.system + else -> throw IllegalArgumentException("Unknown message role: $roleString") + } + val content = messageMap?.getString("content") ?: "" messageList.add(ChatCompletionMessage(role, content)) } From bde2251db2bf15d90cd9b006df456956c0f103dd Mon Sep 17 00:00:00 2001 From: Mike Grabowski Date: Mon, 13 Oct 2025 06:59:07 -0700 Subject: [PATCH 16/18] fix --- .../main/java/com/callstack/ai/ModelDownloader.kt | 4 ++-- packages/mlc/ios/MLCEngine.mm | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/packages/mlc/android/src/main/java/com/callstack/ai/ModelDownloader.kt b/packages/mlc/android/src/main/java/com/callstack/ai/ModelDownloader.kt index 8b34eb89..5d2d0fad 100644 --- a/packages/mlc/android/src/main/java/com/callstack/ai/ModelDownloader.kt +++ b/packages/mlc/android/src/main/java/com/callstack/ai/ModelDownloader.kt @@ -71,7 +71,7 @@ class ModelDownloader( chunk.map { filename -> async { downloadSingleFile(filename) - onProgress(progressCounter.incrementAndGet(), allFiles.size) + onProgress(progressCounter.incrementAndGet(), remainingFiles.size) } }.awaitAll() } @@ -79,7 +79,7 @@ class ModelDownloader( } // Final progress update - onProgress(progressCounter.get(), allFiles.size) + onProgress(progressCounter.get(), remainingFiles.size) } diff --git a/packages/mlc/ios/MLCEngine.mm b/packages/mlc/ios/MLCEngine.mm index 80d71e4d..5d3ed0f2 100644 --- a/packages/mlc/ios/MLCEngine.mm +++ b/packages/mlc/ios/MLCEngine.mm @@ -437,16 +437,16 @@ - (void)downloadModelFiles:(NSDictionary*)modelRecord NSString* filename = filesToDownload[i]; NSURL* fileURL = [modelDirURL URLByAppendingPathComponent:filename]; - // Calculate and emit progress + // Download the file first + if (![self downloadFile:modelUrl filename:filename toURL:fileURL error:error]) { + return; + } + + // Calculate and emit progress after successful download double percentage = totalFiles > 0 ? (double)(i + 1) / totalFiles * 100.0 : 100.0; if (progressCallback) { progressCallback(round(percentage)); } - - // Download the file - if (![self downloadFile:modelUrl filename:filename toURL:fileURL error:error]) { - return; - } } } From 2c666934f7a18d56e0c354c96f8dfe281430db7c Mon Sep 17 00:00:00 2001 From: Mike Grabowski Date: Mon, 13 Oct 2025 07:00:01 -0700 Subject: [PATCH 17/18] chore: update download progress --- website/src/docs/mlc/model-management.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/src/docs/mlc/model-management.md b/website/src/docs/mlc/model-management.md index a58d0467..08946b7b 100644 --- a/website/src/docs/mlc/model-management.md +++ b/website/src/docs/mlc/model-management.md @@ -56,7 +56,7 @@ You can track download progress: ```typescript await model.download((event) => { - console.log(`Download: ${event.status}`); + console.log(`Download: ${event.progress}`); }); ``` From 3450aaaebd12da9f0a2773740615cf4626107f5c Mon Sep 17 00:00:00 2001 From: Mike Grabowski Date: Mon, 13 Oct 2025 07:05:18 -0700 Subject: [PATCH 18/18] chore: update models configuration across both platforms --- packages/mlc/mlc-package-config-android.json | 6 ++--- packages/mlc/mlc-package-config-ios.json | 26 ++++++++++---------- website/src/docs/mlc/model-management.md | 20 +++++++-------- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/packages/mlc/mlc-package-config-android.json b/packages/mlc/mlc-package-config-android.json index adc9a6c4..6770c65b 100644 --- a/packages/mlc/mlc-package-config-android.json +++ b/packages/mlc/mlc-package-config-android.json @@ -3,7 +3,7 @@ "model_list": [ { "model": "HF://mlc-ai/Llama-3.2-1B-Instruct-q4f16_1-MLC", - "model_id": "Llama-3.2-1B-Instruct-q4f16_1-MLC", + "model_id": "Llama-3.2-1B-Instruct", "estimated_vram_bytes": 1200000000, "bundle_weight": false, "overrides": { @@ -13,7 +13,7 @@ }, { "model": "HF://mlc-ai/Phi-3.5-mini-instruct-q4f16_1-MLC", - "model_id": "Phi-3.5-mini-instruct-q4f16_1-MLC", + "model_id": "Phi-3.5-mini-instruct", "estimated_vram_bytes": 2300000000, "bundle_weight": false, "overrides": { @@ -23,7 +23,7 @@ }, { "model": "HF://mlc-ai/Qwen2.5-0.5B-Instruct-q4f16_1-MLC", - "model_id": "Qwen2.5-0.5B-Instruct-q4f16_1-MLC", + "model_id": "Qwen2.5-0.5B-Instruct", "estimated_vram_bytes": 600000000, "bundle_weight": false, "overrides": { diff --git a/packages/mlc/mlc-package-config-ios.json b/packages/mlc/mlc-package-config-ios.json index e02b7fc1..361da634 100644 --- a/packages/mlc/mlc-package-config-ios.json +++ b/packages/mlc/mlc-package-config-ios.json @@ -2,19 +2,19 @@ "device": "iphone", "model_list": [ { - "model": "HF://mlc-ai/Llama-3.2-3B-Instruct-q4f16_1-MLC", - "model_id": "Llama-3.2-3B-Instruct", - "estimated_vram_bytes": 2000000000, + "model": "HF://mlc-ai/Qwen2.5-0.5B-Instruct-q4f16_1-MLC", + "model_id": "Qwen2.5-0.5B-Instruct", + "estimated_vram_bytes": 600000000, "bundle_weight": false, "overrides": { - "context_window_size": 4096, + "context_window_size": 2048, "prefill_chunk_size": 1024 } }, { - "model": "HF://mlc-ai/Phi-3-mini-4k-instruct-q4f16_1-MLC", - "model_id": "Phi-3-mini-4k-instruct", - "estimated_vram_bytes": 2500000000, + "model": "HF://mlc-ai/Llama-3.2-1B-Instruct-q4f16_1-MLC", + "model_id": "Llama-3.2-1B-Instruct", + "estimated_vram_bytes": 1200000000, "bundle_weight": false, "overrides": { "context_window_size": 4096, @@ -22,9 +22,9 @@ } }, { - "model": "HF://mlc-ai/Mistral-7B-Instruct-v0.2-q4f16_1-MLC", - "model_id": "Mistral-7B-Instruct", - "estimated_vram_bytes": 4500000000, + "model": "HF://mlc-ai/Llama-3.2-3B-Instruct-q4f16_1-MLC", + "model_id": "Llama-3.2-3B-Instruct", + "estimated_vram_bytes": 2000000000, "bundle_weight": false, "overrides": { "context_window_size": 4096, @@ -32,9 +32,9 @@ } }, { - "model": "HF://mlc-ai/Qwen2.5-1.5B-Instruct-q4f16_1-MLC", - "model_id": "Qwen2.5-1.5B-Instruct", - "estimated_vram_bytes": 1000000000, + "model": "HF://mlc-ai/Phi-3.5-mini-instruct-q4f16_1-MLC", + "model_id": "Phi-3.5-mini-instruct", + "estimated_vram_bytes": 2300000000, "bundle_weight": false, "overrides": { "context_window_size": 4096, diff --git a/website/src/docs/mlc/model-management.md b/website/src/docs/mlc/model-management.md index 08946b7b..511b3050 100644 --- a/website/src/docs/mlc/model-management.md +++ b/website/src/docs/mlc/model-management.md @@ -6,14 +6,14 @@ This guide covers the complete lifecycle of MLC models - from discovery and down The package includes a prebuilt runtime optimized for the following models: -| Model ID | Size | -|----------|------| -| `Llama-3.2-3B-Instruct` | ~2GB | -| `Phi-3-mini-4k-instruct` | ~2.5GB | -| `Mistral-7B-Instruct` | ~4.5GB (requires 8GB+ RAM) | -| `Qwen2.5-1.5B-Instruct` | ~1GB | +| Model ID | Size | Best For | +|----------|------|----------| +| `Qwen2.5-0.5B-Instruct` | ~600MB | Fast responses, basic conversations | +| `Llama-3.2-1B-Instruct` | ~1.2GB | Balanced performance and quality | +| `Llama-3.2-3B-Instruct` | ~2GB | High quality responses, complex reasoning | +| `Phi-3.5-mini-instruct` | ~2.3GB | Code generation, technical tasks | -> **Note**: These are the only models supported for direct download. For other models, you'll need to build MLC from source (documentation coming soon). +> **Note**: These models use q4f16_1 quantization (4-bit weights, 16-bit activations) optimized for mobile devices. For other models, you'll need to build MLC from source (documentation coming soon). ## Model Lifecycle @@ -27,7 +27,7 @@ import { MLCEngine } from '@react-native-ai/mlc'; const models = await MLCEngine.getModels(); console.log('Available models:', models); -// Output: [{ model_id: 'Llama-3.2-3B-Instruct' }, ...] +// Output: [{ model_id: 'Llama-3.2-1B-Instruct' }, ...] ``` ### Creating Model Instance @@ -37,7 +37,7 @@ Create a model instance using the `mlc.languageModel()` method: ```typescript import { mlc } from '@react-native-ai/mlc'; -const model = mlc.languageModel('Llama-3.2-3B-Instruct'); +const model = mlc.languageModel('Llama-3.2-1B-Instruct'); ``` ### Downloading Models @@ -56,7 +56,7 @@ You can track download progress: ```typescript await model.download((event) => { - console.log(`Download: ${event.progress}`); + console.log(`Download: ${event.percentage}%`); }); ```