Skip to content

Commit

Permalink
refactor: simplifying typecheck and validation for predict
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonShin committed Dec 9, 2018
1 parent 1410990 commit 9520761
Showing 1 changed file with 6 additions and 27 deletions.
33 changes: 6 additions & 27 deletions src/lib/naive_bayes/multinomial.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import * as tf from '@tensorflow/tfjs';
import { countBy, zip } from 'lodash';
import { reshape } from '../ops';
import { reshape, validateMatrix2D } from '../ops';
import { IMlModel, Type1DMatrix, Type2DMatrix } from '../types';
import math from '../utils/MathExtra';

Expand Down Expand Up @@ -82,28 +82,9 @@ export class MultinomialNB<T extends number | string = number>
* @param {Type2DMatrix<number>} X - values to predict in Matrix format
* @returns T
*/
public predict(X: Type2DMatrix<number>): Type1DMatrix<T> {
try {
return X.map((x): T => this.singlePredict(x));
} catch (e) {
if (!isMatrix(X)) {
throw new Error('X must be a matrix');
} else {
throw e;
}
}
}

/**
* @param {IterableIterator<IterableIterator<number>>} X
* @returns IterableIterator
*/
public *predictIterator(
X: IterableIterator<IterableIterator<number>>
): IterableIterator<T> {
for (const x of X) {
yield this.singlePredict([...x]);
}
public predict(X: Type2DMatrix<number>): T[] {
validateMatrix2D(X);
return X.map(x => this.singlePredict(x));
}

/**
Expand Down Expand Up @@ -146,9 +127,7 @@ export class MultinomialNB<T extends number | string = number>
* @param {ReadonlyArray<number>} predictRow
* @returns T
*/
private singlePredict(
predictRow: Type1DMatrix<number>
): Type1DMatrix<T>[any] {
private singlePredict(predictRow: Type1DMatrix<number>): T {
const matrixX: tf.Tensor<tf.Rank> = tf.tensor1d(
predictRow as number[],
'float32'
Expand Down Expand Up @@ -181,7 +160,7 @@ export class MultinomialNB<T extends number | string = number>
const selectionIndex = allProbabilities.argMax().dataSync()[0];
allProbabilities.dispose();

return this.classCategories[selectionIndex];
return this.classCategories[selectionIndex] as T;
}

/**
Expand Down

0 comments on commit 9520761

Please sign in to comment.