Skip to content

Commit 6b57aae

Browse files
committed
feat: add mean by dimension and product methods
1 parent 220f2df commit 6b57aae

File tree

6 files changed

+123
-11
lines changed

6 files changed

+123
-11
lines changed

matrix.d.ts

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ declare module 'ml-matrix' {
22
type MaybeMatrix = Matrix | number[][];
33
type Rng = () => number;
44
type ScalarOrMatrix = number | Matrix;
5+
type MatrixDimension = 'row' | 'column';
56

67
class BaseView extends Matrix {}
78
class MatrixColumnView extends BaseView {
@@ -154,13 +155,35 @@ declare module 'ml-matrix' {
154155
* Returns the sum of all elements of the matrix.
155156
*/
156157
sum(): number;
158+
157159
/**
158-
* Returns the sum by the dimension given.
160+
* Returns the sum by the given dimension.
159161
* @param by - sum by 'row' or 'column'.
160162
*/
161-
sum(by: 'row' | 'column'): number[];
163+
sum(by: MatrixDimension): number[];
164+
165+
/**
166+
* Returns the product of all elements of the matrix.
167+
*/
168+
product(): number;
169+
170+
/**
171+
* Returns the product by the given dimension.
172+
* @param by - product by 'row' or 'column'.
173+
*/
174+
product(by: MatrixDimension): number[];
162175

176+
/**
177+
* Returns the mean of all elements of the matrix.
178+
*/
163179
mean(): number;
180+
181+
/**
182+
* Returns the mean by the given dimension.
183+
* @param by - mean by 'row' or 'column'.
184+
*/
185+
mean(by: MatrixDimension): number[];
186+
164187
prod(): number;
165188
norm(type: 'frobenius' | 'max'): number;
166189
cumulativeSum(): Matrix;

src/__tests__/matrix/mean.js

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import { Matrix } from '../..';
2+
3+
describe('mean by row and columns', () => {
4+
const matrix = new Matrix([[1, 2, 3], [4, 5, 6]]);
5+
it('mean by row', () => {
6+
expect(matrix.mean('row')).toStrictEqual([2, 5]);
7+
});
8+
9+
it('mean by column', () => {
10+
expect(matrix.mean('column')).toStrictEqual([2.5, 3.5, 4.5]);
11+
});
12+
13+
it('mean all', () => {
14+
expect(matrix.mean()).toBe(3.5);
15+
});
16+
});

src/__tests__/matrix/product.js

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import { Matrix } from '../..';
2+
3+
describe('product by row and columns', () => {
4+
const matrix = new Matrix([[1, 2, 3], [4, 5, 6]]);
5+
it('product by row', () => {
6+
expect(matrix.product('row')).toStrictEqual([6, 120]);
7+
});
8+
9+
it('product by column', () => {
10+
expect(matrix.product('column')).toStrictEqual([4, 10, 18]);
11+
});
12+
13+
it('product all', () => {
14+
expect(matrix.product()).toBe(720);
15+
});
16+
});

src/abstractMatrix.js

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import {
1010
checkRange,
1111
checkIndices
1212
} from './util';
13-
import { sumByRow, sumByColumn, sumAll } from './stat';
13+
import { sumByRow, sumByColumn, sumAll, productByRow, productByColumn, productAll } from './stat';
1414
import MatrixTransposeView from './views/transpose';
1515
import MatrixRowView from './views/row';
1616
import MatrixSubView from './views/sub';
@@ -960,12 +960,39 @@ export default function AbstractMatrix(superCtor) {
960960
}
961961
}
962962

963-
/**
964-
* Returns the mean of all elements of the matrix
965-
* @return {number}
966-
*/
967-
mean() {
968-
return this.sum() / this.size;
963+
product(by) {
964+
switch (by) {
965+
case 'row':
966+
return productByRow(this);
967+
case 'column':
968+
return productByColumn(this);
969+
case undefined:
970+
return productAll(this);
971+
default:
972+
throw new Error(`invalid option: ${by}`);
973+
}
974+
}
975+
976+
mean(by) {
977+
const sum = this.sum(by);
978+
switch (by) {
979+
case 'row': {
980+
for (let i = 0; i < this.rows; i++) {
981+
sum[i] /= this.columns;
982+
}
983+
return sum;
984+
}
985+
case 'column': {
986+
for (let i = 0; i < this.columns; i++) {
987+
sum[i] /= this.rows;
988+
}
989+
return sum;
990+
}
991+
case undefined:
992+
return sum / this.size;
993+
default:
994+
throw new Error(`invalid option: ${by}`);
995+
}
969996
}
970997

971998
/**

src/stat.js

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,33 @@ export function sumAll(matrix) {
2929
}
3030
return v;
3131
}
32+
33+
export function productByRow(matrix) {
34+
var sum = newArray(matrix.rows, 1);
35+
for (var i = 0; i < matrix.rows; ++i) {
36+
for (var j = 0; j < matrix.columns; ++j) {
37+
sum[i] *= matrix.get(i, j);
38+
}
39+
}
40+
return sum;
41+
}
42+
43+
export function productByColumn(matrix) {
44+
var sum = newArray(matrix.columns, 1);
45+
for (var i = 0; i < matrix.rows; ++i) {
46+
for (var j = 0; j < matrix.columns; ++j) {
47+
sum[j] *= matrix.get(i, j);
48+
}
49+
}
50+
return sum;
51+
}
52+
53+
export function productAll(matrix) {
54+
var v = 1;
55+
for (var i = 0; i < matrix.rows; i++) {
56+
for (var j = 0; j < matrix.columns; j++) {
57+
v *= matrix.get(i, j);
58+
}
59+
}
60+
return v;
61+
}

src/util.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,10 @@ export function getRange(from, to) {
138138
return arr;
139139
}
140140

141-
export function newArray(length) {
141+
export function newArray(length, value = 0) {
142142
var array = [];
143143
for (var i = 0; i < length; i++) {
144-
array.push(0);
144+
array.push(value);
145145
}
146146
return array;
147147
}

0 commit comments

Comments
 (0)