diff --git a/src/demo/colors.js b/src/demo/colors.js index 82d076cd..9c7ebd0b 100644 --- a/src/demo/colors.js +++ b/src/demo/colors.js @@ -1,7 +1,7 @@ const colors = { white: 'white', black: 'black', - grey: 'grey', + grey: '#4d575f', darkGrey: '#2f353e', lightGrey: '#e7e7e7', green: '#56c568', diff --git a/src/demo/models/pond.js b/src/demo/models/pond.js index f67ee09e..f2a2f0f3 100644 --- a/src/demo/models/pond.js +++ b/src/demo/models/pond.js @@ -6,14 +6,21 @@ import constants, {ClassType} from '../constants'; export const init = async () => { const state = getState(); let fishWithConfidence = await predictAllFish(state); - setState({totalPondFish: fishWithConfidence.length}); fishWithConfidence = _.sortBy(fishWithConfidence, ['confidence']); - const pondFishWithConfidence = fishWithConfidence.splice( + const fishByClassType = _.groupBy( + fishWithConfidence, + fish => fish.getResult().predictedClassId + ); + + let pondFish = fishByClassType[ClassType.Like]; + setState({totalPondFish: pondFish.length}); + pondFish = pondFish.splice(0, constants.maxPondFish); + const recallFish = fishByClassType[ClassType.Dislike].splice( 0, constants.maxPondFish ); - arrangeFish(pondFishWithConfidence); - setState({pondFish: pondFishWithConfidence}); + arrangeFish(pondFish); + setState({pondFish, recallFish}); }; const predictAllFish = state => { @@ -21,10 +28,8 @@ const predictAllFish = state => { let fishWithConfidence = []; state.fishData.map((fish, index) => { state.trainer.predict(fish).then(res => { - if (res.predictedClassId === ClassType.Like) { - fish.setResult(res); - fishWithConfidence.push(fish); - } + fish.setResult(res); + fishWithConfidence.push(fish); if (index === state.fishData.length - 1) { resolve(fishWithConfidence); @@ -34,7 +39,7 @@ const predictAllFish = state => { }); }; -const arrangeFish = fishes => { +export const arrangeFish = fishes => { let fishPositions = formatArrangement(); fishes.forEach(fish => { diff --git a/src/demo/renderer.js b/src/demo/renderer.js index df2c17ab..10684e0a 100644 --- a/src/demo/renderer.js +++ b/src/demo/renderer.js @@ -495,12 +495,12 @@ const drawPredictBot = state => { // Draw the fish for pond mode. const drawPondFishImages = () => { - const canvas = getState().canvas; - const ctx = canvas.getContext('2d'); - + const state = getState(); + const ctx = state.canvas.getContext('2d'); + const fishes = state.showRecallFish ? state.recallFish : state.pondFish; const fishBounds = []; - getState().pondFish.forEach(fish => { + fishes.forEach(fish => { const pondClickedFish = getState().pondClickedFish; const pondClickedFishUs = pondClickedFish && fish.id === pondClickedFish.id; diff --git a/src/demo/state.js b/src/demo/state.js index ef241f85..664c9d08 100644 --- a/src/demo/state.js +++ b/src/demo/state.js @@ -5,6 +5,8 @@ const initialState = { currentMode: null, fishData: [], pondFish: [], + recallFish: [], + showRecallFish: false, totalPondFish: null, backgroundCanvas: null, canvas: null, diff --git a/src/demo/ui.jsx b/src/demo/ui.jsx index afd23cae..521da947 100644 --- a/src/demo/ui.jsx +++ b/src/demo/ui.jsx @@ -7,6 +7,7 @@ import constants, {AppMode, Modes} from './constants'; import {toMode} from './toMode'; import {$time, currentRunTime, finishMovement, resetTraining} from './helpers'; import {onClassifyFish} from './models/train'; +import {arrangeFish} from './models/pond'; import colors from './colors'; import aiBotClosed from '../../public/images/ai-bot/ai-bot-closed.png'; import counterIcon from '../../public/images/data.png'; @@ -22,7 +23,9 @@ import { faPause, faBackward, faForward, - faEraser + faEraser, + faCheck, + faBan } from '@fortawesome/free-solid-svg-icons'; const styles = { @@ -42,6 +45,7 @@ const styles = { button: { cursor: 'pointer', backgroundColor: colors.white, + color: colors.grey, borderRadius: 8, minWidth: 160, outline: 'none', @@ -295,6 +299,30 @@ const styles = { transform: 'translateX(-45%)', pointerEvents: 'none' }, + recallContainer: { + position: 'absolute', + top: '4%', + right: '2.25%', + color: colors.white, + display: 'flex', + alignItems: 'center', + justifyContent: 'space-between' + }, + recallIcon: { + width: 30, + height: 30, + border: `5px solid ${colors.white}`, + borderRadius: 50, + padding: 6, + marginLeft: 8, + backgroundColor: colors.lightGrey + }, + bgRed: { + backgroundColor: colors.red + }, + bgGreen: { + backgroundColor: colors.green + }, pill: { display: 'flex', alignItems: 'center' @@ -855,12 +883,25 @@ let Predict = class Predict extends React.Component { }; Predict = Radium(Predict); -class Pond extends React.Component { +let Pond = class Pond extends React.Component { constructor(props) { super(props); } - onPondClick(e) { + toggleRecall = () => { + const state = getState(); + const showRecallFish = !state.showRecallFish; + const fish = showRecallFish ? state.recallFish : state.pondFish; + + // Don't call arrangeFish if fish have already been arranged. + if (fish.length > 0 && !fish[0].getXY()) { + arrangeFish(fish); + } + + setState({showRecallFish}); + }; + + onPondClick = e => { // Don't allow pond clicks if a Guide is currently showing. if (getCurrentGuide()) { return; @@ -920,13 +961,32 @@ class Pond extends React.Component { playSound('no'); } } - } + }; render() { const state = getState(); return ( -
this.onPondClick(e)}> + +