From 2ddcad9053097c75bcdd97d234c1f8b605a88241 Mon Sep 17 00:00:00 2001 From: Yaw Joseph Etse Date: Wed, 18 May 2022 22:12:38 -0400 Subject: [PATCH 1/2] feat: custom modelfitargs for linear models --- .gitignore | 3 ++- src/linear_model/LinearRegression.ts | 8 ++++++-- src/linear_model/LogisticRegression.ts | 8 ++++++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 4d06003..b57f5e7 100644 --- a/.gitignore +++ b/.gitignore @@ -107,4 +107,5 @@ dist # IDE Files .vscode/ -.idea/ \ No newline at end of file +.idea/ +.dccache \ No newline at end of file diff --git a/src/linear_model/LinearRegression.ts b/src/linear_model/LinearRegression.ts index 1913ed2..995f577 100644 --- a/src/linear_model/LinearRegression.ts +++ b/src/linear_model/LinearRegression.ts @@ -15,6 +15,7 @@ import { SGDRegressor } from './SgdRegressor' import { getBackend } from '../tf-singleton' +import { ModelFitArgs } from '../types' /** * LinearRegression implementation using gradient descent @@ -39,6 +40,8 @@ export interface LinearRegressionParams { * **default = true** */ fitIntercept?: boolean + modelFitOptions?: Partial + } /* @@ -66,7 +69,7 @@ Next steps: * ``` */ export class LinearRegression extends SGDRegressor { - constructor({ fitIntercept = true }: LinearRegressionParams = {}) { + constructor({ fitIntercept = true, modelFitOptions }: LinearRegressionParams = {}) { let tf = getBackend() super({ modelCompileArgs: { @@ -80,7 +83,8 @@ export class LinearRegression extends SGDRegressor { verbose: 0, callbacks: [ tf.callbacks.earlyStopping({ monitor: 'mse', patience: 30 }) - ] + ], + ...modelFitOptions }, denseLayerArgs: { units: 1, diff --git a/src/linear_model/LogisticRegression.ts b/src/linear_model/LogisticRegression.ts index 159cd36..b235bb3 100644 --- a/src/linear_model/LogisticRegression.ts +++ b/src/linear_model/LogisticRegression.ts @@ -15,6 +15,7 @@ import { SGDClassifier } from './SgdClassifier' import { getBackend } from '../tf-singleton' +import { ModelFitArgs } from '../types' // First pass at a LogisticRegression implementation using gradient descent // Trying to mimic the API of scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html @@ -35,6 +36,7 @@ export interface LogisticRegressionParams { C?: number /** Whether or not the intercept should be estimator not. **default = true** */ fitIntercept?: boolean + modelFitOptions?: Partial } /** Builds a linear classification model with associated penalty and regularization @@ -63,7 +65,8 @@ export class LogisticRegression extends SGDClassifier { constructor({ penalty = 'l2', C = 1, - fitIntercept = true + fitIntercept = true, + modelFitOptions }: LogisticRegressionParams = {}) { // Assume Binary classification // If we call fit, and it isn't binary then update args @@ -80,7 +83,8 @@ export class LogisticRegression extends SGDClassifier { verbose: 0, callbacks: [ tf.callbacks.earlyStopping({ monitor: 'loss', patience: 50 }) - ] + ], + ...modelFitOptions }, denseLayerArgs: { units: 1, From 7fa5c4259902d7dca0a925002cbfaf1937dc2b1b Mon Sep 17 00:00:00 2001 From: Dan Crescimanno Date: Wed, 18 May 2022 21:41:05 -0700 Subject: [PATCH 2/2] feat: added test case for custom callbacks. works great and somehow serializes. --- src/linear_model/LinearRegression.test.ts | 32 +++++++++++++++++++++++ src/linear_model/LinearRegression.ts | 10 ++++--- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/linear_model/LinearRegression.test.ts b/src/linear_model/LinearRegression.test.ts index 2e54a97..6681df6 100644 --- a/src/linear_model/LinearRegression.test.ts +++ b/src/linear_model/LinearRegression.test.ts @@ -17,6 +17,38 @@ describe('LinearRegression', function () { expect(roughlyEqual(lr.intercept as number, 0)).toBe(true) }, 30000) + it('Works on arrays (small example) with custom callbacks', async function () { + let trainingHasStarted = false + const onTrainBegin = async (logs: any) => { + trainingHasStarted = true + console.log('training begins') + } + const lr = new LinearRegression({ + modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] } + }) + await lr.fit([[1], [2]], [2, 4]) + expect(tensorEqual(lr.coef, tf.tensor1d([2]), 0.1)).toBe(true) + expect(roughlyEqual(lr.intercept as number, 0)).toBe(true) + expect(trainingHasStarted).toBe(true) + }, 30000) + + it('Works on arrays (small example) with custom callbacks', async function () { + let trainingHasStarted = false + const onTrainBegin = async (logs: any) => { + trainingHasStarted = true + console.log('training begins') + } + const lr = new LinearRegression({ + modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] } + }) + await lr.fit([[1], [2]], [2, 4]) + + const serialized = await lr.toJSON() + const newModel = await fromJSON(serialized) + expect(tensorEqual(newModel.coef, tf.tensor1d([2]), 0.1)).toBe(true) + expect(roughlyEqual(newModel.intercept as number, 0)).toBe(true) + }, 30000) + it('Works on small multi-output example (small example)', async function () { const lr = new LinearRegression() await lr.fit( diff --git a/src/linear_model/LinearRegression.ts b/src/linear_model/LinearRegression.ts index 995f577..c09a620 100644 --- a/src/linear_model/LinearRegression.ts +++ b/src/linear_model/LinearRegression.ts @@ -41,7 +41,6 @@ export interface LinearRegressionParams { */ fitIntercept?: boolean modelFitOptions?: Partial - } /* @@ -53,7 +52,7 @@ Next steps: /** Linear Least Squares * @example * ```js - * import {LinearRegression} from 'scikitjs' + * import { LinearRegression } from 'scikitjs' * * let X = [ * [1, 2], @@ -63,13 +62,16 @@ Next steps: * [10, 20] * ] * let y = [3, 5, 8, 8, 30] - * const lr = new LinearRegression({fitIntercept: false}) + * const lr = new LinearRegression({ fitIntercept: false }) await lr.fit(X, y) lr.coef.print() // probably around [1, 1] * ``` */ export class LinearRegression extends SGDRegressor { - constructor({ fitIntercept = true, modelFitOptions }: LinearRegressionParams = {}) { + constructor({ + fitIntercept = true, + modelFitOptions + }: LinearRegressionParams = {}) { let tf = getBackend() super({ modelCompileArgs: {