Skip to content

Commit fdc1c07

Browse files
jajoetargos
authored andcommitted
feat: add fast multiplication algorithm (strassen)
* add an implementation of matrix product in the benchmark (strassen's algorithm) * add the function mmul_strassen to the class abstractMatrix. * Modification of the benchmark of mmul : use integer instead of float between 0 and 1. * Add a test of mmul_strassen
1 parent 3de8a15 commit fdc1c07

File tree

3 files changed

+238
-26
lines changed

3 files changed

+238
-26
lines changed

benchmark/mmul.js

Lines changed: 104 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,126 @@
22

33
var x = parseInt(process.argv[2]) || 5;
44
var y = parseInt(process.argv[3]) || x;
5-
console.log(`mmul operations benchmark for ${x}x${y} matrix`);
5+
// console.log(`mmul operations benchmark for ${x}x${y} matrix`);
66

77
var Benchmark = require('benchmark');
88
var suite = new Benchmark.Suite;
99

1010
var Matrix = require('../src/index');
1111

12-
Matrix.prototype.mmul2 = function (other) {
13-
other = Matrix.checkMatrix(other);
14-
if (this.columns !== other.rows)
15-
console.warn('Number of columns of left matrix are not equal to number of rows of right matrix.');
16-
17-
var m = this.rows;
18-
var n = this.columns;
19-
var p = other.columns;
20-
21-
var result = Matrix.zeros(m, p);
22-
for (var i = 0; i < m; i++) {
23-
for (var k = 0; k < n; k++) {
24-
for (var j = 0; j < p; j++) {
25-
result[i][j] += this[i][k] * other[k][j];
26-
}
27-
}
12+
function strassen_2x2(a,b){
13+
var a11 = a.get(0,0);
14+
var b11 = b.get(0,0);
15+
var a12 = a.get(0,1);
16+
var b12 = b.get(0,1);
17+
var a21 = a.get(1,0);
18+
var b21 = b.get(1,0);
19+
var a22 = a.get(1,1);
20+
var b22 = b.get(1,1);
21+
22+
// Compute intermediate values.
23+
var m1 = (a11+a22)*(b11+b22);
24+
var m2 = (a21+a22)*b11;
25+
var m3 = a11*(b12-b22);
26+
var m4 = a22*(b21-b11);
27+
var m5 = (a11+a12)*b22;
28+
var m6 = (a21-a11)*(b11+b12);
29+
var m7 = (a12-a22)*(b21+b22);
30+
31+
// Combine intermediate values into the output.
32+
var c11 =m1+m4-m5+m7;
33+
var c12 = m3+m5;
34+
var c21 = m2+m4;
35+
var c22 = m1-m2+m3+m6;
36+
37+
var c = new Matrix(2,2);
38+
c.set(0,0,c11);
39+
c.set(0,1,c12);
40+
c.set(1,0,c21);
41+
c.set(1,1,c22);
42+
return c;
43+
}
44+
45+
// bad, very bad...
46+
function strassen_nxn(a,b){
47+
if(a.rows == 2){
48+
return strassen_2x2(a, b);
2849
}
29-
return result;
30-
};
50+
else{
51+
var size = a.rows;
52+
var size1 = size - 1;
53+
var demi_size0 = parseInt(size/2);
54+
var demi_size1 = parseInt(demi_size0 - 1);
55+
// a et b must be the same size and rows = columns
56+
57+
var a11 = a.subMatrix(0, demi_size1, 0, demi_size1);
58+
var b11 = b.subMatrix(0, demi_size1, 0, demi_size1);
59+
var a12 = a.subMatrix(0, demi_size1, demi_size0, size1);
60+
var b12 = b.subMatrix(0, demi_size1, demi_size0, size1);
61+
var a21 = a.subMatrix(demi_size0, size1, 0, demi_size1);
62+
var b21 = b.subMatrix(demi_size0, size1, 0, demi_size1);
63+
var a22 = a.subMatrix(demi_size0, size1, demi_size0, size1);
64+
var b22 = b.subMatrix(demi_size0, size1, demi_size0, size1);
65+
66+
// Compute intermediate values.
67+
var m1 = strassen_nxn(Matrix.add(a11,a22),Matrix.add(b11,b22));
68+
var m2 = strassen_nxn(Matrix.add(a21,a22),b11);
69+
var m3 = strassen_nxn(a11,Matrix.sub(b12,b22));
70+
var m4 = strassen_nxn(a22,Matrix.sub(b21,b11));
71+
var m5 = strassen_nxn(Matrix.add(a11,a12),b22);
72+
var m6 = strassen_nxn(Matrix.sub(a21,a11),Matrix.add(b11,b12));
73+
var m7 = strassen_nxn(Matrix.sub(a12,a22),Matrix.add(b21,b22));
74+
75+
// Combine intermediate values into the output.
76+
var c11 = Matrix.add(m1,m4).sub(m5).add(m7);
77+
var c12 = Matrix.add(m3,m5);
78+
var c21 = Matrix.add(m2,m4);
79+
var c22 = Matrix.sub(m1,m2).add(m3).add(m6);
80+
81+
var c = new Matrix(size,size);
82+
c.setSubMatrix(c11,0,0);
83+
c.setSubMatrix(c12,0,demi_size0);
84+
c.setSubMatrix(c21,demi_size0,0);
85+
c.setSubMatrix(c22,demi_size0,demi_size0);
86+
return c;
87+
}
88+
}
89+
90+
91+
var m = Matrix.randInt(x, y);
92+
var m2 = Matrix.randInt(y, x);
3193

32-
var m = Matrix.rand(x, y);
33-
var m2 = Matrix.rand(y, x);
94+
/*console.log("test avec strassen n by n")
95+
console.time("r0");
96+
var r0 = m.mmul_strassen_2(m, m2);
97+
console.timeEnd("r0")*/
98+
console.log("test avec une implementation standard")
99+
console.time("r1");
100+
var r1 = m.mmul(m2);
101+
console.timeEnd("r1")
102+
console.log("test avec une implementation de Strassen basee sur du Dynamic Padding")
103+
console.time("r2")
104+
var r2 = m.mmul_strassen(m, m2);
105+
console.timeEnd("r2")
106+
if(x == 2 && y == 2){
107+
console.log("Test avec Strassen 2*2")
108+
console.time("r3")
109+
var r3 =strassen_2x2(m, m2);
110+
console.timeEnd("r3")
111+
}
34112

35-
suite
113+
/*suite
36114
.add('mmul1', function() {
37115
m.mmul(m2);
38116
})
39117
.add('mmul2', function() {
40-
m.mmul2(m2);
118+
m.mmul_strassen(m, m2);
41119
})
42120
.on('cycle', function(event) {
43-
console.log(String(event.target));
121+
console.log(String(event.target));
44122
})
45123
.on('complete', function() {
46-
console.log('Fastest is ' + this.filter('fastest').map('name'));
124+
console.log('Fastest is ' + this.filter('fastest').map('name'));
47125
})
48126
.run();
127+
*/

src/abstractMatrix.js

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,25 @@ function abstractMatrix(superCtor) {
122122
return matrix;
123123
}
124124

125+
/**
126+
* Creates a matrix with the given dimensions. Values will be randomly set.
127+
* @param {number} rows - Number of rows
128+
* @param {number} columns - Number of columns
129+
* @param {function} [rng] - Random number generator (default: Math.random)
130+
* @returns {Matrix} The new matrix
131+
*/
132+
static randInt(rows, columns, rng) {
133+
if (rng === undefined) rng = Math.random;
134+
var matrix = this.empty(rows, columns);
135+
for (var i = 0; i < rows; i++) {
136+
for (var j = 0; j < columns; j++) {
137+
var value = parseInt(rng()*1000);
138+
matrix.set(i, j, value);
139+
}
140+
}
141+
return matrix;
142+
}
143+
125144
/**
126145
* Creates an identity matrix with the given dimension. Values of the diagonal will be 1 and others will be 0.
127146
* @param {number} rows - Number of rows
@@ -958,6 +977,114 @@ function abstractMatrix(superCtor) {
958977
return result;
959978
}
960979

980+
/**
981+
* Returns the matrix product between x and y. More efficient than mmul(other) only when we multiply squared matrix and when the size of the matrix is > 1000.
982+
* @param {Matrix} x
983+
* @param {Matrix} y
984+
* @returns {Matrix}
985+
*/
986+
mmul_strassen(y){
987+
var x = this.clone();
988+
var r1 = x.rows;
989+
var c1 = x.columns;
990+
var r2 = y.rows;
991+
var c2 = y.columns;
992+
if(c1 != r2){
993+
console.log(`Multiplying ${r1} x ${c1} and ${r2} x ${c2} matrix: dimensions do not match.`)
994+
}
995+
996+
// Put a matrix into the top left of a matrix of zeros.
997+
// `rows` and `cols` are the dimensions of the output matrix.
998+
function embed(mat, rows, cols){
999+
var r = mat.rows;
1000+
var c = mat.columns;
1001+
if((r==rows) && (c==cols)){
1002+
return mat;
1003+
}
1004+
else{
1005+
var resultat = Matrix.zeros(rows, cols);
1006+
resultat = resultat.setSubMatrix(mat, 0, 0);
1007+
return resultat;
1008+
}
1009+
}
1010+
1011+
1012+
// Make sure both matrices are the same size.
1013+
// This is exclusively for simplicity:
1014+
// this algorithm can be implemented with matrices of different sizes.
1015+
1016+
var r = Math.max(r1, r2);
1017+
var c = Math.max(c1, c2);
1018+
var x = embed(x, r, c);
1019+
var y = embed(y, r, c);
1020+
1021+
// Our recursive multiplication function.
1022+
function block_mult(a, b, rows, cols){
1023+
// For small matrices, resort to naive multiplication.
1024+
if (rows <= 512 || cols <= 512){
1025+
return a.mmul(b); // a is equivalent to this
1026+
}
1027+
1028+
// Apply dynamic padding.
1029+
if ((rows % 2 == 1) && (cols % 2 == 1)) {
1030+
a = embed(a, rows + 1, cols + 1);
1031+
b = embed(b, rows + 1, cols + 1);
1032+
}
1033+
else if (rows % 2 == 1){
1034+
a = embed(a, rows + 1, cols);
1035+
b = embed(b, rows + 1, cols);
1036+
}
1037+
else if (cols % 2 == 1){
1038+
a = embed(a, rows, cols + 1);
1039+
b = embed(b, rows, cols + 1);
1040+
}
1041+
1042+
var half_rows = parseInt(a.rows / 2);
1043+
var half_cols = parseInt(a.columns / 2);
1044+
// Subdivide input matrices.
1045+
var a11 = a.subMatrix(0, half_rows -1, 0, half_cols - 1);
1046+
var b11 = b.subMatrix(0, half_rows -1, 0, half_cols - 1);
1047+
1048+
var a12 = a.subMatrix(0, half_rows -1, half_cols, a.columns - 1);
1049+
var b12 = b.subMatrix(0, half_rows -1, half_cols, b.columns - 1);
1050+
1051+
var a21 = a.subMatrix(half_rows, a.rows - 1, 0, half_cols - 1);
1052+
var b21 = b.subMatrix(half_rows, b.rows - 1, 0, half_cols - 1);
1053+
1054+
var a22 = a.subMatrix(half_rows, a.rows - 1, half_cols, a.columns - 1);
1055+
var b22 = b.subMatrix(half_rows, b.rows - 1, half_cols, b.columns - 1);
1056+
1057+
// Compute intermediate values.
1058+
var m1 = block_mult(Matrix.add(a11,a22), Matrix.add(b11,b22), half_rows, half_cols);
1059+
var m2 = block_mult(Matrix.add(a21,a22), b11, half_rows, half_cols);
1060+
var m3 = block_mult(a11, Matrix.sub(b12, b22), half_rows, half_cols);
1061+
var m4 = block_mult(a22, Matrix.sub(b21,b11), half_rows, half_cols);
1062+
var m5 = block_mult(Matrix.add(a11,a12), b22, half_rows, half_cols);
1063+
var m6 = block_mult(Matrix.sub(a21, a11), Matrix.add(b11, b12), half_rows, half_cols);
1064+
var m7 = block_mult(Matrix.sub(a12,a22), Matrix.add(b21,b22), half_rows, half_cols);
1065+
1066+
// Combine intermediate values into the output.
1067+
var c11 = Matrix.add(m1, m4);
1068+
c11.sub(m5);
1069+
c11.add(m7);
1070+
var c12 = Matrix.add(m3,m5);
1071+
var c21 = Matrix.add(m2,m4);
1072+
var c22 = Matrix.sub(m1,m2);
1073+
c22.add(m3);
1074+
c22.add(m6);
1075+
1076+
//Crop output to the desired size (undo dynamic padding).
1077+
var resultat = Matrix.zeros(2*c11.rows, 2*c11.columns);
1078+
resultat = resultat.setSubMatrix(c11, 0, 0);
1079+
resultat = resultat.setSubMatrix(c12, c11.rows, 0)
1080+
resultat = resultat.setSubMatrix(c21, 0, c11.columns);
1081+
resultat = resultat.setSubMatrix(c22, c11.rows, c11.columns);
1082+
return resultat.subMatrix(0, rows - 1, 0, cols - 1);
1083+
}
1084+
var resultat_final = block_mult(x, y, r, c);
1085+
return resultat_final;
1086+
};
1087+
9611088
/**
9621089
* Returns a row-by-row scaled matrix
9631090
* @param {Number} [min=0] - Minimum scaled value
@@ -1198,7 +1325,7 @@ function abstractMatrix(superCtor) {
11981325
}
11991326

12001327
/*
1201-
Matrix views
1328+
Matrix views
12021329
*/
12031330

12041331
/**

test/matrix/utility.js

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,10 @@ describe('utility methods', function () {
136136
matrix.repeat(2, 2).to2DArray().should.eql([[1, 2, 1, 2], [3, 4, 3, 4], [1, 2, 1, 2], [3, 4, 3, 4]]);
137137
matrix.repeat(1, 2).to2DArray().should.eql([[1, 2, 1, 2], [3, 4, 3, 4]]);
138138
});
139+
140+
it('mmul strassen', function (){
141+
var matrix = new Matrix([[2,4],[7,1]]);
142+
var matrix2 = new Matrix([[2,1],[1,1]]);
143+
matrix.mmul_strassen(matrix2).to2DArray().should.eql([[8,6], [15,8]]);
144+
});
139145
});

0 commit comments

Comments
 (0)