Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flashlight MLP is much slower than Jax+Jit python implementation #898

Closed
xueeinstein opened this issue Jun 19, 2022 · 1 comment
Closed
Labels
question Further information is requested

Comments

@xueeinstein
Copy link

Question

Since both flashlight and jax are proposed for high performance and easy low-level ops implementation, it is desired to compare them. I simply tried a compiled MLP model (flashlight/fl/examples/Perceptron.cpp) and reproduced Jax+Flax+Jit python MLP. The results show that Flashlight is much slower than Jax.

Additional Context

All tests are conducted on a CPU-only machine.

Flashlight:
image

Jax:
image

Htop monitor:
image

Click to toggle contents of `test_jax_mlp.py`
import numpy as np

import jax
import jax.numpy as jnp

from flax import linen as nn
from flax.training import train_state

import optax


class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(100)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x


def mean_squared_error(*, logits, labels):
    df = logits - labels
    return (df * df).mean()


def get_dataset(num_samples, num_feat):
    X = np.random.random((num_feat, num_samples))
    Y = np.sum(np.power(X, 3), axis=0).transpose()
    Y += np.sin(2 * np.pi * np.random.random(num_samples))
    return X, Y


def create_train_state(rng, learning_rate, momentum, num_feat):
    mlp = MLP()
    params = mlp.init(rng, jnp.ones([1, num_feat]))['params']
    tx = optax.sgd(learning_rate, momentum)
    return train_state.TrainState.create(
        apply_fn=mlp.apply, params=params, tx=tx)


@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = MLP().apply({'params': params}, batch['X'])
        loss = mean_squared_error(logits=logits, labels=batch['Y'])
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


if __name__ == '__main__':
    num_samples = 10000
    num_feat = 10
    num_epoch = 100
    learning_rate = 0.0001
    momentum = 0.9

    X, Y = get_dataset(num_samples, num_feat)

    rng = jax.random.PRNGKey(0)
    rng, init_rng = jax.random.split(rng)
    state = create_train_state(init_rng, learning_rate, momentum, num_feat)
    del init_rng

    for e in range(1, num_epoch + 1):
        loss_lst = []
        for i in range(num_samples):
            state, loss = train_step(state, {
                'X': np.expand_dims(X[:, i], 0),
                'Y': np.expand_dims(Y[i], 0),
            })
            loss_lst.append(loss)

        avg_loss = sum(loss_lst) / len(loss_lst)
        print(f'Epoch: {e} Mean Squared Error: {avg_loss}')
@xueeinstein xueeinstein added the question Further information is requested label Jun 19, 2022
@jacobkahn
Copy link
Member

This isn't surprising. The ArrayFire CPU backend is single-threaded and not optimized for performance -- it hardly JITs computations and primarily exists as a debugging tool. While the ArrayFire OpenCL backend is better (uses a JIT and some async/parallel computation), CPU performance may still be underwhelming, and it isn't properly supported at the moment in Flashlight.

We're currently developing a CPU backend based on Intel's oneDNN library which should significantly improve performance (and will also enable proper interoperability with the ArrayFire OpenCL backend). Stay tuned for updates -- commits on that backend will appear in main directly.

cc @StrongerXi, who's adding this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants