diff --git a/src/lib/cluster/k_means.ts b/src/lib/cluster/k_means.ts index 2867aa86..b9658e1f 100644 --- a/src/lib/cluster/k_means.ts +++ b/src/lib/cluster/k_means.ts @@ -138,6 +138,7 @@ export class KMeans { * @returns {number[]} */ public predict(X: Type2DMatrix): number[] { + validateMatrix2D(X); return _.map(X, data => { return this.getClosestCentroids(data, this.centroids, this.distance); }); diff --git a/test/clusters/k_means.test.ts b/test/clusters/k_means.test.ts index d735fbae..594bc696 100644 --- a/test/clusters/k_means.test.ts +++ b/test/clusters/k_means.test.ts @@ -68,4 +68,16 @@ describe('clusters:k_means', () => { const pred2 = kmean.predict(predVector2); expect(_.isEqual(expectedResult, pred2)).toBe(true); }); + + it('should not fit none 2D matrix', () => { + const kmean = new KMeans({ k: 2 }); + expect(() => kmean.fit([1, 2])).toThrow( + 'The matrix is not 2D shaped: [1,2] of [2]' + ); + expect(() => kmean.fit(null)).toThrow( + 'values passed to tensor(values) must be an array of numbers or booleans, or a TypedArray' + ); + // TODO: implement datatype check to the validation method + // expect(() => kmean.fit([["aa", "bb"]])).toThrow('The matrix is not 2D shaped: [1,2] of [2]'); + }); });