Skip to content

Commit dbe7c99

Browse files
committed
refactor: make sum by row or column return an array
BREAKING CHANGE: `matrix.sum('row')` and `matrix.sum('column')` now return an array instead of a Matrix.
1 parent 63b95d1 commit dbe7c99

File tree

4 files changed

+52
-41
lines changed

4 files changed

+52
-41
lines changed

matrix.d.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,17 @@ declare module 'ml-matrix' {
149149
minColumnIndex(column: number): number[];
150150
diag(): number[];
151151
diagonal(): number[];
152-
sum(by?: 'row' | 'column'): Matrix | number;
152+
153+
/**
154+
* Returns the sum of all elements of the matrix.
155+
*/
156+
sum(): number;
157+
/**
158+
* Returns the sum by the dimension given.
159+
* @param by - sum by 'row' or 'column'.
160+
*/
161+
sum(by: 'row' | 'column'): number[];
162+
153163
mean(): number;
154164
prod(): number;
155165
norm(type: 'frobenius' | 'max'): number;

src/abstractMatrix.js

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@ import {
88
checkColumnIndex,
99
checkColumnVector,
1010
checkRange,
11-
checkIndices,
12-
sumByRow,
13-
sumByColumn,
14-
sumAll
11+
checkIndices
1512
} from './util';
13+
import { sumByRow, sumByColumn, sumAll } from './stat';
1614
import MatrixTransposeView from './views/transpose';
1715
import MatrixRowView from './views/row';
1816
import MatrixSubView from './views/sub';
@@ -949,20 +947,16 @@ export default function AbstractMatrix(superCtor) {
949947
return diag;
950948
}
951949

952-
/**
953-
* Returns the sum by the argument given, if no argument given,
954-
* it returns the sum of all elements of the matrix.
955-
* @param {string} by - sum by 'row' or 'column'.
956-
* @return {Matrix|number}
957-
*/
958950
sum(by) {
959951
switch (by) {
960952
case 'row':
961953
return sumByRow(this);
962954
case 'column':
963955
return sumByColumn(this);
964-
default:
956+
case undefined:
965957
return sumAll(this);
958+
default:
959+
throw new Error(`invalid option: ${by}`);
966960
}
967961
}
968962

src/stat.js

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import { newArray } from './util';
2+
3+
export function sumByRow(matrix) {
4+
var sum = newArray(matrix.rows);
5+
for (var i = 0; i < matrix.rows; ++i) {
6+
for (var j = 0; j < matrix.columns; ++j) {
7+
sum[i] += matrix.get(i, j);
8+
}
9+
}
10+
return sum;
11+
}
12+
13+
export function sumByColumn(matrix) {
14+
var sum = newArray(matrix.columns);
15+
for (var i = 0; i < matrix.rows; ++i) {
16+
for (var j = 0; j < matrix.columns; ++j) {
17+
sum[j] += matrix.get(i, j);
18+
}
19+
}
20+
return sum;
21+
}
22+
23+
export function sumAll(matrix) {
24+
var v = 0;
25+
for (var i = 0; i < matrix.rows; i++) {
26+
for (var j = 0; j < matrix.columns; j++) {
27+
v += matrix.get(i, j);
28+
}
29+
}
30+
return v;
31+
}

src/util.js

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import Matrix from './matrix';
2-
31
/**
42
* @private
53
* Check that a row index is not out of bounds
@@ -140,34 +138,12 @@ export function getRange(from, to) {
140138
return arr;
141139
}
142140

143-
export function sumByRow(matrix) {
144-
var sum = Matrix.zeros(matrix.rows, 1);
145-
for (var i = 0; i < matrix.rows; ++i) {
146-
for (var j = 0; j < matrix.columns; ++j) {
147-
sum.set(i, 0, sum.get(i, 0) + matrix.get(i, j));
148-
}
149-
}
150-
return sum;
151-
}
152-
153-
export function sumByColumn(matrix) {
154-
var sum = Matrix.zeros(1, matrix.columns);
155-
for (var i = 0; i < matrix.rows; ++i) {
156-
for (var j = 0; j < matrix.columns; ++j) {
157-
sum.set(0, j, sum.get(0, j) + matrix.get(i, j));
158-
}
159-
}
160-
return sum;
161-
}
162-
163-
export function sumAll(matrix) {
164-
var v = 0;
165-
for (var i = 0; i < matrix.rows; i++) {
166-
for (var j = 0; j < matrix.columns; j++) {
167-
v += matrix.get(i, j);
168-
}
141+
export function newArray(length) {
142+
var array = [];
143+
for (var i = 0; i < length; i++) {
144+
array.push(i);
169145
}
170-
return v;
146+
return array;
171147
}
172148

173149
function checkNumber(name, value) {

0 commit comments

Comments
 (0)