Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions mobile/examples/phi-3/android/README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,38 @@
# Local Chatbot on Android with Phi-3, ONNX Runtime Mobile and ONNX Runtime Generate() API
# Local Chatbot on Android with ONNX Runtime Mobile and ONNX Runtime Generate() API

## Overview

This is a basic [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) Android example application with [ONNX Runtime mobile](https://onnxruntime.ai/docs/tutorials/mobile/) and [ONNX Runtime Generate() API](https://github.com/microsoft/onnxruntime-genai) with support for efficiently running generative AI models. This app demonstrates the usage of phi-3 model in a simple question answering chatbot mode.
This is a flexible Android chatbot application with [ONNX Runtime mobile](https://onnxruntime.ai/docs/tutorials/mobile/) and [ONNX Runtime Generate() API](https://github.com/microsoft/onnxruntime-genai) that supports efficiently running generative AI models. While it uses [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) by default, **it can work with any ONNX Runtime GenAI compatible model** by simply updating the model configuration in the code.

### Model
The model used here is [ONNX Phi-3 model on HuggingFace](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4) with INT4 quantization and optimizations for mobile usage.

You can also optimize your fine-tuned PyTorch Phi-3 model for mobile usage following this example [Phi3 optimization with Olive](https://github.com/microsoft/Olive/tree/main/examples/phi3).
By default, this app uses the [ONNX Phi-3 model on HuggingFace](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4) with INT4 quantization and optimizations for mobile usage.

### Using Different Models
**The app is designed to work with any ONNX Runtime GenAI compatible model.** To use a different model:

1. Open `MainActivity.java` in Android Studio
2. Locate the model configuration section at the top of the class (marked with comments)
3. Update the `MODEL_BASE_URL` to point to your model's download location
4. Update the `MODEL_FILES` list to include all required files for your model

Example for a different model:
```java
// Base URL for downloading model files (ensure it ends with '/')
private static final String MODEL_BASE_URL = "https://your-model-host.com/path/to/model/";

// List of required model files to download
private static final List<String> MODEL_FILES = Arrays.asList(
"config.json",
"genai_config.json",
"your-model.onnx",
"your-model.onnx.data",
"tokenizer.json",
"tokenizer_config.json"
// Add other required files...
);
```

**Note:** The model files will be downloaded to `/data/data/ai.onnxruntime.genai.demo/files` on the Android device.

### Requirements
- Android Studio Giraffe | 2022.3.1 or later (installed on Mac/Windows/Linux)
Expand All @@ -30,7 +55,7 @@ The current set up supports downloading Phi-3-mini model directly from Huggingfa
You can also follow this link to download **Phi-3-mini**: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4
and manually copy to the android device file directory following the below instructions:

#### Steps for manual copying models to android device directory:
#### Steps for manual copying model files to android device directory:
From Android Studio:
- create (if necessary) and run your emulator/device
- make sure it has at least 8GB of internal storage
Expand All @@ -40,7 +65,8 @@ From Android Studio:
- Open Device Explorer in Android Studio
- Navigate to `/data/data/ai.onnxruntime.genai.demo/files`
- adjust as needed if the value returned by getFilesDir() differs for your emulator or device
- copy the whole [phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4) model folder to the `files` directory
- copy all the required model files (as specified in `MODEL_FILES` in MainActivity.java) directly to the `files` directory
- For the default Phi-3 model, copy files from [here](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4)

### Step 3: Connect Android Device and Run the app
Connect your Android Device to your computer or select the Android Emulator in Android Studio Device manager.
Expand Down
2 changes: 1 addition & 1 deletion mobile/examples/phi-3/android/app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ dependencies {

// ONNX Runtime with GenAI
implementation("com.microsoft.onnxruntime:onnxruntime-android:latest.release")
implementation(files("libs/onnxruntime-genai-android-0.4.0-dev.aar"))
implementation(files("libs/onnxruntime-genai-android-0.8.1.aar"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a published onnxruntime-genai-android package that we can use?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I downloaded the aar file from https://github.com/microsoft/onnxruntime-genai/releases. GenAI is not publishing android packages on maven

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder if there's a way to have gradle download from github? it'd be nice to not have to keep a copy of the .aar in the repo. we don't have to figure that out in this PR though.


}
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,33 @@
import java.util.concurrent.Executors;
import java.util.function.Consumer;

import ai.onnxruntime.genai.SimpleGenAI;
import ai.onnxruntime.genai.GenAIException;
import ai.onnxruntime.genai.Generator;
import ai.onnxruntime.genai.GeneratorParams;
import ai.onnxruntime.genai.Sequences;
import ai.onnxruntime.genai.TokenizerStream;
import ai.onnxruntime.genai.demo.databinding.ActivityMainBinding;
import ai.onnxruntime.genai.Model;
import ai.onnxruntime.genai.Tokenizer;

public class MainActivity extends AppCompatActivity implements Consumer<String> {

private ActivityMainBinding binding;
// ===== MODEL CONFIGURATION - MODIFY THESE FOR DIFFERENT MODELS =====
// Base URL for downloading model files (ensure it ends with '/')
private static final String MODEL_BASE_URL = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/";

// List of required model files to download
private static final List<String> MODEL_FILES = Arrays.asList(
"added_tokens.json",
"config.json",
"configuration_phi3.py",
"genai_config.json",
"phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx",
"phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer.model",
"tokenizer_config.json"
);
// ===== END MODEL CONFIGURATION =====

private EditText userMsgEdt;
private Model model;
private Tokenizer tokenizer;
private SimpleGenAI genAI;
private ImageButton sendMsgIB;
private TextView generatedTV;
private TextView promptTV;
Expand All @@ -56,9 +68,7 @@ private static boolean fileExists(Context context, String fileName) {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);

binding = ActivityMainBinding.inflate(getLayoutInflater());
setContentView(binding.getRoot());
setContentView(R.layout.activity_main);

sendMsgIB = findViewById(R.id.idIBSend);
userMsgEdt = findViewById(R.id.idEdtMessage);
Expand Down Expand Up @@ -90,8 +100,6 @@ public void onSettingsApplied(int maxLength, float lengthPenalty) {
});


Consumer<String> tokenListener = this;

//enable scrolling and resizing of text boxes
generatedTV.setMovementMethod(new ScrollingMovementMethod());
getWindow().setSoftInputMode(WindowManager.LayoutParams.SOFT_INPUT_ADJUST_RESIZE);
Expand All @@ -100,7 +108,7 @@ public void onSettingsApplied(int maxLength, float lengthPenalty) {
sendMsgIB.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
if (tokenizer == null) {
if (genAI == null) {
// if user tries to submit prompt while model is still downloading, display a toast message.
Toast.makeText(MainActivity.this, "Model not loaded yet, please wait...", Toast.LENGTH_SHORT).show();
return;
Expand Down Expand Up @@ -131,77 +139,58 @@ public void onClick(View v) {
new Thread(new Runnable() {
@Override
public void run() {
TokenizerStream stream = null;
GeneratorParams generatorParams = null;
Generator generator = null;
Sequences encodedPrompt = null;
try {
stream = tokenizer.createStream();

generatorParams = model.createGeneratorParams();
//examples for optional parameters to format AI response
// Create generator parameters
GeneratorParams generatorParams = genAI.createGeneratorParams();

// Set optional parameters to format AI response
// https://onnxruntime.ai/docs/genai/reference/config.html
generatorParams.setSearchOption("length_penalty", lengthPenalty);
generatorParams.setSearchOption("max_length", maxLength);

encodedPrompt = tokenizer.encode(promptQuestion_formatted);
generatorParams.setInput(encodedPrompt);

generator = new Generator(model, generatorParams);

// try to measure average time taken to generate each token.
generatorParams.setSearchOption("length_penalty", (double)lengthPenalty);
generatorParams.setSearchOption("max_length", (double)maxLength);
long startTime = System.currentTimeMillis();
long firstTokenTime = startTime;
long currentTime = startTime;
int numTokens = 0;
while (!generator.isDone()) {
generator.computeLogits();
generator.generateNextToken();

int token = generator.getLastTokenInSequence(0);

if (numTokens == 0) { //first token
firstTokenTime = System.currentTimeMillis();
final long[] firstTokenTime = {startTime};
final long[] numTokens = {0};

// Token listener for streaming tokens
Consumer<String> tokenListener = token -> {
if (numTokens[0] == 0) {
firstTokenTime[0] = System.currentTimeMillis();
}

tokenListener.accept(stream.decode(token));

// Update UI with new token
MainActivity.this.accept(token);

Log.i(TAG, "Generated token: " + token);
numTokens[0] += 1;
};


Log.i(TAG, "Generated token: " + token + ": " + stream.decode(token));
Log.i(TAG, "Time taken to generate token: " + (System.currentTimeMillis() - currentTime)/ 1000.0 + " seconds");
currentTime = System.currentTimeMillis();
numTokens++;
}
long totalTime = System.currentTimeMillis() - firstTokenTime;

float promptProcessingTime = (firstTokenTime - startTime)/ 1000.0f;
float tokensPerSecond = (1000 * (numTokens -1)) / totalTime;
String fullResponse = genAI.generate(generatorParams, promptQuestion_formatted, tokenListener);

long totalTime = System.currentTimeMillis() - firstTokenTime[0];
float promptProcessingTime = (firstTokenTime[0] - startTime) / 1000.0f;
float tokensPerSecond = numTokens[0] > 1 ? (1000.0f * (numTokens[0] - 1)) / totalTime : 0;

runOnUiThread(() -> {
sendMsgIB.setEnabled(true);
sendMsgIB.setAlpha(1.0f);

// Display the token generation rate in a dialog popup
showTokenPopup(promptProcessingTime, tokensPerSecond);
});

Log.i(TAG, "Full response: " + fullResponse);
Log.i(TAG, "Prompt processing time (first token): " + promptProcessingTime + " seconds");
Log.i(TAG, "Tokens generated per second (excluding prompt processing): " + tokensPerSecond);
}
catch (GenAIException e) {
Log.e(TAG, "Exception occurred during model query: " + e.getMessage());
runOnUiThread(() -> {
Toast.makeText(MainActivity.this, "Error generating response: " + e.getMessage(), Toast.LENGTH_SHORT).show();
});
}
finally {
if (generator != null) generator.close();
if (encodedPrompt != null) encodedPrompt.close();
if (stream != null) stream.close();
if (generatorParams != null) generatorParams.close();
runOnUiThread(() -> {
sendMsgIB.setEnabled(true);
sendMsgIB.setAlpha(1.0f);
});
}

runOnUiThread(() -> {
sendMsgIB.setEnabled(true);
sendMsgIB.setAlpha(1.0f);
});
}
}).start();
}
Expand All @@ -210,42 +199,28 @@ public void run() {

@Override
protected void onDestroy() {
tokenizer.close();
tokenizer = null;
model.close();
model = null;
if (genAI != null) {
genAI.close();
genAI = null;
}
super.onDestroy();
}

private void downloadModels(Context context) throws GenAIException {

final String baseUrl = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/";
List<String> files = Arrays.asList(
"added_tokens.json",
"config.json",
"configuration_phi3.py",
"genai_config.json",
"phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx",
"phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer.model",
"tokenizer_config.json");

List<Pair<String, String>> urlFilePairs = new ArrayList<>();
for (String file : files) {
for (String file : MODEL_FILES) {
if (!fileExists(context, file)) {
urlFilePairs.add(new Pair<>(
baseUrl + file,
MODEL_BASE_URL + file,
file));
}
}
if (urlFilePairs.isEmpty()) {
// Display a message using Toast
Toast.makeText(this, "All files already exist. Skipping download.", Toast.LENGTH_SHORT).show();
Log.d(TAG, "All files already exist. Skipping download.");
model = new Model(getFilesDir().getPath());
tokenizer = model.createTokenizer();
genAI = new SimpleGenAI(getFilesDir().getPath());
return;
}

Expand Down Expand Up @@ -276,15 +251,18 @@ public void onDownloadComplete() {

// Last download completed, create SimpleGenAI
try {
model = new Model(getFilesDir().getPath());
tokenizer = model.createTokenizer();
genAI = new SimpleGenAI(getFilesDir().getPath());
runOnUiThread(() -> {
Toast.makeText(context, "All downloads completed", Toast.LENGTH_SHORT).show();
progressText.setVisibility(View.INVISIBLE);
});
} catch (GenAIException e) {
e.printStackTrace();
throw new RuntimeException(e);
Log.e(TAG, "Failed to initialize SimpleGenAI: " + e.getMessage());
runOnUiThread(() -> {
Toast.makeText(context, "Failed to load model: " + e.getMessage(), Toast.LENGTH_LONG).show();
progressText.setText("Failed to load model");
});
}

}
Expand Down
Binary file modified mobile/examples/phi-3/android/gradle/wrapper/gradle-wrapper.jar
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#Mon Mar 25 10:44:29 AEST 2024
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.0-bin.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-8.9-bin.zip
networkTimeout=10000
validateDistributionUrl=true
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
Loading
Loading