Skip to content

Commit

Permalink
increase coverage of wml by extracting auth
Browse files Browse the repository at this point in the history
  • Loading branch information
bourdakos1 committed Apr 3, 2019
1 parent e36b241 commit 541b2e1
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 26 deletions.
7 changes: 0 additions & 7 deletions __tests__/api/wml.test.js

This file was deleted.

13 changes: 13 additions & 0 deletions __tests__/commands/train.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ describe('train', () => {
return promise
})

it('trains with a zip', async () => {
const promise = train([
'__tests__/fake.zip',
'--config',
'__tests__/config.yaml'
])
await wait()
await wait()
io.send('no')
io.send(keys.enter)
return promise
})

it('watches progress', async () => {
const promise = train(['--config', '__tests__/config.yaml'])
// Need to wait twice for some reason...
Expand Down
Binary file added __tests__/fake.zip
Binary file not shown.
32 changes: 13 additions & 19 deletions src/api/wml.js
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@ class WML {
}

async authenticate() {
return api.authenticate(this._url, this._username, this._password)
if (!this._token) {
this._token = await api.authenticate(
this._url,
this._username,
this._password
)
}
}

async startTraining(trainingScript) {
Expand All @@ -47,30 +53,22 @@ class WML {
}

async createMonitorSocket(modelId) {
if (!this._token) {
this._token = await this.authenticate()
}
await this.authenticate()
return api.socket(this._url, this._token, modelId)
}

async getTrainingRun(modelId) {
if (!this._token) {
this._token = await this.authenticate()
}
await this.authenticate()
return api.getModel(this._url, this._token, modelId)
}

async listTrainingRuns() {
if (!this._token) {
this._token = await this.authenticate()
}
await this.authenticate()
return api.getModels(this._url, this._token)
}

async createTrainingDefinition() {
if (!this._token) {
this._token = await this.authenticate()
}
await this.authenticate()
// Deep copy.
const trainingDefinition = JSON.parse(
JSON.stringify(DEFAULT_TRAINING_DEFINITION)
Expand All @@ -84,9 +82,7 @@ class WML {
}

async addTrainingScript(trainingDefinition, trainingScript) {
if (!this._token) {
this._token = await this.authenticate()
}
await this.authenticate()

const trainingZip = (() => {
if (trainingScript) {
Expand All @@ -105,9 +101,7 @@ class WML {
}

async startTrainingRun(trainingDefinition) {
if (!this._token) {
this._token = await this.authenticate()
}
await this.authenticate()

const steps = safeGet(() => this._trainingParams.steps) || DEFAULT_STEPS
const gpu = safeGet(() => this._trainingParams.gpu) || DEFAULT_GPU
Expand Down

0 comments on commit 541b2e1

Please sign in to comment.