Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No reproducible using tensorflow backend #2280

Closed
zhuoqiang opened this issue Apr 12, 2016 · 96 comments
Closed

No reproducible using tensorflow backend #2280

zhuoqiang opened this issue Apr 12, 2016 · 96 comments

Comments

@zhuoqiang
Copy link

with theano backend (CPU or GPU without cnDNN), I could train reproducible model by

fixed_seed_num = 1234
nunpy.random.seed(fixed_seed_num)
random.seed(fixed_seed_num) # not sure if needed or not

While in pure tensorflow without keras wrapper, it could also be reproducible by

tersorflow.set_random_seed(fixed_seed_num)

Don't know why, but in Keras + tensorflow backend, non of the above could give reproducible training model.

Environment:

  • Keras: v0.3.2
  • tensorflow: v0.7.1
  • Macbook OS X v10.11.4
  • tensorflow is using CPU

BTW: it would be great if keras could expose an unified API for reproducible training. something like:

keras.set_random_seed(fixed_seed_num)
@fchollet
Copy link
Member

BTW: it would be great if keras could expose an unified API for reproducible training

Right. I will look into it. Or does anybody else want to take a look at it?

@bplank
Copy link

bplank commented Jul 30, 2016

Is there any update on this?

I could train reproducible models on theano (setting the seed before the keras import #439), but not when using the tensorflow backend.

related: #850

@kudkudak
Copy link
Contributor

I run into this problem in edward, here is fix we went with after rather long discussion: blei-lab/edward#184. Long story short - it is pretty hard to seed tensorflow if you have single shared session. Would be very interested to hear if there is a better solution :)

@fish128
Copy link

fish128 commented Oct 17, 2016

Any update on this?

@bluelight773
Copy link

Correct me if I'm wrong, but looks like this issue is still open and there is no way currently in Keras with a TensorFlow backend to get reproducible results. Any update? Workaround?

@kudkudak
Copy link
Contributor

kudkudak commented Nov 5, 2016

Well, there is this hack blei-lab/edward#184, I can propose PR with that to Keras if that makes sense, @fchollet ?

The solution is to simply add set_seed() function, but raise an error if someone calls it after a TF variable is created. You cannot reseed after some Variable was created, as the previous seed was used to create initializers for it.

@pibkac
Copy link

pibkac commented Dec 7, 2016

Any news on that issue? @bluelight773 I think when running it on the CPU it's reproducible - but that is not really an option most of the time

@NianzuMa
Copy link

NianzuMa commented Dec 20, 2016

@fchollet @zhuoqiang Could you confirm this?

Maybe there is a workaround by programming use both Keras and Tensorflow following this post:
https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html

Use Keras pre-defined model to speed up building your model. But use Tensorflow for input, output and optimization.
Take a look at this code, it seems could reproduce the result.
I use CentOS 7 server, with Tesla K40. It always shows 0.6268 for the result.

>>> keras.__version__
'1.1.1'
>>> tf.__version__
'0.12.0-rc1'

You should seed it by

import numpy as np
np.random.seed(42)
import tensorflow as tf
tf.set_random_seed(42)
"""
Different behaviors during training and testing

Some Keras layers (e.g. Dropout, BatchNormalization) behave differently at training time and testing time.
You can tell whether a layer uses the "learning phase" (train/test) by printing layer.uses_learning_phase,
a boolean: True if the layer has a different behavior in training mode and test mode, False otherwise.

If your model includes such layers, then you need to specify the value of the learning phase as part of feed_dict,
so that your model knows whether to apply dropout/etc or not.

To make use of the learning phase, simply pass the value "1" (training mode) or "0" (test mode) to feed_dict:
"""
import numpy as np
np.random.seed(42)
import tensorflow as tf
tf.set_random_seed(42)
sess = tf.Session()
from keras.layers import Dropout, Dense, LSTM
from keras import backend as K
K.set_session(sess)
from keras.objectives import categorical_crossentropy
from keras.metrics import categorical_accuracy as accuracy

# load data
from tensorflow.examples.tutorials.mnist import input_data
mnist_data = input_data.read_data_sets('MNIST_data', one_hot=True)

img = tf.placeholder(tf.float32, shape=(None, 784))
labels = tf.placeholder(tf.float32, shape=(None, 10))

x = Dense(128, activation='relu')(img)
x = Dropout(0.5)(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.5)(x)
preds = Dense(10, activation='softmax')(x)

loss = tf.reduce_mean(categorical_crossentropy(labels, preds))

# train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
train_step = tf.train.RMSPropOptimizer(learning_rate=0.001).minimize(loss)
# train_step = tf.train.AdagradOptimizer(learning_rate=0.001).minimize(loss)
# train_step = tf.train.AdadeltaOptimizer(learning_rate=0.001).minimize(loss)

with sess.as_default():
    sess.run(tf.global_variables_initializer())
    for i in range(100):
        batch = mnist_data.train.next_batch(50)
        train_step.run(feed_dict={img: batch[0],
                                  labels: batch[1],
                                  K.learning_phase(): 1})

acc_value = accuracy(labels, preds)
with sess.as_default():
    print acc_value.eval(feed_dict={img: mnist_data.test.images,
                                    labels: mnist_data.test.labels,
                                    K.learning_phase(): 0})

@nejumi
Copy link

nejumi commented Jan 25, 2017

I heard that Keras is going to be merged in TensorFlow. Can I expect that the problem of reproducibility is solved at the same time? If YES, it will be great improvement for Kaggle usage!

@brannondorsey
Copy link

@nejumi, ditto. This lack of support makes it really hard to run experiments with Keras & TF. I appreciate the convos and solutions here but really hoping this gets fixed soon.

@diogoff
Copy link
Contributor

diogoff commented Mar 7, 2017

In principle, this should do it:

import numpy as np
np.random.seed(...)
import tensorflow as tf
tf.set_random_seed(...)

However, there is still non-determinism in cuDNN.

With theano it is possible to ensure reproducibility of cuDNN by setting dnn.conv flags: #2479 (comment)

With tensorflow, how do we set those flags?

lewfish added a commit to azavea/raster-vision that referenced this issue Mar 10, 2017
According to keras-team/keras#2280, experiments aren't
reproducible when using Keras/Tensorflow at the moment. So,  we might as well
remove seeding the random number generator, since it's not doing what we expect.
@pibkac
Copy link

pibkac commented Mar 10, 2017

For some time, I had at least reproducible results when running the training on the CPU. However even that seems not to work any more. Anyone experienced the same?

@iaguas
Copy link

iaguas commented Mar 30, 2017

I'm looking for a way of reproducing keras code, but I'm supposing that it's not possible. Am I right?

@iaguas
Copy link

iaguas commented Mar 30, 2017

Thanks @diogoff but my problem is that I have tensorflow as backend and also I utilize cuDNN. It's the case that you are looking too for a solution.

@diogoff
Copy link
Contributor

diogoff commented Mar 30, 2017

I gave up on reproducibility because I found that when forcing deterministic behavior in cuDNN, training would be much slower (e.g. from 15 secs/epoch to 30 secs/epoch).

@pylang
Copy link

pylang commented Apr 1, 2017

IMO this is a critical issue that merits a high priority. Running a complex model for several minutes is meaningless unless results can be reproduced.

Running the same cell multiple times has given results that differ by several orders of magnitude. I can confirm the latest suggestion does not work for Keras 2.0.2/ TensorFlow 1.0 backend/Anaconda 4.2/Windows 7

import numpy as np
np.random.seed(123)
import tensorflow as tf
tf.set_random_seed(123)

@diogoff
Copy link
Contributor

diogoff commented Apr 1, 2017

@pylang are you using cuDNN?

@pylang
Copy link

pylang commented Apr 1, 2017

@diogoff I have not taken extra steps to install cuDNN. My assumption is no, though I am unaware how verify this absolutely.

@diogoff
Copy link
Contributor

diogoff commented Apr 1, 2017

Try:
$ ls -las /usr/local/cuda/include/*dnn*
and
$ ls -las /usr/local/cuda/lib64/*dnn*

If you see libcudnn.so installed, you have it and probably tensorflow is using it.

If I remember, tensorflow will print some warning/info messages on startup, saying which libraries it has loaded. On my system, libcudnn.so was one of them.

@pylang
Copy link

pylang commented Apr 1, 2017

I searched all files on my Windows machine and found none by that name, nor any system files with "cudnn" (only folders included in Anaconda's TensorFlow site package). I also don't see any warnings aside from the "TensorFlow backend" warning upon import. Seeing that

I have not directly installed the driver, find no library files under this name, and see no unusual warnings at import, I conclude I do not have cudnn installed.

@pylang
Copy link

pylang commented Apr 1, 2017

On another note ... I perceived the main issue with non-reproducible results in keras may be related to how the weights are randomized for each call.

I did discover (late last night), that the kernel_initializer has a number of options for setting up a distribution from which (I assume) the weights are drawn. I have not run substantial tests to make a conclusion nor investigated these options further yet, but my initial tests seem to suggest that selecting different initializers influences the reproducibility of results. For instance, the default initializer is called "glorot_uniform". I played with some other distributions and managed to get more reproducible results, although with much higher error.

Since there are many variables, perhaps we should post, simple example here, e.g. single Dense layer, 1 input linear regression. The results should be consistent for all implementers. We can then confirm the results across different machines for different users.

@diogoff
Copy link
Contributor

diogoff commented Apr 3, 2017

I picked up mnist_cnn.py from the examples and set up keras.json in this way:

{
    "image_data_format": "channels_first",
    "epsilon": 1e-07,
    "floatx": "float32",
    "backend": "tensorflow"
}

I ran python mnist_cnn.py a couple of times and the results did not seem to be reproducible.

Then I edited mnist_cnn.py and inserted the following code between from __future__ import print_function (line 8) and import keras (line 9):

import numpy as np
np.random.seed(123)
import tensorflow as tf
tf.set_random_seed(123)

The results now look sufficiently reproducible to me. The small differences I assume are due to the use of cuDNN.

I tried running without cuDNN:
$ TF_USE_CUDNN=0 python mnist_cnn.py
but it seems it's not possible:
UnimplementedError (see above for traceback): Conv2D for GPU is not currently supported without cudnn

@diogoff
Copy link
Contributor

diogoff commented Apr 3, 2017

If I switch the backend to Theano:

{
    "image_data_format": "channels_first",
    "epsilon": 1e-07,
    "floatx": "float32",
    "backend": "theano"
}

and insert the following code between lines 8-9 in mnist_cnn.py:

import numpy as np
np.random.seed(123)

and then run:
$ THEANO_FLAGS="dnn.conv.algo_bwd_filter=deterministic,dnn.conv.algo_bwd_data=deterministic" python mnist_cnn.py
the results are fully reproducible.

@pylang
Copy link

pylang commented Apr 4, 2017

@diogoff for clarity, what do you consider fully reproducible? Do you know how close your loss results between runs? I'd like to compare notes.

@diogoff
Copy link
Contributor

diogoff commented Apr 4, 2017

With fully reproducible, I mean I always get exactly the same results in every run:

loss: 0.3336 - acc: 0.8981 - val_loss: 0.0788 - val_acc: 0.9759
loss: 0.1214 - acc: 0.9642 - val_loss: 0.0548 - val_acc: 0.9828
loss: 0.0893 - acc: 0.9733 - val_loss: 0.0443 - val_acc: 0.9847
loss: 0.0735 - acc: 0.9783 - val_loss: 0.0391 - val_acc: 0.9871
loss: 0.0666 - acc: 0.9804 - val_loss: 0.0363 - val_acc: 0.9872
loss: 0.0590 - acc: 0.9825 - val_loss: 0.0369 - val_acc: 0.9873
loss: 0.0542 - acc: 0.9836 - val_loss: 0.0338 - val_acc: 0.9889
loss: 0.0505 - acc: 0.9850 - val_loss: 0.0314 - val_acc: 0.9889
loss: 0.0467 - acc: 0.9861 - val_loss: 0.0299 - val_acc: 0.9896
loss: 0.0451 - acc: 0.9867 - val_loss: 0.0319 - val_acc: 0.9898
loss: 0.0421 - acc: 0.9874 - val_loss: 0.0297 - val_acc: 0.9894
loss: 0.0405 - acc: 0.9880 - val_loss: 0.0309 - val_acc: 0.9895
Test loss: 0.0309449151449    <-- exactly the same up to the last digit
Test accuracy: 0.9895

Keras 2.0.2, Theano 0.9.0 with libgpuarray, CUDA 8.0, cuDNN 5.1.10

@MyVanitar
Copy link

I don't think you could be able to make reproducible results using Tensorflow-GPU. I have also a GPU on the system but Tensorflow uses CPU because I have installed the CPU version. You can create another anaconda environment for this purpose and test the idea

@Sucran
Copy link

Sucran commented May 21, 2018

@VanitarNordic Yes, I agree with you. It is difficult to make reproducible results using TF-GPU. In my cases, although two same code got loss values with small value difference in each epoch, the data curve is almost the same but not identical (may be we can already say this is a reproducible results). However, I
still hope someone could figure out what cases this issue. I think may be it could be identical in CPU just like you said. Thank you for giving these advises.

dmitriydligach pushed a commit to dmitriydligach/Phenotype that referenced this issue Jun 28, 2018
@MartinThoma
Copy link
Contributor

@abali96 Do you have any idea why setting PYTHONHASHSEED would be necessary?

@ageron
Copy link
Contributor

ageron commented Aug 8, 2018

@MartinThoma, setting the PYTHONHASHSEED environment variable to 0 ensures that python's built-in hash() function outputs the same result across multiple runs of the program (without this, the hash() function is only stable within a single run of the program). This hash() function is used everywhere, for example when you create a set or a dict. Try running this:

>>> set("abcdefghijklmnopqrstuvwxyz")
{'p', 'f', 'g', 'i', 'n', 'o', 'k', 'c', 'h', 'b', 'v', 'a', 'd', 's', 'u', 'q', 'j', 'z', 'm', 'r', 'w', 'l', 't', 'x', 'y', 'e'}
>>> set("abcdefghijklmnopqrstuvwxyz")
{'p', 'f', 'g', 'i', 'n', 'o', 'k', 'c', 'h', 'b', 'v', 'a', 'd', 's', 'u', 'q', 'j', 'z', 'm', 'r', 'w', 'l', 't', 'x', 'y', 'e'}

If I stop the Python shell and I run the same commands, I get a different result:

>>> set("abcdefghijklmnopqrstuvwxyz")
{'c', 'y', 'q', 'g', 'a', 'u', 'd', 'k', 'w', 'j', 'm', 's', 'e', 'o', 'b', 'h', 'l', 'r', 't', 'x', 'z', 'n', 'p', 'v', 'f', 'i'}
>>> set("abcdefghijklmnopqrstuvwxyz")
{'c', 'y', 'q', 'g', 'a', 'u', 'd', 'k', 'w', 'j', 'm', 's', 'e', 'o', 'b', 'h', 'l', 'r', 't', 'x', 'z', 'n', 'p', 'v', 'f', 'i'}

However, if I start python like this:

PYTHONHASHSEED=0 python

Then I always get the same result, even across multiple restarts of the Python shell:

>>> set("abcdefghijklmnopqrstuvwxyz")
{'x', 'i', 'r', 'p', 'd', 'c', 'l', 'y', 'h', 'm', 'z', 'k', 'o', 'a', 'g', 'f', 'u', 'e', 'w', 'n', 'b', 'q', 'j', 't', 's', 'v'}

However, I noticed that setting this environment variable within the Python program did not have any effect. It only worked when setting it before the program starts. Looking at Python's source code, it seems that this environment variable is read upon startup, so I think it's no use setting this environment variable after startup, I'll submit a fix to Keras's documentation.

Hope this helps,
Aurélien

@thanhnguyentang
Copy link

I think this non-deterministic behaviour is pretty much not due to Keras, but Tensorflow itself (e.g., https://stackoverflow.com/questions/45865665/getting-reproducible-results-using-tensorflow-gpu?newreg=4a6ec43834884576a175961e7f2188db).

I have tried to run the pure Tensorflow code fully_connected_feed.py from Tensorflow repo with the following settings (as recommended above by other responses):

import tensorflow as tf
import numpy as np 
import random 
import os 
os.environ['PYTHONHASHSEED'] = '0'
np.random.seed(2019)
random.seed(2019)
tf.set_random_seed(2019)

session_conf = tf.ConfigProto(intra_op_parallelism_threads=1,
                              inter_op_parallelism_threads=1)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)

and set shuffle=False in line 78 of fully_connected_feed.py but could not obtain reproducibility.

I could not obtain reproducibility even when runnning the code on CPU.
Note: For the Keras + Theano backend, I have also obtained a perfect reproducibility.

@AE51
Copy link

AE51 commented Nov 28, 2018

I've also inserted the explicit kernel (and bias) initialization:
x = layers.Dense(64, activation='relu', kernel_initializer=keras.initializers.glorot_uniform(seed=123))(x),
and this has worked.

@mrgkumar
Copy link

mrgkumar commented Apr 3, 2019

I tried multiple version of tensorflow from 1.2.1 till 1.13.1 all of them have issues on CPU. However when I set the backend as CNTK I am able to get perfect matching results.
I have used Sequential model with only Dense layers and no dropouts.

@erotavlas
Copy link

I'm getting this issue. I'm using this example https://www.depends-on-the-definition.com/lstm-with-char-embeddings-for-ner/ with the ner dataset. https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus

Everytime I run the training, it produces a completely different metrics for fscore. precision, and recall. THe lowest I had was 71% fscore and the highest was 83%. I think that is a huge variation.

I tried the same on two machines one with GPU and one using only CPU, and results are the same. unreproducable results.

I was using keras 2.2.2 + tensorflow 1.10 (unable to use latest version due to unresolved bug in keras >= 2.2.3)

@anjanaw
Copy link

anjanaw commented Apr 9, 2019

Getting same issue +1

@anjanaw
Copy link

anjanaw commented Apr 15, 2019

any update on this please?

@alberduris
Copy link

alberduris commented Apr 19, 2019

Still can't get reproducible results even after the usage of this method

def _seed_everything(seed=2019):
    os.environ['PYTHONHASHSEED'] = str(seed) # Os
    random.seed(seed) # Python random
    np.random.seed(seed) # Numpy random
    set_random_seed(2019) # TF random

Please, consider it as high priority as it becomes really difficult to do research with TF + Keras.

@ageron
Copy link
Contributor

ageron commented Apr 19, 2019

Hi @alberduris, please read my comment about PYTHONHASHSEED: you cannot set it within your program, you have to set it before starting Python (or Jupyter). Check out my video for more details.

@jsl303
Copy link

jsl303 commented May 5, 2019

Putting the following code in the beginning, I can consistently reproduce the result 100% if I only use Dense layer.

import numpy as np
import random as rn
import tensorflow as tf
import os
os.environ['PYTHONHASHSEED'] = '0'
np.random.seed(1)
rn.seed(2)
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
from tensorflow.keras import backend as K
tf.set_random_seed(3)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)

However, I get different results if I insert this one line "model.add(Conv2D(32, 3, activation='relu'))" before "model.add(Flatten())".

Input> flatten > dense produces consistent result, but input > conv2d > flatten > dense produces different result every time I run the code.

I'd appreciate any guidance.

@ageron
Copy link
Contributor

ageron commented May 5, 2019

@jsl303 , it's no use setting PYTHONHASHSEED within the Python program, you have to set it before starting Python (it's only read by Python upon startup). This is used by Python to compute the hash (e.g., for dictionaries or sets). If you don't use any code that relies on the order of the items in sets or dictionaries, it won't make a difference, but I wouldn't count on it.

To convince yourself, try starting Python multiple times and run this command: print("".join(set("abcdefghijklmnopqrstuvwxyz"))). For example:

$ python3
...
>>> print("".join(set("abcdefghijklmnopqrstuvwxyz")))
oeqsytnmfprwbvhldxijzcugak
>>> print("".join(set("abcdefghijklmnopqrstuvwxyz")))
oeqsytnmfprwbvhldxijzcugak
>>> exit()

$ python3
...
>>> print("".join(set("abcdefghijklmnopqrstuvwxyz")))
qjufgnolbdewycpitkzvarxsmh
>>> print("".join(set("abcdefghijklmnopqrstuvwxyz")))
qjufgnolbdewycpitkzvarxsmh
>>> exit()

As you can see, although the order is consistent within one Python execution, it is not consistent across multiple runs. If you try to set PYTHONHASHSEED within your Python code, it won't be any different, git it a try! But the correct approach is to set PYTHONHASHSEED before Python starts, for example on the command line:

$ PYTHONHASHSEED=0 python3
...
>>> print("".join(set("abcdefghijklmnopqrstuvwxyz")))
xirpdclyhmzkoagfuewnbqjtsv
>>> print("".join(set("abcdefghijklmnopqrstuvwxyz")))
xirpdclyhmzkoagfuewnbqjtsv
>>> exit()

$ PYTHONHASHSEED=0 python3
...
>>> print("".join(set("abcdefghijklmnopqrstuvwxyz")))
xirpdclyhmzkoagfuewnbqjtsv
>>> print("".join(set("abcdefghijklmnopqrstuvwxyz")))
xirpdclyhmzkoagfuewnbqjtsv
>>> exit()

Hope this helps.

@jsl303
Copy link

jsl303 commented May 5, 2019 via email

@ageron
Copy link
Contributor

ageron commented May 5, 2019

@jsl303 , glad I could help. But even if you don't use any sets or dictionaries in your code, if any library you call iterates over sets or dictionaries, the result will not be deterministic across runs. So I really recommend setting PYTHONHASHSEED=0 outside of Python. Moreover, if you are using a GPU, then it won't be deterministic because some GPU operations used by TensorFlow (through CuDNN and CUDA) are just not perfectly deterministic (such as tf.reduce_sum()). So unfortunately, if you really need things to be deterministic, you need to use the CPU. And you also need to use a single thread. Check out my video on this topic.

@kiranvarmas
Copy link

kiranvarmas commented May 15, 2019

I am getting 100% reproducible results after following Ageron's video. Below are the variables I set

OS: Windows
Libraries used: Keras (Tensorflow in the backend)
IDE: Anaconda

I first created a PYTHONHASHSEED environment variable in Windows environment variables and set it to 0.

I opened jupyter notebook through Anaconda prompt and added the below code at the start of the program.

import numpy as np
import tensorflow as tf
import random as rn

np.random.seed(10)
rn.seed(10)

config = tf.ConfigProto(intra_op_parallelism_threads=1,inter_op_parallelism_threads=1)
with tf.Session(config=config) as sess:
pass

tf.set_random_seed(10)

I then used the seed while splitting dataset, initializing weights in NN's .
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30,random_state=random_seed)

model.add(Dense(45, input_dim=42, activation='tanh',kernel_initializer=keras.initializers.glorot_normal(seed=random_seed)))

Hope this helps.

The only problem I see now is to reproduce the results, I have to restart the kernel every time. I hope someone can help in this regard which avoids restarting the kernel.

Update:
I replaced
with tf.Session(config=config) as sess:
pass
tf.set_random_seed(10)

with the below in the first cell and now I don't have to restart the kernel.

from keras import backend as K
tf.set_random_seed(10)
sess = tf.Session(graph=tf.get_default_graph(), config=config)
K.set_session(sess)

Just running this cell before the desired seed iteration helps.

-Kiran Varma

@samra-irshad
Copy link

After searching for 2-3 days, a solution that I saw somewhere worked for me. Changed the optimizer to Adagrad from Adam and now I am having consistent results. Trying to find the reason for this

@pylang
Copy link

pylang commented Jun 26, 2021

Has this been fixed? What's the recommended solution?

@ageron
Copy link
Contributor

ageron commented Jun 28, 2021

@pylang . Here's a summary on how to get 100% reproducibility (at the cost of performance!):

  • Make sure you use the same version of Python and of all libraries at each run.
  • Obviously, make sure the data doesn't change between runs.
  • Set all the random seeds: np.random.seed(42), tf.random.set_seed(42), etc., and set the PYTHONHASHSEED environment variable to 0 before starting Python (or Jupyter, or whatever Python environment you're using).
  • Avoid using the GPU. CUDA does have deterministic versions of most operations, but TensorFlow uses the faster but non-deterministic versions for many operations (e.g., tf.reduce_sum() is non-deterministic, and even if your code doesn't explicitly use it, many other operations use it).
  • More generally, avoid parallelism. Because of floating point errors, the order of execution matters. For example, 1 + 1 + 1/3 is not perfectly equal to 1/3 + 1 + 1 (check it out in a Python shell). To make TensorFlow single-threaded, you can set the TF_NUM_INTEROP_THREADS and TF_NUM_INTRAOP_THREADS environment variables to 1 before starting Python (or at least before importing TensorFlow).
  • Some library functions are non-deterministic. For example, os.listdir() returns files in an arbitrary order, so you should sort the list before using it.
  • The same code may sometimes produce different results on different platforms (e.g., 32-bit vs 64-bit platforms, or Windows vs Linux). So don't expect 100% reproducibility on different platforms (unless you work very hard to work around every single difference).

If you follow these guidelines, you should get 100% reproducible results. But as you can see, it comes at a high cost (especially dropping the GPU!), so you may ask yourself: is it worth the effort? Perhaps instead of perfect reproducibility, you could run the code multiple times and ensure that it produces approximately the same output on average, and the variance is low.

Hope this helps.

@minsung-k
Copy link

minsung-k commented Feb 23, 2022

@abali96

I found your comment on here.
(https://keras.io/getting_started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development)

The way how I get a reproducible results is to put tf.random.set_seed() in your function.

for example.

"""
def model():
tf.random.set_seed()
--
Build model
--
return **
"""

Hope others solve this problem as well!

@hwei-hw
Copy link

hwei-hw commented Feb 23, 2022 via email

@ahallermed
Copy link

For everyone who had the same issue as described here but still came across this issue, please have a look at these links:

  1. Tensorflow has an experimental option to run the model training on GPU in a deterministic way.
    https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism
    I checked it out and it worked for me very well (using keras lstm, dropout, dense layer).

  2. Here is a repository which formerly provided patches for tensorflow determinism, and now they track the status of each functionality, whether it is available as deterministic method or why there are still issues with it.
    https://github.com/NVIDIA/framework-determinism/blob/master/doc/tensorflow_status.md

  3. FYI: Theano backend can't be used anymore since keras 2.4.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests