Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmarking EJML vs Numpy / Pytorch #178

Open
brr53 opened this issue Feb 1, 2023 · 5 comments
Open

Benchmarking EJML vs Numpy / Pytorch #178

brr53 opened this issue Feb 1, 2023 · 5 comments

Comments

@brr53
Copy link

brr53 commented Feb 1, 2023

I've noticed EJML matrix multiplication is more than 20 times slower than Pytorch or Numpy. Below I have some test code to reproduce the situation. Am I doing something wrong and what can I do to achieve Python results in Java? Thank you!

public static void main(String[] args) {
var mat1 = fill(new double[15][32]);
var mat2 = fill(new double[32][25600]);
var mask2dMat = new SimpleMatrix(mat1);
var proto2dMat = new SimpleMatrix(mat2);

var ts = System.currentTimeMillis();
mask2dMat.mult(proto2dMat); // multiply
System.out.println(System.currentTimeMillis() - ts); // 20ms on my PC
}

private static double[][] fill(double[][] fMat) {
for (double[] row : fMat) {
for (int i = 0; i < row.length; i++) {
row[i] = ThreadLocalRandom.current().nextFloat();
}
}
return fMat;
}

VS

mat1 = torch.randn(15, 32)
mat2 = torch.randn(32, 25600)
timestamp = int(time.time() * 1000)
mat1 @ mat2 # multiply
print(int(time.time() * 1000) - timestamp) # 0ms on same PC

Thank you.

@lessthanoptimal
Copy link
Owner

Most likely what you are seeing is Just in Time optimization vs optimization at compile time. The Java Virtual Machine uses JIT and typically when benchmark you are interested in the steady state perform. So to "warm up" the JVM you run a few iterations, then run it again. Java developers typically use JMH to do micro benchmarks since it automates all of this for you.

Also you will get more accurate results in a micro benchmark in any language/platform if you run it for a sufficient period of time. In this case, I would pick a number of iterations so that it takes a few seconds to run.

As for which one is faster in the steady state, I'm not sure here. Numpy is basically a wrapper around LAPACK and/or Eigen. On small matrices the performance is actually comparable. On large dense matrices LAPACK/EIGEN is typically 2.5 faster since their compiler can do fancier optimization. For sparse matrices the performance is comparable since the optimization advantage goes away.

@brr53
Copy link
Author

brr53 commented Feb 1, 2023

Thank you for your reply. After running the compute on a loop. It appears EJML is still 5 times slower than Pytorch with my tests. Would you be able to confirm I am using it properly please?

Test code:

static SimpleMatrix mat1 = new SimpleMatrix(fill(new double[15][32]));
static SimpleMatrix mat2 = new SimpleMatrix(fill(new double[32][25600]));
public static void multiplyAndDisplaySpeed() {
var ts = System.currentTimeMillis();
mat1.mult(mat2); // multiply
System.out.println(System.currentTimeMillis() - ts); // 20ms on my PC
}

private static double[][] fill(double[][] fMat) {
for (double[] row : fMat) {
for (int i = 0; i < row.length; i++) {
row[i] = ThreadLocalRandom.current().nextFloat();
}
}
return fMat;
}

public static void main(String[] args) {
while (true)
multiplyAndDisplaySpeed(); // this displays 5ms infinitely on my device. Meanwhile pytorch is 0ms
}

@lessthanoptimal
Copy link
Owner

lessthanoptimal commented Feb 1, 2023

Source code. Did a few permutations to see what was going on.

Java

public class Foobar {
    static Random rand = ThreadLocalRandom.current();
    static SimpleMatrix mat1 = SimpleMatrix.random_DDRM(15, 32, 0, 1.0, rand);
    static SimpleMatrix mat2 = SimpleMatrix.random_DDRM(32, 25600, 0, 1.0, rand);

    public static void multiplyAndDisplaySpeed( int count ) {
        var ts = System.nanoTime();

        for (int i = 0; i < count; i++) {
            mat1.mult(mat2); // multiply
        }
        double elapsedTimeMS = (System.nanoTime() - ts)*1e-6;
        System.out.println("A) average: " + (elapsedTimeMS/count) + " (ms) total: " + elapsedTimeMS);
    }

    public static void multiplyAndDisplaySpeed2( final int count ) {
        var output = new DMatrixRMaj(1, 1);
        var ts = System.nanoTime();
        for (int i = 0; i < count; i++) {
            CommonOps_DDRM.mult(mat1.getDDRM(), mat2.getDDRM(), output);
        }
        double elapsedTimeMS = (System.nanoTime() - ts)*1e-6;
        System.out.println("B) average: " + (elapsedTimeMS/count) + " (ms) total: " + elapsedTimeMS);
    }

    public static void multiplyAndDisplaySpeed3( final int count ) {
        var output = new DMatrixRMaj(1, 1);
        var ts = System.nanoTime();
        for (int i = 0; i < count; i++) {
            CommonOps_MT_DDRM.mult(mat1.getDDRM(), mat2.getDDRM(), output);
        }
        double elapsedTimeMS = (System.nanoTime() - ts)*1e-6;
        System.out.println("C) average: " + (elapsedTimeMS/count) + " (ms) total: " + elapsedTimeMS);
    }

    public static void main( String[] args ) {
        multiplyAndDisplaySpeed(100);
        multiplyAndDisplaySpeed(500);
        multiplyAndDisplaySpeed2(100);
        multiplyAndDisplaySpeed2(500);
        multiplyAndDisplaySpeed3(500);
        multiplyAndDisplaySpeed3(1000);
    }
}

Python:

mat1 = np.random.random((15, 32))
mat2 = np.random.random((32, 25600))

N = 2000
time0 = time.time()
for i in range(N):
  np.matmul(mat1, mat2)
time1 = time.time()

average = 1000.0*(time1-time0)/N
print("Average time", average)

@lessthanoptimal
Copy link
Owner

lessthanoptimal commented Feb 1, 2023

Running concurrent code makes a big difference here. About 7x speed up. There's a switch in SimpleMatrix which turns on threads that isn't being triggered. Could probably be improved...

Java:   0.557 (ms)
Python: 0.295 (ms)

These matrices are also large enough that the better SIMD optimization in the C/C++ code Python wraps is kicking in. I'm willing to bet if you made the matrices even bigger the speed difference would increase. The new vector API in Java should help close the gap.

@lessthanoptimal lessthanoptimal changed the title EJML over 20 times slower than Numpy / Pytorch Benchmarking EJML vs Numpy / Pytorch Feb 10, 2023
@lessthanoptimal
Copy link
Owner

lessthanoptimal commented Feb 10, 2023

Changed the title to get people to actually read this thread. A new stable has been release and here are it's benchmark results using SimpleMatrix. The logic for switching to concurrency has been improved:

SimpleMatrix:   1.000 (ms)
CommonOps_MT    0.614 (ms)

The overhead of SimpleMatrix does slow it down a bit, but most people probably won't care. NumPy still has faster performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants