diff --git a/src/dashboard/src/pages/Submission/DataJob.tsx b/src/dashboard/src/pages/Submission/DataJob.tsx index e6885a078..ec9ee4f13 100644 --- a/src/dashboard/src/pages/Submission/DataJob.tsx +++ b/src/dashboard/src/pages/Submission/DataJob.tsx @@ -18,7 +18,7 @@ import { DirectoryPathTextField } from './components/GPUCard' import ClustersContext from '../../contexts/Clusters' import UserContext from '../../contexts/User' import TeamContext from '../../contexts/Team' -import { Link } from 'react-router-dom' +import { Link, useLocation } from 'react-router-dom' import Slide from '@material-ui/core/Slide' import { green } from '@material-ui/core/colors' import useFetch from 'use-http' @@ -42,6 +42,7 @@ const Transition = React.forwardRef { + const location = useLocation() const styles = useStyles() const [azureDataStorage, setAzureDataStorage] = useState('') const [nfsDataStorage, setNFSDataStorage] = useState('') @@ -51,7 +52,13 @@ const DataJob: React.FC = (props: any) => { const { email } = React.useContext(UserContext) const { currentTeamId } = React.useContext(TeamContext) const { clusters } = React.useContext(ClustersContext) - const [selectedCluster, saveSelectedCluster] = React.useState(() => clusters[0].id) + const [selectedCluster, saveSelectedCluster] = React.useState(() => { + const clusterId = location.state.cluster + if (clusters.some(({ id }) => id === clusterId)) { + return clusterId + } + return clusters[0].id + }) const [workStorage, setWorkStorage] = useState('') const [dataStorage, setDataStorage] = useState('') diff --git a/src/dashboard/src/pages/Submission/Training.tsx b/src/dashboard/src/pages/Submission/Training.tsx index 1d32afe72..c45fb338e 100644 --- a/src/dashboard/src/pages/Submission/Training.tsx +++ b/src/dashboard/src/pages/Submission/Training.tsx @@ -27,7 +27,7 @@ import { } from '@material-ui/core' import Tooltip from '@material-ui/core/Tooltip' import { Info, Delete, Add } from '@material-ui/icons' -import { withRouter } from 'react-router-dom' +import { useHistory, useLocation } from 'react-router-dom' import IconButton from '@material-ui/core/IconButton' import useFetch from 'use-http' import { join } from 'path' @@ -56,9 +56,17 @@ const sanitizePath = (path: string) => { path = join('.', path) return path } -const Training: React.ComponentClass = withRouter(({ history }) => { +const Training: React.FunctionComponent = () => { + const history = useHistory() + const location = useLocation() const { clusters } = React.useContext(ClustersContext) - const [selectedCluster, saveSelectedCluster] = React.useState(() => clusters[0].id) + const [selectedCluster, saveSelectedCluster] = React.useState(() => { + const clusterId = location.state.cluster + if (clusters.some(({ id }) => id === clusterId)) { + return clusterId + } + return clusters[0].id + }) const { email } = React.useContext(UserContext) const { currentTeamId } = React.useContext(TeamContext) // const team = 'platform'; @@ -1330,6 +1338,6 @@ const Training: React.ComponentClass = withRouter(({ history }) => { /> ) -}) +} export default Training