diff --git a/.gitmodules b/.gitmodules index e69de29bb2d1d..b3c259fa9ee9b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "include/cpu_features"] + path = include/cpu_features + url = https://github.com/google/cpu_features diff --git a/CODEOWNERS b/CODEOWNERS index 53d2e1e7ed49e..8b91864a0bbff 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -30,7 +30,7 @@ /examples/export-docs/ @ggerganov /examples/gen-docs/ @ggerganov /examples/gguf/ @ggerganov -/examples/llama.android/ @ggerganov +/examples/llama.android/ @ggerganov @hanyin-arm @naco-siren /examples/llama.swiftui/ @ggerganov /examples/llama.vim @ggerganov /examples/lookahead/ @ggerganov diff --git a/README.md b/README.md index 6d30a8bdab058..fb2a57be13de0 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama) - Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi) - Go (no CGo needed): [hybridgroup/yzma](https://github.com/hybridgroup/yzma) +- Android: [llama.android](/examples/llama.android) diff --git a/docs/android.md b/docs/android.md index d2a835653fe5d..28b966ffc7c96 100644 --- a/docs/android.md +++ b/docs/android.md @@ -1,6 +1,26 @@ # Android +## Build with Android Studio + +Import the `examples/llama.android` directory into Android Studio, then perform a Gradle sync and build the project. +![Project imported into Android Studio](./android/imported-into-android-studio.png) + +This Android binding supports hardware acceleration up to `SME2` for **Arm** and `AMX` for **x86-64** CPUs on Android and ChromeOS devices. +It automatically detects the host's hardware to load compatible kernels. As a result, it runs seamlessly on both the latest premium devices and older devices that may lack modern CPU features or have limited RAM, without requiring any manual configuration. + +A minimal Android app frontend is included to showcase the binding’s core functionalities: +1. **Parse GGUF metadata** via `GgufMetadataReader` from either a `ContentResolver` provided `Uri` or a local `File`. +2. **Obtain a `TierDetection` or `InferenceEngine`** instance through the high-level facade APIs. +3. **Send a raw user prompt** for automatic template formatting, prefill, and decoding. Then collect the generated tokens in a Kotlin `Flow`. + +For a production-ready experience that leverages advanced features such as system prompts and benchmarks, check out [Arm AI Chat](https://play.google.com/store/apps/details?id=com.arm.aichat) on Google Play. +This project is made possible through a collaborative effort by Arm's **CT-ML**, **CE-ML** and **STE** groups: + +| ![Home screen](./android/arm-ai-chat-home-screen.png) | ![System prompt](./android/system-prompt-setup.png) | !["Haiku"](./android/chat-with-system-prompt-haiku.png) | +|:------------------------------------------------------:|:----------------------------------------------------:|:--------------------------------------------------------:| +| Home screen | System prompt | "Haiku" | + ## Build on Android using Termux [Termux](https://termux.dev/en/) is an Android terminal emulator and Linux environment app (no root required). As of writing, Termux is available experimentally in the Google Play Store; otherwise, it may be obtained directly from the project repo or on F-Droid. diff --git a/examples/llama.android/app/build.gradle.kts b/examples/llama.android/app/build.gradle.kts index 8d1b37195efd4..3524fe39c4528 100644 --- a/examples/llama.android/app/build.gradle.kts +++ b/examples/llama.android/app/build.gradle.kts @@ -1,16 +1,18 @@ plugins { - id("com.android.application") - id("org.jetbrains.kotlin.android") + alias(libs.plugins.android.application) + alias(libs.plugins.jetbrains.kotlin.android) } android { namespace = "com.example.llama" - compileSdk = 34 + compileSdk = 36 defaultConfig { - applicationId = "com.example.llama" + applicationId = "com.example.llama.aichat" + minSdk = 33 - targetSdk = 34 + targetSdk = 36 + versionCode = 1 versionName = "1.0" @@ -21,8 +23,17 @@ android { } buildTypes { + debug { + isMinifyEnabled = true + isShrinkResources = true + proguardFiles( + getDefaultProguardFile("proguard-android.txt"), + "proguard-rules.pro" + ) + } release { - isMinifyEnabled = false + isMinifyEnabled = true + isShrinkResources = true proguardFiles( getDefaultProguardFile("proguard-android-optimize.txt"), "proguard-rules.pro" @@ -36,30 +47,15 @@ android { kotlinOptions { jvmTarget = "1.8" } - buildFeatures { - compose = true - } - composeOptions { - kotlinCompilerExtensionVersion = "1.5.1" - } } dependencies { + implementation(libs.bundles.androidx) + implementation(libs.material) + + implementation(project(":lib")) - implementation("androidx.core:core-ktx:1.12.0") - implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.2") - implementation("androidx.activity:activity-compose:1.8.2") - implementation(platform("androidx.compose:compose-bom:2023.08.00")) - implementation("androidx.compose.ui:ui") - implementation("androidx.compose.ui:ui-graphics") - implementation("androidx.compose.ui:ui-tooling-preview") - implementation("androidx.compose.material3:material3") - implementation(project(":llama")) - testImplementation("junit:junit:4.13.2") - androidTestImplementation("androidx.test.ext:junit:1.1.5") - androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") - androidTestImplementation(platform("androidx.compose:compose-bom:2023.08.00")) - androidTestImplementation("androidx.compose.ui:ui-test-junit4") - debugImplementation("androidx.compose.ui:ui-tooling") - debugImplementation("androidx.compose.ui:ui-test-manifest") + testImplementation(libs.junit) + androidTestImplementation(libs.androidx.junit) + androidTestImplementation(libs.androidx.espresso.core) } diff --git a/examples/llama.android/app/proguard-rules.pro b/examples/llama.android/app/proguard-rules.pro index f1b424510da51..358020d2d2479 100644 --- a/examples/llama.android/app/proguard-rules.pro +++ b/examples/llama.android/app/proguard-rules.pro @@ -19,3 +19,11 @@ # If you keep the line number information, uncomment this to # hide the original source file name. #-renamesourcefileattribute SourceFile + +-keep class com.arm.aichat.* { *; } +-keep class com.arm.aichat.gguf.* { *; } + +-assumenosideeffects class android.util.Log { + public static int v(...); + public static int d(...); +} diff --git a/examples/llama.android/app/src/main/AndroidManifest.xml b/examples/llama.android/app/src/main/AndroidManifest.xml index 41a358a299154..8f7c606b41ecb 100644 --- a/examples/llama.android/app/src/main/AndroidManifest.xml +++ b/examples/llama.android/app/src/main/AndroidManifest.xml @@ -1,24 +1,21 @@ - - - + + android:exported="true"> diff --git a/examples/llama.android/app/src/main/java/com/example/llama/Downloadable.kt b/examples/llama.android/app/src/main/java/com/example/llama/Downloadable.kt deleted file mode 100644 index 78c231ae55d8c..0000000000000 --- a/examples/llama.android/app/src/main/java/com/example/llama/Downloadable.kt +++ /dev/null @@ -1,119 +0,0 @@ -package com.example.llama - -import android.app.DownloadManager -import android.net.Uri -import android.util.Log -import androidx.compose.material3.Button -import androidx.compose.material3.Text -import androidx.compose.runtime.Composable -import androidx.compose.runtime.getValue -import androidx.compose.runtime.mutableDoubleStateOf -import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.remember -import androidx.compose.runtime.rememberCoroutineScope -import androidx.compose.runtime.setValue -import androidx.core.database.getLongOrNull -import androidx.core.net.toUri -import kotlinx.coroutines.delay -import kotlinx.coroutines.launch -import java.io.File - -data class Downloadable(val name: String, val source: Uri, val destination: File) { - companion object { - @JvmStatic - private val tag: String? = this::class.qualifiedName - - sealed interface State - data object Ready: State - data class Downloading(val id: Long): State - data class Downloaded(val downloadable: Downloadable): State - data class Error(val message: String): State - - @JvmStatic - @Composable - fun Button(viewModel: MainViewModel, dm: DownloadManager, item: Downloadable) { - var status: State by remember { - mutableStateOf( - if (item.destination.exists()) Downloaded(item) - else Ready - ) - } - var progress by remember { mutableDoubleStateOf(0.0) } - - val coroutineScope = rememberCoroutineScope() - - suspend fun waitForDownload(result: Downloading, item: Downloadable): State { - while (true) { - val cursor = dm.query(DownloadManager.Query().setFilterById(result.id)) - - if (cursor == null) { - Log.e(tag, "dm.query() returned null") - return Error("dm.query() returned null") - } - - if (!cursor.moveToFirst() || cursor.count < 1) { - cursor.close() - Log.i(tag, "cursor.moveToFirst() returned false or cursor.count < 1, download canceled?") - return Ready - } - - val pix = cursor.getColumnIndex(DownloadManager.COLUMN_BYTES_DOWNLOADED_SO_FAR) - val tix = cursor.getColumnIndex(DownloadManager.COLUMN_TOTAL_SIZE_BYTES) - val sofar = cursor.getLongOrNull(pix) ?: 0 - val total = cursor.getLongOrNull(tix) ?: 1 - cursor.close() - - if (sofar == total) { - return Downloaded(item) - } - - progress = (sofar * 1.0) / total - - delay(1000L) - } - } - - fun onClick() { - when (val s = status) { - is Downloaded -> { - viewModel.load(item.destination.path) - } - - is Downloading -> { - coroutineScope.launch { - status = waitForDownload(s, item) - } - } - - else -> { - item.destination.delete() - - val request = DownloadManager.Request(item.source).apply { - setTitle("Downloading model") - setDescription("Downloading model: ${item.name}") - setAllowedNetworkTypes(DownloadManager.Request.NETWORK_WIFI) - setDestinationUri(item.destination.toUri()) - } - - viewModel.log("Saving ${item.name} to ${item.destination.path}") - Log.i(tag, "Saving ${item.name} to ${item.destination.path}") - - val id = dm.enqueue(request) - status = Downloading(id) - onClick() - } - } - } - - Button(onClick = { onClick() }, enabled = status !is Downloading) { - when (status) { - is Downloading -> Text(text = "Downloading ${(progress * 100).toInt()}%") - is Downloaded -> Text("Load ${item.name}") - is Ready -> Text("Download ${item.name}") - is Error -> Text("Download ${item.name}") - } - } - } - - } -} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt index 9da04f7d3c32e..4923e8e764d5f 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt @@ -1,154 +1,266 @@ package com.example.llama -import android.app.ActivityManager -import android.app.DownloadManager -import android.content.ClipData -import android.content.ClipboardManager import android.net.Uri import android.os.Bundle -import android.os.StrictMode -import android.os.StrictMode.VmPolicy -import android.text.format.Formatter -import androidx.activity.ComponentActivity -import androidx.activity.compose.setContent -import androidx.activity.viewModels -import androidx.compose.foundation.layout.Box -import androidx.compose.foundation.layout.Column -import androidx.compose.foundation.layout.Row -import androidx.compose.foundation.layout.fillMaxSize -import androidx.compose.foundation.layout.padding -import androidx.compose.foundation.lazy.LazyColumn -import androidx.compose.foundation.lazy.items -import androidx.compose.foundation.lazy.rememberLazyListState -import androidx.compose.material3.Button -import androidx.compose.material3.LocalContentColor -import androidx.compose.material3.MaterialTheme -import androidx.compose.material3.OutlinedTextField -import androidx.compose.material3.Surface -import androidx.compose.material3.Text -import androidx.compose.runtime.Composable -import androidx.compose.ui.Modifier -import androidx.compose.ui.unit.dp -import androidx.core.content.getSystemService -import com.example.llama.ui.theme.LlamaAndroidTheme +import android.util.Log +import android.widget.EditText +import android.widget.TextView +import android.widget.Toast +import androidx.activity.enableEdgeToEdge +import androidx.activity.result.contract.ActivityResultContracts +import androidx.appcompat.app.AppCompatActivity +import androidx.lifecycle.lifecycleScope +import androidx.recyclerview.widget.LinearLayoutManager +import androidx.recyclerview.widget.RecyclerView +import com.arm.aichat.AiChat +import com.arm.aichat.InferenceEngine +import com.arm.aichat.TierDetection +import com.arm.aichat.gguf.GgufMetadata +import com.arm.aichat.gguf.GgufMetadataReader +import com.google.android.material.floatingactionbutton.FloatingActionButton +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.onCompletion +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext import java.io.File +import java.io.FileOutputStream +import java.io.InputStream +import java.util.UUID -class MainActivity( - activityManager: ActivityManager? = null, - downloadManager: DownloadManager? = null, - clipboardManager: ClipboardManager? = null, -): ComponentActivity() { - private val tag: String? = this::class.simpleName +class MainActivity : AppCompatActivity() { - private val activityManager by lazy { activityManager ?: getSystemService()!! } - private val downloadManager by lazy { downloadManager ?: getSystemService()!! } - private val clipboardManager by lazy { clipboardManager ?: getSystemService()!! } + // Android views + private lateinit var tierTv: TextView + private lateinit var pickerBtn: FloatingActionButton + private lateinit var ggufTv: TextView + private lateinit var messagesRv: RecyclerView + private lateinit var userInputEt: EditText + private lateinit var userSendBtn: FloatingActionButton - private val viewModel: MainViewModel by viewModels() + // Arm AI Chat engine and utils + private lateinit var detection: TierDetection + private lateinit var engine: InferenceEngine - // Get a MemoryInfo object for the device's current memory status. - private fun availableMemory(): ActivityManager.MemoryInfo { - return ActivityManager.MemoryInfo().also { memoryInfo -> - activityManager.getMemoryInfo(memoryInfo) - } - } + // Conversation states + private val messages = mutableListOf() + private val lastAssistantMsg = StringBuilder() + private val messageAdapter = MessageAdapter(messages) override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) + enableEdgeToEdge() + setContentView(R.layout.activity_main) + + // Find views + tierTv = findViewById(R.id.tier) + pickerBtn = findViewById(R.id.pick_model) + ggufTv = findViewById(R.id.gguf) + messagesRv = findViewById(R.id.messages) + messagesRv.layoutManager = LinearLayoutManager(this) + messagesRv.adapter = messageAdapter + userInputEt = findViewById(R.id.user_input) + userSendBtn = findViewById(R.id.user_send) + + // Arm AI Chat initialization + lifecycleScope.launch(Dispatchers.Default) { + // Obtain the device's CPU feature tier + detection = AiChat.getTierDetection(applicationContext) + withContext(Dispatchers.Main) { + tierTv.text = detection.getDetectedTier()?.description ?: "N/A" + } + + // Obtain the inference engine + engine = AiChat.getInferenceEngine(applicationContext) + } + + // Upon file picker button tapped, prompt user to select a GGUF metadata on the device + pickerBtn.setOnClickListener { + getContent.launch(arrayOf("*/*")) + } + + // Upon user send button tapped, validate input and send to engine + userSendBtn.setOnClickListener { + handleUserInput() + } + } + + private val getContent = registerForActivityResult( + ActivityResultContracts.OpenDocument() + ) { uri -> + Log.i(TAG, "Selected file uri:\n $uri") + uri?.let { handleSelectedModel(it) } + } - StrictMode.setVmPolicy( - VmPolicy.Builder(StrictMode.getVmPolicy()) - .detectLeakedClosableObjects() - .build() - ) - - val free = Formatter.formatFileSize(this, availableMemory().availMem) - val total = Formatter.formatFileSize(this, availableMemory().totalMem) - - viewModel.log("Current memory: $free / $total") - viewModel.log("Downloads directory: ${getExternalFilesDir(null)}") - - val extFilesDir = getExternalFilesDir(null) - - val models = listOf( - Downloadable( - "Phi-2 7B (Q4_0, 1.6 GiB)", - Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true"), - File(extFilesDir, "phi-2-q4_0.gguf"), - ), - Downloadable( - "TinyLlama 1.1B (f16, 2.2 GiB)", - Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true"), - File(extFilesDir, "tinyllama-1.1-f16.gguf"), - ), - Downloadable( - "Phi 2 DPO (Q3_K_M, 1.48 GiB)", - Uri.parse("https://huggingface.co/TheBloke/phi-2-dpo-GGUF/resolve/main/phi-2-dpo.Q3_K_M.gguf?download=true"), - File(extFilesDir, "phi-2-dpo.Q3_K_M.gguf") - ), - ) - - setContent { - LlamaAndroidTheme { - // A surface container using the 'background' color from the theme - Surface( - modifier = Modifier.fillMaxSize(), - color = MaterialTheme.colorScheme.background - ) { - MainCompose( - viewModel, - clipboardManager, - downloadManager, - models, - ) + /** + * Handles the file Uri from [getContent] result + */ + private fun handleSelectedModel(uri: Uri) { + // Update UI states + pickerBtn.isEnabled = false + userInputEt.hint = "Parsing GGUF..." + ggufTv.text = "Parsing metadata from selected file \n$uri" + + lifecycleScope.launch(Dispatchers.IO) { + // Parse GGUF metadata + Log.i(TAG, "Parsing GGUF metadata...") + contentResolver.openInputStream(uri)?.use { + GgufMetadataReader.create().readStructuredMetadata(it) + }?.let { metadata -> + // Update UI to show GGUF metadata to user + Log.i(TAG, "GGUF parsed: \n$metadata") + withContext(Dispatchers.Main) { + ggufTv.text = metadata.toString() } + // Ensure the model file is available + val modelName = metadata.filename() + FILE_EXTENSION_GGUF + contentResolver.openInputStream(uri)?.use { input -> + ensureModelFile(modelName, input) + }?.let { modelFile -> + loadModel(modelName, modelFile) + + withContext(Dispatchers.Main) { + userInputEt.hint = "Type and send a message!" + userInputEt.isEnabled = true + userSendBtn.isEnabled = true + } + } } } } -} -@Composable -fun MainCompose( - viewModel: MainViewModel, - clipboard: ClipboardManager, - dm: DownloadManager, - models: List -) { - Column { - val scrollState = rememberLazyListState() - - Box(modifier = Modifier.weight(1f)) { - LazyColumn(state = scrollState) { - items(viewModel.messages) { - Text( - it, - style = MaterialTheme.typography.bodyLarge.copy(color = LocalContentColor.current), - modifier = Modifier.padding(16.dp) - ) + /** + * Prepare the model file within app's private storage + */ + private suspend fun ensureModelFile(modelName: String, input: InputStream) = + withContext(Dispatchers.IO) { + File(ensureModelsDirectory(), modelName).also { file -> + // Copy the file into local storage if not yet done + if (!file.exists()) { + Log.i(TAG, "Start copying file to $modelName") + withContext(Dispatchers.Main) { + userInputEt.hint = "Copying file..." + } + + FileOutputStream(file).use { input.copyTo(it) } + Log.i(TAG, "Finished copying file to $modelName") + } else { + Log.i(TAG, "File already exists $modelName") } } } - OutlinedTextField( - value = viewModel.message, - onValueChange = { viewModel.updateMessage(it) }, - label = { Text("Message") }, - ) - Row { - Button({ viewModel.send() }) { Text("Send") } - Button({ viewModel.bench(8, 4, 1) }) { Text("Bench") } - Button({ viewModel.clear() }) { Text("Clear") } - Button({ - viewModel.messages.joinToString("\n").let { - clipboard.setPrimaryClip(ClipData.newPlainText("", it)) + + /** + * Load the model file from the app private storage + */ + private suspend fun loadModel(modelName: String, modelFile: File) = + withContext(Dispatchers.IO) { + Log.i(TAG, "Loading model $modelName") + withContext(Dispatchers.Main) { + userInputEt.hint = "Loading model..." + } + engine.loadModel(modelFile.path) + } + + /** + * Validate and send the user message into [InferenceEngine] + */ + private fun handleUserInput() { + userInputEt.text.toString().also { userSsg -> + if (userSsg.isEmpty()) { + Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show() + } else { + userInputEt.text = null + userSendBtn.isEnabled = false + + // Update message states + messages.add(Message(UUID.randomUUID().toString(), userSsg, true)) + lastAssistantMsg.clear() + messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false)) + + lifecycleScope.launch(Dispatchers.Default) { + engine.sendUserPrompt(userSsg) + .onCompletion { + withContext(Dispatchers.Main) { + userSendBtn.isEnabled = true + } + }.collect { token -> + val messageCount = messages.size + check(messageCount > 0 && !messages[messageCount - 1].isUser) + + messages.removeAt(messageCount - 1).copy( + content = lastAssistantMsg.append(token).toString() + ).let { messages.add(it) } + + withContext(Dispatchers.Main) { + messageAdapter.notifyItemChanged(messages.size - 1) + } + } } - }) { Text("Copy") } + } } + } - Column { - for (model in models) { - Downloadable.Button(viewModel, dm, model) + /** + * Run a benchmark with the model file + */ + private suspend fun runBenchmark(modelName: String, modelFile: File) = + withContext(Dispatchers.Default) { + Log.i(TAG, "Starts benchmarking $modelName") + withContext(Dispatchers.Main) { + userInputEt.hint = "Running benchmark..." + } + engine.bench( + pp=BENCH_PROMPT_PROCESSING_TOKENS, + tg=BENCH_TOKEN_GENERATION_TOKENS, + pl=BENCH_SEQUENCE, + nr=BENCH_REPETITION + ).let { result -> + messages.add(Message(UUID.randomUUID().toString(), result, false)) + withContext(Dispatchers.Main) { + messageAdapter.notifyItemChanged(messages.size - 1) + } } } + + /** + * Create the `models` directory if not exist. + */ + private fun ensureModelsDirectory() = + File(filesDir, DIRECTORY_MODELS).also { + if (it.exists() && !it.isDirectory) { it.delete() } + if (!it.exists()) { it.mkdir() } + } + + companion object { + private val TAG = MainActivity::class.java.simpleName + + private const val DIRECTORY_MODELS = "models" + private const val FILE_EXTENSION_GGUF = ".gguf" + + private const val BENCH_PROMPT_PROCESSING_TOKENS = 512 + private const val BENCH_TOKEN_GENERATION_TOKENS = 128 + private const val BENCH_SEQUENCE = 1 + private const val BENCH_REPETITION = 3 + } +} + +fun GgufMetadata.filename() = when { + basic.name != null -> { + basic.name?.let { name -> + basic.sizeLabel?.let { size -> + "$name-$size" + } ?: name + } + } + architecture?.architecture != null -> { + architecture?.architecture?.let { arch -> + basic.uuid?.let { uuid -> + "$arch-$uuid" + } ?: "$arch-${System.currentTimeMillis()}" + } + } + else -> { + "model-${System.currentTimeMillis().toHexString()}" } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt deleted file mode 100644 index 45ac29938f441..0000000000000 --- a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt +++ /dev/null @@ -1,105 +0,0 @@ -package com.example.llama - -import android.llama.cpp.LLamaAndroid -import android.util.Log -import androidx.compose.runtime.getValue -import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.setValue -import androidx.lifecycle.ViewModel -import androidx.lifecycle.viewModelScope -import kotlinx.coroutines.flow.catch -import kotlinx.coroutines.launch - -class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() { - companion object { - @JvmStatic - private val NanosPerSecond = 1_000_000_000.0 - } - - private val tag: String? = this::class.simpleName - - var messages by mutableStateOf(listOf("Initializing...")) - private set - - var message by mutableStateOf("") - private set - - override fun onCleared() { - super.onCleared() - - viewModelScope.launch { - try { - llamaAndroid.unload() - } catch (exc: IllegalStateException) { - messages += exc.message!! - } - } - } - - fun send() { - val text = message - message = "" - - // Add to messages console. - messages += text - messages += "" - - viewModelScope.launch { - llamaAndroid.send(text) - .catch { - Log.e(tag, "send() failed", it) - messages += it.message!! - } - .collect { messages = messages.dropLast(1) + (messages.last() + it) } - } - } - - fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) { - viewModelScope.launch { - try { - val start = System.nanoTime() - val warmupResult = llamaAndroid.bench(pp, tg, pl, nr) - val end = System.nanoTime() - - messages += warmupResult - - val warmup = (end - start).toDouble() / NanosPerSecond - messages += "Warm up time: $warmup seconds, please wait..." - - if (warmup > 5.0) { - messages += "Warm up took too long, aborting benchmark" - return@launch - } - - messages += llamaAndroid.bench(512, 128, 1, 3) - } catch (exc: IllegalStateException) { - Log.e(tag, "bench() failed", exc) - messages += exc.message!! - } - } - } - - fun load(pathToModel: String) { - viewModelScope.launch { - try { - llamaAndroid.load(pathToModel) - messages += "Loaded $pathToModel" - } catch (exc: IllegalStateException) { - Log.e(tag, "load() failed", exc) - messages += exc.message!! - } - } - } - - fun updateMessage(newMessage: String) { - message = newMessage - } - - fun clear() { - messages = listOf() - } - - fun log(message: String) { - messages += message - } -} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt b/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt new file mode 100644 index 0000000000000..0439f964415fb --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt @@ -0,0 +1,51 @@ +package com.example.llama + +import android.view.LayoutInflater +import android.view.View +import android.view.ViewGroup +import android.widget.TextView +import androidx.recyclerview.widget.RecyclerView + +data class Message( + val id: String, + val content: String, + val isUser: Boolean +) + +class MessageAdapter( + private val messages: List +) : RecyclerView.Adapter() { + + companion object { + private const val VIEW_TYPE_USER = 1 + private const val VIEW_TYPE_ASSISTANT = 2 + } + + override fun getItemViewType(position: Int): Int { + return if (messages[position].isUser) VIEW_TYPE_USER else VIEW_TYPE_ASSISTANT + } + + override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): RecyclerView.ViewHolder { + val layoutInflater = LayoutInflater.from(parent.context) + return if (viewType == VIEW_TYPE_USER) { + val view = layoutInflater.inflate(R.layout.item_message_user, parent, false) + UserMessageViewHolder(view) + } else { + val view = layoutInflater.inflate(R.layout.item_message_assistant, parent, false) + AssistantMessageViewHolder(view) + } + } + + override fun onBindViewHolder(holder: RecyclerView.ViewHolder, position: Int) { + val message = messages[position] + if (holder is UserMessageViewHolder || holder is AssistantMessageViewHolder) { + val textView = holder.itemView.findViewById(R.id.msg_content) + textView.text = message.content + } + } + + override fun getItemCount(): Int = messages.size + + class UserMessageViewHolder(view: View) : RecyclerView.ViewHolder(view) + class AssistantMessageViewHolder(view: View) : RecyclerView.ViewHolder(view) +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Color.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Color.kt deleted file mode 100644 index 40c30e8d97077..0000000000000 --- a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Color.kt +++ /dev/null @@ -1,11 +0,0 @@ -package com.example.llama.ui.theme - -import androidx.compose.ui.graphics.Color - -val Purple80 = Color(0xFFD0BCFF) -val PurpleGrey80 = Color(0xFFCCC2DC) -val Pink80 = Color(0xFFEFB8C8) - -val Purple40 = Color(0xFF6650a4) -val PurpleGrey40 = Color(0xFF625b71) -val Pink40 = Color(0xFF7D5260) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Theme.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Theme.kt deleted file mode 100644 index e742220a8d719..0000000000000 --- a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Theme.kt +++ /dev/null @@ -1,70 +0,0 @@ -package com.example.llama.ui.theme - -import android.app.Activity -import android.os.Build -import androidx.compose.foundation.isSystemInDarkTheme -import androidx.compose.material3.MaterialTheme -import androidx.compose.material3.darkColorScheme -import androidx.compose.material3.dynamicDarkColorScheme -import androidx.compose.material3.dynamicLightColorScheme -import androidx.compose.material3.lightColorScheme -import androidx.compose.runtime.Composable -import androidx.compose.runtime.SideEffect -import androidx.compose.ui.graphics.toArgb -import androidx.compose.ui.platform.LocalContext -import androidx.compose.ui.platform.LocalView -import androidx.core.view.WindowCompat - -private val DarkColorScheme = darkColorScheme( - primary = Purple80, - secondary = PurpleGrey80, - tertiary = Pink80 -) - -private val LightColorScheme = lightColorScheme( - primary = Purple40, - secondary = PurpleGrey40, - tertiary = Pink40 - - /* Other default colors to override - background = Color(0xFFFFFBFE), - surface = Color(0xFFFFFBFE), - onPrimary = Color.White, - onSecondary = Color.White, - onTertiary = Color.White, - onBackground = Color(0xFF1C1B1F), - onSurface = Color(0xFF1C1B1F), - */ -) - -@Composable -fun LlamaAndroidTheme( - darkTheme: Boolean = isSystemInDarkTheme(), - // Dynamic color is available on Android 12+ - dynamicColor: Boolean = true, - content: @Composable () -> Unit -) { - val colorScheme = when { - dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> { - val context = LocalContext.current - if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context) - } - - darkTheme -> DarkColorScheme - else -> LightColorScheme - } - val view = LocalView.current - if (!view.isInEditMode) { - SideEffect { - val window = (view.context as Activity).window - window.statusBarColor = colorScheme.primary.toArgb() - WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme - } - } - - MaterialTheme( - colorScheme = colorScheme, - typography = Typography, - content = content - ) -} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Type.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Type.kt deleted file mode 100644 index 0b87946ca3ab1..0000000000000 --- a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Type.kt +++ /dev/null @@ -1,34 +0,0 @@ -package com.example.llama.ui.theme - -import androidx.compose.material3.Typography -import androidx.compose.ui.text.TextStyle -import androidx.compose.ui.text.font.FontFamily -import androidx.compose.ui.text.font.FontWeight -import androidx.compose.ui.unit.sp - -// Set of Material typography styles to start with -val Typography = Typography( - bodyLarge = TextStyle( - fontFamily = FontFamily.Default, - fontWeight = FontWeight.Normal, - fontSize = 16.sp, - lineHeight = 24.sp, - letterSpacing = 0.5.sp - ) - /* Other default text styles to override - titleLarge = TextStyle( - fontFamily = FontFamily.Default, - fontWeight = FontWeight.Normal, - fontSize = 22.sp, - lineHeight = 28.sp, - letterSpacing = 0.sp - ), - labelSmall = TextStyle( - fontFamily = FontFamily.Default, - fontWeight = FontWeight.Medium, - fontSize = 11.sp, - lineHeight = 16.sp, - letterSpacing = 0.5.sp - ) - */ -) diff --git a/examples/llama.android/app/src/main/res/drawable/bg_assistant_message.xml b/examples/llama.android/app/src/main/res/drawable/bg_assistant_message.xml new file mode 100644 index 0000000000000..f90c3db458301 --- /dev/null +++ b/examples/llama.android/app/src/main/res/drawable/bg_assistant_message.xml @@ -0,0 +1,4 @@ + + + + diff --git a/examples/llama.android/app/src/main/res/drawable/bg_user_message.xml b/examples/llama.android/app/src/main/res/drawable/bg_user_message.xml new file mode 100644 index 0000000000000..3ca7daefec78a --- /dev/null +++ b/examples/llama.android/app/src/main/res/drawable/bg_user_message.xml @@ -0,0 +1,4 @@ + + + + diff --git a/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml b/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml index 7706ab9e6d407..2b068d11462a4 100644 --- a/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml +++ b/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml @@ -27,4 +27,4 @@ android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z" android:strokeWidth="1" android:strokeColor="#00000000" /> - + \ No newline at end of file diff --git a/examples/llama.android/app/src/main/res/drawable/outline_folder_open_24.xml b/examples/llama.android/app/src/main/res/drawable/outline_folder_open_24.xml new file mode 100644 index 0000000000000..f58b501e3bc31 --- /dev/null +++ b/examples/llama.android/app/src/main/res/drawable/outline_folder_open_24.xml @@ -0,0 +1,10 @@ + + + diff --git a/examples/llama.android/app/src/main/res/drawable/outline_send_24.xml b/examples/llama.android/app/src/main/res/drawable/outline_send_24.xml new file mode 100644 index 0000000000000..712adc00c4e3d --- /dev/null +++ b/examples/llama.android/app/src/main/res/drawable/outline_send_24.xml @@ -0,0 +1,11 @@ + + + diff --git a/examples/llama.android/app/src/main/res/layout/activity_main.xml b/examples/llama.android/app/src/main/res/layout/activity_main.xml new file mode 100644 index 0000000000000..90eda033e730f --- /dev/null +++ b/examples/llama.android/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,96 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/llama.android/app/src/main/res/layout/item_message_assistant.xml b/examples/llama.android/app/src/main/res/layout/item_message_assistant.xml new file mode 100644 index 0000000000000..b7fb50039398d --- /dev/null +++ b/examples/llama.android/app/src/main/res/layout/item_message_assistant.xml @@ -0,0 +1,15 @@ + + + + + diff --git a/examples/llama.android/app/src/main/res/layout/item_message_user.xml b/examples/llama.android/app/src/main/res/layout/item_message_user.xml new file mode 100644 index 0000000000000..fe871f12fa7ae --- /dev/null +++ b/examples/llama.android/app/src/main/res/layout/item_message_user.xml @@ -0,0 +1,15 @@ + + + + + diff --git a/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml b/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml index b3e26b4c60c27..6f3b755bf50c6 100644 --- a/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml +++ b/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml @@ -3,4 +3,4 @@ - + \ No newline at end of file diff --git a/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml b/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml index b3e26b4c60c27..6f3b755bf50c6 100644 --- a/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml +++ b/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml @@ -3,4 +3,4 @@ - + \ No newline at end of file diff --git a/examples/llama.android/app/src/main/res/values/strings.xml b/examples/llama.android/app/src/main/res/values/strings.xml index 7a9d314e2969b..36059fc799888 100644 --- a/examples/llama.android/app/src/main/res/values/strings.xml +++ b/examples/llama.android/app/src/main/res/values/strings.xml @@ -1,3 +1,3 @@ - LlamaAndroid + AI Chat basic sample diff --git a/examples/llama.android/app/src/main/res/values/themes.xml b/examples/llama.android/app/src/main/res/values/themes.xml index 8a24fda56602c..2e4fdad72e012 100644 --- a/examples/llama.android/app/src/main/res/values/themes.xml +++ b/examples/llama.android/app/src/main/res/values/themes.xml @@ -1,5 +1,10 @@ - + +