Skip to content

Commit

Permalink
Add model loading screen
Browse files Browse the repository at this point in the history
  • Loading branch information
mattsacks committed Jun 9, 2020
1 parent 6096c40 commit bdbda6b
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/App/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,21 @@ import { View } from "react-native";
import styles from "./styles";
import Camera from "../Camera";
import Detections from "../Detections";
import LoadGraph from "../LoadGraph";
/* import Readout from "../Readout"; */
/* import Controls from "../Controls"; */

export default function App() {
const [cameraPermission, setCameraPermission] = React.useState(false);
const [graph, setGraph] = React.useState<tensorflow.GraphModel>();
const [images, setImages] = React.useState<
IterableIterator<tensorflow.Tensor3D>
>();

if (!graph) {
return <LoadGraph setGraph={setGraph} />;
}

return (
<View style={styles.app}>
<View style={styles.cameraContainer}>
Expand All @@ -22,7 +28,7 @@ export default function App() {
setCameraPermission={setCameraPermission}
setImages={setImages}
/>
<Detections images={images} />
{/* <Detections images={images} graph={graph} /> */}
</View>
{/* <Controls /> */}
</View>
Expand Down
50 changes: 50 additions & 0 deletions src/LoadGraph/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import * as React from "react";
import { Text, View } from "react-native";
import * as tensorflow from "@tensorflow/tfjs";
import styles from "./styles";

// TODO: Make this an env variable
const ORIGIN =
"https://tfhub.dev/tensorflow/tfjs-model/ssd_mobilenet_v2/1/default/1";

interface Props {
setGraph: (graph: tensorflow.GraphModel) => void;
}

export default function LoadGraph(props: Props): JSX.Element {
const { setGraph } = props;
const [progress, setProgress] = React.useState(0);

React.useEffect(() => {
const loadGraph = async () => {
await tensorflow.ready();

const options = {
fromTFHub: true,
onProgress: (loading: number) => {
setProgress(Math.round(loading * 100));
}
};

const graph = await tensorflow.loadGraphModel(ORIGIN, options);

setGraph(graph);
};

loadGraph().catch((error) => {
console.info(`Error!: ${error}`);
});
}, []);

return (
<View style={styles.loading}>
<Text style={styles.loadingHeader}>Loading model…</Text>
<View style={styles.progressContainer}>
<View style={styles.progressBar}>
<View style={{ ...styles.progress, width: `${progress}%` }} />
</View>
<Text>{progress}%</Text>
</View>
</View>
);
}
36 changes: 36 additions & 0 deletions src/LoadGraph/styles.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { StyleSheet } from "react-native";

export default StyleSheet.create({
loading: {
alignItems: "center",
display: "flex",
flexDirection: "column",
height: "100%",
justifyContent: "center",
paddingHorizontal: 20
},
loadingHeader: {
fontSize: 20,
marginBottom: 4,
textAlign: "center"
},
progressContainer: {
display: "flex",
justifyContent: "space-between",
overflow: "hidden",
width: "100%"
},
progressBar: {
backgroundColor: "white",
borderColor: "black",
borderRadius: 4,
borderWidth: 2,
// flex: 1,
height: 12,
marginBottom: 4
},
progress: {
backgroundColor: "black",
height: "100%"
}
});

0 comments on commit bdbda6b

Please sign in to comment.