Skip to content

Commit

Permalink
feat: added test case for custom callbacks. works great and somehow s…
Browse files Browse the repository at this point in the history
…erializes.
  • Loading branch information
dcrescim committed May 19, 2022
1 parent 2ddcad9 commit 7fa5c42
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
32 changes: 32 additions & 0 deletions src/linear_model/LinearRegression.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,38 @@ describe('LinearRegression', function () {
expect(roughlyEqual(lr.intercept as number, 0)).toBe(true)
}, 30000)

it('Works on arrays (small example) with custom callbacks', async function () {
let trainingHasStarted = false
const onTrainBegin = async (logs: any) => {
trainingHasStarted = true
console.log('training begins')
}
const lr = new LinearRegression({
modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] }
})
await lr.fit([[1], [2]], [2, 4])
expect(tensorEqual(lr.coef, tf.tensor1d([2]), 0.1)).toBe(true)
expect(roughlyEqual(lr.intercept as number, 0)).toBe(true)
expect(trainingHasStarted).toBe(true)
}, 30000)

it('Works on arrays (small example) with custom callbacks', async function () {
let trainingHasStarted = false
const onTrainBegin = async (logs: any) => {
trainingHasStarted = true
console.log('training begins')
}
const lr = new LinearRegression({
modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] }
})
await lr.fit([[1], [2]], [2, 4])

const serialized = await lr.toJSON()
const newModel = await fromJSON(serialized)
expect(tensorEqual(newModel.coef, tf.tensor1d([2]), 0.1)).toBe(true)
expect(roughlyEqual(newModel.intercept as number, 0)).toBe(true)
}, 30000)

it('Works on small multi-output example (small example)', async function () {
const lr = new LinearRegression()
await lr.fit(
Expand Down
10 changes: 6 additions & 4 deletions src/linear_model/LinearRegression.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ export interface LinearRegressionParams {
*/
fitIntercept?: boolean
modelFitOptions?: Partial<ModelFitArgs>

}

/*
Expand All @@ -53,7 +52,7 @@ Next steps:
/** Linear Least Squares
* @example
* ```js
* import {LinearRegression} from 'scikitjs'
* import { LinearRegression } from 'scikitjs'
*
* let X = [
* [1, 2],
Expand All @@ -63,13 +62,16 @@ Next steps:
* [10, 20]
* ]
* let y = [3, 5, 8, 8, 30]
* const lr = new LinearRegression({fitIntercept: false})
* const lr = new LinearRegression({ fitIntercept: false })
await lr.fit(X, y)
lr.coef.print() // probably around [1, 1]
* ```
*/
export class LinearRegression extends SGDRegressor {
constructor({ fitIntercept = true, modelFitOptions }: LinearRegressionParams = {}) {
constructor({
fitIntercept = true,
modelFitOptions
}: LinearRegressionParams = {}) {
let tf = getBackend()
super({
modelCompileArgs: {
Expand Down

0 comments on commit 7fa5c42

Please sign in to comment.