-
Notifications
You must be signed in to change notification settings - Fork 0
/
gemm.cpp
executable file
·71 lines (58 loc) · 2.02 KB
/
gemm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <stdio.h> /* I/O lib ISOC */
#include <stdlib.h> /* Standard Lib ISOC */
#include <limits.h> /* Standard Lib ISOC */
#include <cblas.h> /* Basic Linear Algebra I/O */
#include <chrono>
#define eps 0
void matrixMultiply(int N, double * A, double * B, double * C){
for (int i = 0; i < N; i++)
for (int j = 0; j < N ; j++){
double sum = 0;
for (int k = 0; k < N ; k++)
sum += A[i *N + k] * B[k * N + j];
C[i * N + j] = sum;
}
return;
}
int main(int argc, char **argv) {
int N= 8192;
double * A = (double *)malloc( N * N * sizeof(double));
double * B = (double *)malloc( N * N * sizeof(double));
double * C = (double *)malloc( N * N * sizeof(double));
double * C_ref = (double *)malloc( N * N * sizeof(double));
for (int i = 0; i < N * N; i ++){
A[i] = double(rand())/INT_MAX;
B[i] = double(rand())/INT_MAX;
C[i] = 0;
C_ref[i] = 0;
}
auto start_time = std::chrono::steady_clock::now();
/* row_order transform transform rowsA colsB K alpha a lda b ldb beta c ldc */
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, N, N, N, 1.0, A, N, B, N, 0.0, C, N);
auto end_time = std::chrono::steady_clock::now();
printf("blas takes %f Seconds\n", std::chrono::duration_cast<std::chrono::milliseconds> (end_time - start_time).count() / 1000.0);
start_time = std::chrono::steady_clock::now();
matrixMultiply(N, A, B, C_ref);
end_time = std::chrono::steady_clock::now();
printf("naive takes %f Seconds\n", std::chrono::duration_cast<std::chrono::milliseconds> (end_time - start_time).count() / 1000.0);
int ferror =0;
int count = 0;
for (int i = 0; i < N * N; i++){
if (abs(C[i] - C_ref[i]) > eps)
{
ferror = 1;
if (count < 10){
printf("%d\t%f\t%f\n", i, C[i], C_ref[i]);
count++;
}
}
}
if (ferror !=0)
printf("Failed\n");
else
printf("Passed\n");
free(A);
free(B);
free(C);
free(C_ref);
}