Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model compiling twice when using jax==0.2.10 or later #212

Open
wil-j-wil opened this issue Mar 29, 2021 · 6 comments
Open

Model compiling twice when using jax==0.2.10 or later #212

wil-j-wil opened this issue Mar 29, 2021 · 6 comments

Comments

@wil-j-wil
Copy link

Hi,

I recently updated JAX, and noticed that my runtime increased. I have managed to isolate the issue to be that my objax model is compiling itself twice, i.e., on the second training iteration the model seems to be recompiling for some reason. This only happens for JAX versions 0.2.10 or later.

Any idea what the cause of this may be?

I hope this toy example is clear enough. I am using objax==1.3.1 and jaxlib==0.1.60

import objax
import jax.numpy as np
from jax import vmap
import time


class GaussianLikelihood(objax.Module):
    """
    The Gaussian likelihood
    """
    def __init__(self,
                 variance=0.1):
        """
        :param variance: The observation noise variance
        """
        self.variance = objax.TrainVar(np.array(variance))

    def expected_log_lik(self, y, post_mean, post_cov):
        """
        """
        exp_log_lik = (
            -0.5 * np.log(2 * np.pi)
            - 0.5 * np.log(self.variance.value)
            - 0.5 * ((y - post_mean) ** 2 + post_cov) / self.variance.value
        )
        return exp_log_lik


class GP(objax.Module):
    """
    A GP model
    """
    def __init__(self,
                 likelihood,
                 X,
                 Y):
        self.X = np.array(X)
        self.Y = np.array(Y)
        self.likelihood = likelihood
        self.posterior_mean = objax.StateVar(np.zeros([X.shape[0], 1, 1]))
        self.posterior_variance = objax.StateVar(np.ones([X.shape[0], 1, 1]))

    def energy(self):
        """
        """
        mean_f, cov_f = self.posterior_mean.value, self.posterior_variance.value

        E = vmap(self.likelihood.expected_log_lik)(
            self.Y,
            mean_f,
            cov_f
        )

        return np.sum(E)


# generate some data
N = 1000000
x = np.linspace(-10, 100, num=N)
y = np.sin(x)

# set up the model
lik = GaussianLikelihood(variance=1.0)
gp_model = GP(likelihood=lik, X=x, Y=y)

energy = objax.GradValues(gp_model.energy, gp_model.vars())

lr_adam = 0.1
iters = 10
opt = objax.optimizer.Adam(gp_model.vars())


def train_op():
    dE, E = energy()  # compute energy and its gradients w.r.t. hypers
    return dE, E


train_op = objax.Jit(train_op, gp_model.vars())

t0 = time.time()
for i in range(1, iters + 1):
    t2 = time.time()
    grad, loss = train_op()
    opt(lr_adam, grad)
    t3 = time.time()
    # print('iter %2d, energy: %1.4f' % (i, loss[0]))
    print('iter time: %2.2f secs' % (t3-t2))
t1 = time.time()
print('optimisation time: %2.2f secs' % (t1-t0))

Running this script with jax==0.2.9 gives

iter time: 0.12 secs
iter time: 0.01 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.01 secs
iter time: 0.01 secs
optimisation time: 0.28 secs

Running the script with jax==0.2.10 gives

iter time: 0.14 secs
iter time: 0.08 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
optimisation time: 0.38 secs

As you can see, there is a significant difference in the 2nd iteration, as if the model is re-compiling itself.

@AlexeyKurakin
Copy link
Member

Hi,

I tried to run this code in colab (with jax.__version__ = '0.2.11' and jaxlib.__version == '0.1.64') and I can not reproduce this behavior. I actually observe that second iteration is the fastest. Specifically, when I run code on CPU in Colab, the output is following:

iter time: 0.20 secs
iter time: 0.01 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
iter time: 0.02 secs
optimisation time: 0.35 secs

If I run the code on GPU (you need to select GPU in Runtime -> Change runtime type menu) I see the following output:

iter time: 0.73 secs
iter time: 0.01 secs
iter time: 0.07 secs
iter time: 0.06 secs
iter time: 0.07 secs
iter time: 0.07 secs
iter time: 0.07 secs
iter time: 0.07 secs
iter time: 0.07 secs
iter time: 0.07 secs
optimisation time: 1.27 secs

So I can recommend you to try to upgrade to the latest jax and jaxlib version and see whether performance improved.

If you want to experiment with the latest version of JAX in colab you can do it here: https://colab.sandbox.google.com/
To install objax into colab add line %pip --quiet install objax, see example in this Objax colab tutorial.

@AlexeyKurakin
Copy link
Member

Few other general suggestions about the code you posted:

  1. For better performance opt should be inside train_op, otherwise opt won't be compiled and be executed line by line instead. Here is example:
@objax.Function(gp_model.vars() + opt.vars())
def train_op():
  dE, E = energy()
  opt(lr_adam, dE)
  return E

train_op = objax.Jit(train_op)
  1. If you deal with Objax variables (TrainVar, StateVar, etc..) its better to avoid directly using jax.vmap because it may treat JAX variables as constants. Instead it's better to use objax.Vectorize:
class GaussianLikelihood(objax.Module):
    """
    The Gaussian likelihood
    """
    def __init__(self,
                 variance=0.1):
        """
        :param variance: The observation noise variance
        """
        self.variance = objax.TrainVar(np.array(variance))

    def __call__(self, y, post_mean, post_cov):
        """
        """
        exp_log_lik = (
            -0.5 * np.log(2 * np.pi)
            - 0.5 * np.log(self.variance.value)
            - 0.5 * ((y - post_mean) ** 2 + post_cov) / self.variance.value
        )
        return exp_log_lik

class GP:
    def __init__:
        self.likelihood = objax.Vectorize(GaussianLikelihood())

    def energy():
        return np.sum(self.likelyhood( ... ))
  1. Generally speaking vmap or objax.Vectorize should not be called on each training step. Instead call Vectorize inside __init__. Output of vmap / Vectorize is another callable which you can call on each training step. Internally vmap / Vectorize perform compilation, so if you call it on every training step you may end up with unnecessary compilation on each training step. See example above on how this should look like.

I'm closing this issue for now, feel free to reopen it if you have more questions.

@wil-j-wil
Copy link
Author

Hi @AlexeyKurakin ,
Thanks a lot for the code tips, and apologies for the non-reproducible bug report.

https://colab.research.google.com/drive/13yKlZ1-fI_pIG3gt_J5WFuiEco1PkcYw?usp=sharing

Here is a colab notebook which shows the double-compile issue.

For the latest versions of jax and jaxlib, this gives

iter time: 0.07 secs
iter time: 0.06 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
optimisation time: 0.14 secs

whereas running the line %pip --quiet install --upgrade jax==0.2.9 jaxlib==0.1.60 at the start then results in

iter time: 0.06 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
optimisation time: 0.07 secs

Can you spot something I'm doing wrong, or do you think this is a bug?

@wil-j-wil
Copy link
Author

I should also mention that I suspect there is some double-compiling going on because when I run much larger models with a similar setup, I have observed the Very slow compile? If you want to file a bug, ... error twice for a single script (on each of the first 2 iterations).

@wil-j-wil
Copy link
Author

@AlexeyKurakin I also tried out your suggestions about using Vectorise rather than vmap, and now it seems like the model compiles on the first three iterations. I have checked this in google colab. Could we reopen the issue (I can't reopen myself)?

import objax
import jax.numpy as np
import time


class GaussianLikelihood(objax.Module):
    """
    The Gaussian likelihood
    """
    def __init__(self,
                 variance=0.1):
        """
        :param variance: The observation noise variance
        """
        self.variance = objax.TrainVar(np.array(variance))

    def __call__(self, y, post_mean, post_cov):
        """
        """
        exp_log_lik = (
            -0.5 * np.log(2 * np.pi)
            - 0.5 * np.log(self.variance.value)
            - 0.5 * ((y - post_mean) ** 2 + post_cov) / self.variance.value
        )
        return exp_log_lik


class GP(objax.Module):
    """
    A GP model
    """
    def __init__(self,
                 likelihood,
                 X,
                 Y):
        self.X = np.array(X)
        self.Y = np.array(Y)
        self.likelihood = objax.Vectorize(likelihood, batch_axis=(0, 0, 0))
        self.posterior_mean = objax.StateVar(np.zeros([X.shape[0], 1, 1]))
        self.posterior_variance = objax.StateVar(np.ones([X.shape[0], 1, 1]))

    def energy(self):
        """
        """
        mean_f, cov_f = self.posterior_mean.value, self.posterior_variance.value

        E = self.likelihood(
            self.Y,
            mean_f,
            cov_f
        )

        return np.sum(E)


# generate some data
N = 1000000
x = np.linspace(-10, 100, num=N)
y = np.sin(x)

# set up the model
lik = GaussianLikelihood(variance=1.0)
gp_model = GP(likelihood=lik, X=x, Y=y)

energy = objax.GradValues(gp_model.energy, gp_model.vars())

lr_adam = 0.1
iters = 10
opt = objax.optimizer.Adam(gp_model.vars())


@objax.Function.with_vars(gp_model.vars() + opt.vars())
def train_op():
    dE, E = energy()  # compute energy and its gradients w.r.t. hypers
    opt(lr_adam, dE)
    return E


train_op = objax.Jit(train_op)

t0 = time.time()
for i in range(1, iters + 1):
    t2 = time.time()
    loss = train_op()
    t3 = time.time()
    # print('iter %2d, energy: %1.4f' % (i, loss[0]))
    print('iter time: %2.2f secs' % (t3-t2))
t1 = time.time()
print('optimisation time: %2.2f secs' % (t1-t0))

gives

iter time: 0.19 secs
iter time: 0.18 secs
iter time: 0.18 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
iter time: 0.00 secs
optimisation time: 0.56 secs

@AlexeyKurakin AlexeyKurakin reopened this Mar 31, 2021
@AlexeyKurakin
Copy link
Member

ok, let me debug it and then I get back to you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants