Skip to content

Commit 4055ef9

Browse files
jajoetargos
authored andcommitted
feat: implement optimized algorithm for 2x2 and 3x3 multiplication
1 parent fdc1c07 commit 4055ef9

File tree

3 files changed

+186
-64
lines changed

3 files changed

+186
-64
lines changed

benchmark/mmul.js

Lines changed: 72 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,39 +9,6 @@ var suite = new Benchmark.Suite;
99

1010
var Matrix = require('../src/index');
1111

12-
function strassen_2x2(a,b){
13-
var a11 = a.get(0,0);
14-
var b11 = b.get(0,0);
15-
var a12 = a.get(0,1);
16-
var b12 = b.get(0,1);
17-
var a21 = a.get(1,0);
18-
var b21 = b.get(1,0);
19-
var a22 = a.get(1,1);
20-
var b22 = b.get(1,1);
21-
22-
// Compute intermediate values.
23-
var m1 = (a11+a22)*(b11+b22);
24-
var m2 = (a21+a22)*b11;
25-
var m3 = a11*(b12-b22);
26-
var m4 = a22*(b21-b11);
27-
var m5 = (a11+a12)*b22;
28-
var m6 = (a21-a11)*(b11+b12);
29-
var m7 = (a12-a22)*(b21+b22);
30-
31-
// Combine intermediate values into the output.
32-
var c11 =m1+m4-m5+m7;
33-
var c12 = m3+m5;
34-
var c21 = m2+m4;
35-
var c22 = m1-m2+m3+m6;
36-
37-
var c = new Matrix(2,2);
38-
c.set(0,0,c11);
39-
c.set(0,1,c12);
40-
c.set(1,0,c21);
41-
c.set(1,1,c22);
42-
return c;
43-
}
44-
4512
// bad, very bad...
4613
function strassen_nxn(a,b){
4714
if(a.rows == 2){
@@ -90,38 +57,80 @@ function strassen_nxn(a,b){
9057

9158
var m = Matrix.randInt(x, y);
9259
var m2 = Matrix.randInt(y, x);
60+
var a0 = m.clone();
61+
var a1 = m.clone();
9362

94-
/*console.log("test avec strassen n by n")
95-
console.time("r0");
96-
var r0 = m.mmul_strassen_2(m, m2);
97-
console.timeEnd("r0")*/
98-
console.log("test avec une implementation standard")
99-
console.time("r1");
100-
var r1 = m.mmul(m2);
101-
console.timeEnd("r1")
102-
console.log("test avec une implementation de Strassen basee sur du Dynamic Padding")
103-
console.time("r2")
104-
var r2 = m.mmul_strassen(m, m2);
105-
console.timeEnd("r2")
106-
if(x == 2 && y == 2){
63+
/*if(x == 2 && y == 2){
10764
console.log("Test avec Strassen 2*2")
10865
console.time("r3")
109-
var r3 =strassen_2x2(m, m2);
66+
var r3 =strassen_2x2(a0, m2);
11067
console.timeEnd("r3")
11168
}
112-
113-
/*suite
114-
.add('mmul1', function() {
115-
m.mmul(m2);
116-
})
117-
.add('mmul2', function() {
118-
m.mmul_strassen(m, m2);
119-
})
120-
.on('cycle', function(event) {
121-
console.log(String(event.target));
122-
})
123-
.on('complete', function() {
124-
console.log('Fastest is ' + this.filter('fastest').map('name'));
125-
})
126-
.run();
127-
*/
69+
if(x == 3 && y == 3){
70+
console.log("Test avec Strassen 3*3")
71+
console.time("r3")
72+
var r3 =strassen_3x3(a1, m2);
73+
console.timeEnd("r3")
74+
}*/
75+
if(x == 2 && y == 2){
76+
suite
77+
.add('mmul1', function() {
78+
m.mmul(m2);
79+
})
80+
.add('mmul_strassen', function() {
81+
m.mmul_strassen(m, m2);
82+
})
83+
.add('strassen 2x2', function() {
84+
m.strassen_2x2(m2); // a0 is a copy of m
85+
})
86+
.on('cycle', function(event) {
87+
console.log(String(event.target));
88+
})
89+
.on('complete', function() {
90+
console.log('Fastest is ' + this.filter('fastest').map('name'));
91+
})
92+
.run();
93+
}
94+
else if(x == 3 && y == 3){
95+
suite
96+
.add('mmul1', function() {
97+
m.mmul(m2);
98+
})
99+
.add('mmul_strassen', function() {
100+
m.mmul_strassen(m, m2);
101+
})
102+
.add('strassen 3x3', function() {
103+
m.strassen_3x3(m2); // a0 is a copy of m
104+
})
105+
.on('cycle', function(event) {
106+
console.log(String(event.target));
107+
})
108+
.on('complete', function() {
109+
console.log('Fastest is ' + this.filter('fastest').map('name'));
110+
})
111+
.run();
112+
}
113+
else if(Math.max(x,y) < 200){
114+
suite
115+
.add('mmul1', function() {
116+
m.mmul(m2);
117+
})
118+
.add('mmul_strassen', function() {
119+
m.mmul_strassen(m, m2);
120+
})
121+
.on('cycle', function(event) {
122+
console.log(String(event.target));
123+
})
124+
.on('complete', function() {
125+
console.log('Fastest is ' + this.filter('fastest').map('name'));
126+
})
127+
.run();
128+
}
129+
else{
130+
console.time("mmul");
131+
var r1 = m.mmul(m2);
132+
console.timeEnd("mmul")
133+
console.time("mmul strassen dynamic padding")
134+
var r2 = m.mmul_strassen(m, m2);
135+
console.timeEnd("mmul strassen dynamic padding")
136+
}

src/abstractMatrix.js

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,109 @@ function abstractMatrix(superCtor) {
977977
return result;
978978
}
979979

980+
strassen_2x2(other){
981+
var result = new this.constructor[Symbol.species](2, 2);
982+
const a11 = this.get(0,0);
983+
const b11 = other.get(0,0);
984+
const a12 = this.get(0,1);
985+
const b12 = other.get(0,1);
986+
const a21 = this.get(1,0);
987+
const b21 = other.get(1,0);
988+
const a22 = this.get(1,1);
989+
const b22 = other.get(1,1);
990+
991+
// Compute intermediate values.
992+
const m1 = (a11+a22)*(b11+b22);
993+
const m2 = (a21+a22)*b11;
994+
const m3 = a11*(b12-b22);
995+
const m4 = a22*(b21-b11);
996+
const m5 = (a11+a12)*b22;
997+
const m6 = (a21-a11)*(b11+b12);
998+
const m7 = (a12-a22)*(b21+b22);
999+
1000+
// Combine intermediate values into the output.
1001+
const c00 =m1+m4-m5+m7;
1002+
const c01 = m3+m5;
1003+
const c10 = m2+m4;
1004+
const c11 = m1-m2+m3+m6;
1005+
1006+
result.set(0,0,c00);
1007+
result.set(0,1,c01);
1008+
result.set(1,0,c10);
1009+
result.set(1,1,c11);
1010+
return result;
1011+
}
1012+
1013+
strassen_3x3(other){
1014+
var result = new this.constructor[Symbol.species](3, 3);
1015+
1016+
const a00 = this.get(0,0);
1017+
const a01 = this.get(0,1);
1018+
const a02 = this.get(0,2);
1019+
const a10 = this.get(1,0);
1020+
const a11 = this.get(1,1);
1021+
const a12 = this.get(1,2);
1022+
const a20 = this.get(2,0);
1023+
const a21 = this.get(2,1);
1024+
const a22 = this.get(2,2);
1025+
1026+
const b00 = other.get(0,0);
1027+
const b01 = other.get(0,1);
1028+
const b02 = other.get(0,2);
1029+
const b10 = other.get(1,0);
1030+
const b11 = other.get(1,1);
1031+
const b12 = other.get(1,2);
1032+
const b20 = other.get(2,0);
1033+
const b21 = other.get(2,1);
1034+
const b22 = other.get(2,2);
1035+
1036+
const m1 = (a00+a01+a02-a10-a11-a21-a22)*b11;
1037+
const m2 = (a00-a10)*(-b01+b11);
1038+
const m3 = a11*(-b00+b01+b10-b11-b12-b20+b22);
1039+
const m4 = (-a00+a10+a11)*(b00-b01+b11);
1040+
const m5 = (a10+a11)*(-b00+b01);
1041+
const m6 = a00*b00;
1042+
const m7 = (-a00+a20+a21)*(b00-b02+b12);
1043+
const m8 = (-a00+a20)*(b02-b12);
1044+
const m9 = (a20+a21)*(-b00+b02);
1045+
const m10 = (a00+a01+a02-a11-a12-a20-a21)*b12;
1046+
const m11 = a21*(-b00+b02+b10-b11-b12-b20+b21);
1047+
const m12 = (-a02+a21+a22)*(b11+b20-b21);
1048+
const m13 = (a02-a22)*(b11-b21);
1049+
const m14 = a02*b20;
1050+
const m15 = (a21+a22)*(-b20+b21);
1051+
const m16 = (-a02+a11+a12)*(b12+b20-b22);
1052+
const m17 = (a02-a12)*(b12-b22);
1053+
const m18 = (a11+a12)*(-b20+b22);
1054+
const m19= a01*b10;
1055+
const m20 = a12*b21;
1056+
const m21 = a10*b02;
1057+
const m22 = a20*b01;
1058+
const m23 = a22*b22;
1059+
1060+
const c00 = m6+m14+m19;
1061+
const c01 = m1+m4+m5+m6+m12+m14+m15;
1062+
const c02 = m6+m7+m9+m10+m14+m16+m18;
1063+
const c10 = m2+m3+m4+m6+m14+m16+m17;
1064+
const c11 = m2+m4+m5+m6+m20;
1065+
const c12 = m14+m16+m17+m18+m21;
1066+
const c20 = m6+m7+m8+m11+m12+m13+m14;
1067+
const c21 = m12+m13+m14+m15+m22;
1068+
const c22 = m6+m7+m8+m9+m23;
1069+
1070+
result.set(0,0,c00);
1071+
result.set(0,1,c01);
1072+
result.set(0,2,c02);
1073+
result.set(1,0,c10);
1074+
result.set(1,1,c11);
1075+
result.set(1,2,c12);
1076+
result.set(2,0,c20);
1077+
result.set(2,1,c21);
1078+
result.set(2,2,c22);
1079+
return result;
1080+
}
1081+
1082+
9801083
/**
9811084
* Returns the matrix product between x and y. More efficient than mmul(other) only when we multiply squared matrix and when the size of the matrix is > 1000.
9821085
* @param {Matrix} x
@@ -1636,4 +1739,4 @@ function abstractMatrix(superCtor) {
16361739
}
16371740

16381741
return Matrix;
1639-
}
1742+
}

test/matrix/utility.js

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,14 @@ describe('utility methods', function () {
142142
var matrix2 = new Matrix([[2,1],[1,1]]);
143143
matrix.mmul_strassen(matrix2).to2DArray().should.eql([[8,6], [15,8]]);
144144
});
145+
146+
it('mmul 2x2 and 3x3', function (){
147+
var matrix = new Matrix([[2,4],[7,1]]);
148+
var matrix2 = new Matrix([[2,1],[1,1]]);
149+
matrix.strassen_2x2(matrix2).to2DArray().should.eql([[8,6], [15,8]]);
150+
151+
matrix = new Matrix([[2,4,1],[7,1,2],[5,1,3]]);
152+
matrix2 = new Matrix([[2,1,3],[7,1,1],[6,2,7]]);
153+
matrix.strassen_3x3(matrix2).to2DArray().should.eql([[38,8,17],[33,12,36],[35,12,37]]);
154+
});
145155
});

0 commit comments

Comments
 (0)