Skip to content

Commit

Permalink
Demo improvements (#29)
Browse files Browse the repository at this point in the history
* improved demo rendering on mobile

* improved genai demo rendering on mobile

* tidying up genai chat demo

* removes unusued import
  • Loading branch information
craiglabenz committed May 12, 2024
1 parent 9b936fa commit cef228c
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 171 deletions.
2 changes: 1 addition & 1 deletion packages/mediapipe-task-genai/example/lib/logging.dart
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import 'package:logging/logging.dart';
final log = Logger('Genai');

void initLogging() {
Logger.root.level = Level.FINEST;
Logger.root.level = Level.FINER;
Logger.root.onRecord.listen((record) {
io.stdout.writeln('${record.level.name} [${record.loggerName}]'
'['
Expand Down
64 changes: 30 additions & 34 deletions packages/mediapipe-task-genai/example/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ class AppBlocObserver extends BlocObserver {

@override
void onChange(BlocBase<dynamic> bloc, Change<dynamic> change) {
_log.finer('onChange(${bloc.runtimeType}, $change)');
_log.finest('onChange(${bloc.runtimeType}, $change)');
super.onChange(bloc, change);
}

@override
void onEvent(Bloc<dynamic, dynamic> bloc, Object? event) {
_log.finer('onEvent($event)');
_log.finest('onEvent($event)');
super.onEvent(bloc, event);
}

@override
void onError(BlocBase<dynamic> bloc, Object error, StackTrace stackTrace) {
// print('onError(${bloc.runtimeType}, $error, $stackTrace)');
_log.shout('onError(${bloc.runtimeType}, $error, $stackTrace)');
super.onError(bloc, error, stackTrace);
}
}
Expand Down Expand Up @@ -61,6 +61,19 @@ class _MainAppState extends State<MainApp> {
final titles = <String>['Inference'];
int titleIndex = 0;

@override
void initState() {
controller.addListener(() {
final newIndex = controller.page?.toInt();
if (newIndex != null && newIndex != titleIndex) {
setState(() {
titleIndex = newIndex;
});
}
});
super.initState();
}

void switchToPage(int index) {
controller.animateToPage(
index,
Expand All @@ -74,45 +87,28 @@ class _MainAppState extends State<MainApp> {

@override
Widget build(BuildContext context) {
const activeTextStyle = TextStyle(
fontWeight: FontWeight.bold,
color: Colors.orange,
);
const inactiveTextStyle = TextStyle(
color: Colors.white,
);
return Scaffold(
body: PageView(
controller: controller,
children: const <Widget>[
LlmInferenceDemo(),
],
),
bottomNavigationBar: SizedBox(
height: 50 + MediaQuery.of(context).viewPadding.bottom / 2,
child: ColoredBox(
color: Colors.blueGrey,
child: Column(
children: [
Row(
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
children: <Widget>[
TextButton(
onPressed: () => switchToPage(0),
child: Text(
'Inference',
style:
titleIndex == 0 ? activeTextStyle : inactiveTextStyle,
),
),
],
),
SizedBox(
height: MediaQuery.of(context).viewPadding.bottom / 2,
),
],
bottomNavigationBar: BottomNavigationBar(
currentIndex: titleIndex,
onTap: switchToPage,
items: const <BottomNavigationBarItem>[
BottomNavigationBarItem(
icon: Icon(Icons.chat_bubble),
activeIcon: Icon(Icons.chat_bubble, color: Colors.blue),
label: 'Inference',
),
),
BottomNavigationBarItem(
icon: Icon(Icons.cancel),
activeIcon: Icon(Icons.cancel, color: Colors.blue),
label: 'Coming soon',
),
],
),
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,7 @@ class ModelLocationProvider {

// Useful for quick development if there is any friction around passing
// environment variables
static const hardcodedLocations = <LlmModel, String>{
LlmModel.gemma4bCpu:
'https://storage.googleapis.com/random-storage-asdf/gemma/gemma-2b-it-cpu-int4.bin',
LlmModel.gemma4bGpu:
'https://storage.googleapis.com/random-storage-asdf/gemma/gemma-2b-it-gpu-int4.bin',
LlmModel.gemma8bCpu:
'https://storage.googleapis.com/random-storage-asdf/gemma/gemma-2b-it-cpu-int8.bin',
LlmModel.gemma8bGpu:
'https://storage.googleapis.com/random-storage-asdf/gemma/gemma-2b-it-gpu-int8.bin',
};
static const hardcodedLocations = <LlmModel, String>{};

static ModelPaths _getModelLocationsFromEnvironment() {
final locations = <LlmModel, String>{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,14 @@ class LlmInferenceEngine extends BaseLlmInferenceEngine {
if (response is LlmResponseContext) {
publish(response.responseArray.join(''));
if (response.isDone) {
_log.finer('response.isDone - closing controller from $publish');
_endResponse();
}
} else if (response is String) {
_log.fine(response);
} else {
throw Exception(
'Unexpected sizeInTokens result of type ${response.runtimeType} : $response',
'Unexpected generateResponse result of type ${response.runtimeType} : $response',
);
}
}
Expand Down
24 changes: 24 additions & 0 deletions packages/mediapipe-task-text/example/lib/keyboard_hider.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright 2014 The Flutter Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

import 'package:flutter/services.dart';
import 'package:flutter/widgets.dart';

class KeyboardHider extends StatelessWidget {
const KeyboardHider({required this.child, super.key});

final Widget? child;

@override
Widget build(BuildContext context) {
return GestureDetector(
onTap: () {
// Not sure why this one isn't working.
// FocusScope.of(context).unfocus();
SystemChannels.textInput.invokeMethod('TextInput.hide');
},
child: child,
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import 'dart:async';
import 'dart:typed_data';
import 'package:example/keyboard_hider.dart';
import 'package:flutter/material.dart';
import 'package:getwidget/getwidget.dart';
import 'package:mediapipe_text/mediapipe_text.dart';
Expand Down Expand Up @@ -68,34 +69,36 @@ class _LanguageDetectionDemoState extends State<LanguageDetectionDemo>
void _showDetectionResults(LanguageDetectorResult result) {
setState(
() {
results.last = Card(
key: Key('prediction-"$_isProcessing" ${results.length}'),
margin: const EdgeInsets.all(10),
child: Column(
children: [
Padding(
padding: const EdgeInsets.all(10),
child: Text(_isProcessing!),
),
Padding(
padding: const EdgeInsets.all(10.0),
child: Wrap(
children: <Widget>[
...result.predictions
.enumerate<Widget>(
(prediction, index) => _languagePrediction(
prediction,
predictionColors[index],
),
// Take first 4 because the model spits out dozens of
// astronomically low probability language predictions
max: predictionColors.length,
)
.toList(),
],
results.last = KeyboardHider(
child: Card(
key: Key('prediction-"$_isProcessing" ${results.length}'),
margin: const EdgeInsets.all(10),
child: Column(
children: [
Padding(
padding: const EdgeInsets.all(10),
child: Text(_isProcessing!),
),
Padding(
padding: const EdgeInsets.all(10.0),
child: Wrap(
children: <Widget>[
...result.predictions
.enumerate<Widget>(
(prediction, index) => _languagePrediction(
prediction,
predictionColors[index],
),
// Take first 4 because the model spits out dozens of
// astronomically low probability language predictions
max: predictionColors.length,
)
.toList(),
],
),
),
),
],
],
),
),
);
_isProcessing = null;
Expand Down
69 changes: 32 additions & 37 deletions packages/mediapipe-task-text/example/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ class TextTaskPagesState extends State<TextTaskPages> {
final titles = <String>['Classify', 'Embed', 'Detect Languages'];
int titleIndex = 0;

@override
void initState() {
controller.addListener(() {
final newIndex = controller.page?.toInt();
if (newIndex != null && newIndex != titleIndex) {
setState(() {
titleIndex = newIndex;
});
}
});
super.initState();
}

void switchToPage(int index) {
controller.animateToPage(
index,
Expand All @@ -48,13 +61,6 @@ class TextTaskPagesState extends State<TextTaskPages> {

@override
Widget build(BuildContext context) {
const activeTextStyle = TextStyle(
fontWeight: FontWeight.bold,
color: Colors.orange,
);
const inactiveTextStyle = TextStyle(
color: Colors.white,
);
return Scaffold(
appBar: AppBar(title: Text(titles[titleIndex])),
body: PageView(
Expand All @@ -65,37 +71,26 @@ class TextTaskPagesState extends State<TextTaskPages> {
LanguageDetectionDemo(),
],
),
bottomNavigationBar: SizedBox(
height: 50,
child: ColoredBox(
color: Colors.blueGrey,
child: Row(
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
children: <Widget>[
TextButton(
onPressed: () => switchToPage(0),
child: Text(
'Classify',
style: titleIndex == 0 ? activeTextStyle : inactiveTextStyle,
),
),
TextButton(
onPressed: () => switchToPage(1),
child: Text(
'Embed',
style: titleIndex == 1 ? activeTextStyle : inactiveTextStyle,
),
),
TextButton(
onPressed: () => switchToPage(2),
child: Text(
'Detect Languages',
style: titleIndex == 2 ? activeTextStyle : inactiveTextStyle,
),
),
],
bottomNavigationBar: BottomNavigationBar(
currentIndex: titleIndex,
onTap: switchToPage,
items: const <BottomNavigationBarItem>[
BottomNavigationBarItem(
icon: Icon(Icons.search),
activeIcon: Icon(Icons.search, color: Colors.blue),
label: 'Classify',
),
BottomNavigationBarItem(
icon: Icon(Icons.arrow_downward),
activeIcon: Icon(Icons.arrow_downward, color: Colors.blue),
label: 'Embed',
),
),
BottomNavigationBarItem(
icon: Icon(Icons.flag),
activeIcon: Icon(Icons.flag, color: Colors.blue),
label: 'Detect Languages',
),
],
),
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import 'dart:async';
import 'dart:typed_data';
import 'package:example/keyboard_hider.dart';
import 'package:flutter/material.dart';
import 'package:getwidget/getwidget.dart';
import 'package:mediapipe_core/mediapipe_core.dart';
Expand Down Expand Up @@ -65,24 +66,26 @@ class _TextClassificationDemoState extends State<TextClassificationDemo>

setState(
() {
results.last = Card(
key: Key('Classification::"$_isProcessing" ${results.length}'),
margin: const EdgeInsets.all(10),
child: Column(
children: [
Padding(
padding: const EdgeInsets.all(10),
child: Text(_isProcessing!),
),
Padding(
padding: const EdgeInsets.all(10.0),
child: Wrap(
children: <Widget>[
...categoryWidgets,
],
results.last = KeyboardHider(
child: Card(
key: Key('Classification::"$_isProcessing" ${results.length}'),
margin: const EdgeInsets.all(10),
child: Column(
children: [
Padding(
padding: const EdgeInsets.all(10),
child: Text(_isProcessing!),
),
),
],
Padding(
padding: const EdgeInsets.all(10.0),
child: Wrap(
children: <Widget>[
...categoryWidgets,
],
),
),
],
),
),
);
_isProcessing = null;
Expand Down Expand Up @@ -135,7 +138,7 @@ class _TextClassificationDemoState extends State<TextClassificationDemo>
child: Column(
children: <Widget>[
TextField(controller: _controller),
...results,
...results.reversed,
],
),
),
Expand Down
Loading

0 comments on commit cef228c

Please sign in to comment.