Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very Large Memory Consumption for Even A Small Dataset #50

Closed
createmomo opened this issue Jan 22, 2022 · 5 comments
Closed

Very Large Memory Consumption for Even A Small Dataset #50

createmomo opened this issue Jan 22, 2022 · 5 comments

Comments

@createmomo
Copy link

createmomo commented Jan 22, 2022

Dataset: fashion_mnist
Dataset Size: 36.42MB (https://www.tensorflow.org/datasets/catalog/fashion_mnist)

Reproduce the Issue:

from learned_optimization.tasks import fixed_mlp
task = fixed_mlp.FashionMnistRelu32_8()

or

from learned_optimization.tasks.datasets import base

batch_size=128
image_size=(8, 8)
splits = ("train[0:80%]", "train[80%:90%]", "train[90%:]", "test")
stack_channels = 1

dataset = preload_tfds_image_classification_datasets(
      "fashion_mnist",
      splits,
      batch_size=batch_size,
      image_size=image_size,
      stack_channels=stack_channels)

Issue Description:
As you can see, the original FashionMnist dataset is very small. However, when I run the above code, the memory usage became crazy high, such as 10G+.

In my case, the issues occurs when the program reaches this line which in the function preload_tfds_image_classification_datasets:

  return Datasets(
      *[make_python_iter(split) for split in splits],
      extra_info={"num_classes": num_classes})

Here is the code of make_python_iter:

  def make_python_iter(split: str) -> Iterator[Batch]:
    # load the entire dataset into memory
    dataset = tfds.load(datasetname, split=split, batch_size=-1)
    data = tfds.as_numpy(_image_map_fn(cfg, dataset))

    use a python iterator as this is faster than TFDS.
    def generator_fn():

      def iter_fn():
        batches = data["image"].shape[0] // batch_size
        idx = onp.arange(data["image"].shape[0])
        while True:
          # every epoch shuffle indicies
          onp.random.shuffle(idx)
          for bi in range(0, batches):
            idxs = idx[bi * batch_size:(bi + 1) * batch_size]

            def index_into(idxs, x):
              return x[idxs]

            yield jax.tree_map(functools.partial(index_into, idxs), data)

      return prefetch_iterator.PrefetchIterator(iter_fn(), prefetch_batches)

    return ThreadSafeIterator(LazyIterator(generator_fn))

Could you please suggest a way to reduce the huge memory usage, do you have any idea why it requires so high memory, and do you (or anybody) also have this issue?

Thank you very much and looking forward to your comments.

@lukemetz
Copy link
Contributor

Hmm interesting. I would have suspected some memory overhead, but not nearly this much! Thank you for the carefully written issue. Sadly though, I am also not able to reproduce this on my machine. Could you please describe your hardware / setup?

Just to confirm this is CPU / host memory you are talking about correct? Not gpu?

FYI: There is some buffering going on with the "prefetch_batches: int = 300", but 8x8 images this would mean: 884 (4 splits)*128(batchsize)*300 bytes, or 10mb.... So it is not this.

@createmomo
Copy link
Author

createmomo commented Jan 24, 2022

Hi Luke, thank you very much for your quick response. I hope the following details can be helpful.

I am using GPU

  • Driver Version: 470.82.01
  • Cuda Version: 11.4
  • NVIDIA GeForce RTX 3060

Python

  • 3.8
  • tensorflow: 2.7.0
  • jax: 0.2.27 (this is how I install jax to support gpu: pip install --upgrade "jax[cuda114]" jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html)

More Details:
The first time when I found the large GPU issue is when I run the pes.py example.
python pes.py --train_log_dir somedir


I understand there are some prefetch batches. But as you calculated, the pre-fetched data should be very small.


Again, thank you for your reply. I am also still investigating this issue and will let you know once I found something.

@lukemetz
Copy link
Contributor

Thanks for the info and being such an early tester!

Just to confirm that is NOT gpu memory, but an explosion in host (CPU) memory?

If you are observing GPU memory, see: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html for flags on how to turn that off.

Do you still see memory increases if you turn off the GPU? e.g. something like:

CUDA_VISIBLE_DEVICES= python pes.py --train_log_dir ....

@createmomo
Copy link
Author

Hello, thank you so much! Your comment is really very helpful, especially the jax gpu memory allocation link.

Just to confirm that is NOT gpu memory, but an explosion in host (CPU) memory?

In my case, it was GPU explosion.

If you are observing GPU memory, see: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html for flags on how to turn that off.

For my case, I finally managed to reduce the GPU memory (from 10G+ to ~700M for running the pes.py) based on the suggestions on the above link.

What I did were:

  • disable Tensorflow to use GPU, because we just use Tensorflow to load the fashion_mnist dataset. I added such line in base.py
tf.config.experimental.set_visible_devices([], "GPU")
  • disable jax GPU memory preallocation behaviour. Change the linux environment variable:
export XLA_PYTHON_CLIENT_PREALLOCATE=false

The link https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html provided different options to avoid GPU OOM error, so there could be other solutions.

@lukemetz
Copy link
Contributor

Ahh tf also tries to also grab the GPU. That is annoying. I should fix that on my end. Going to make an issue. Thanks for posting your solution here!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants