Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

out of memory error #37

Closed
lucky2046 opened this issue Feb 21, 2024 · 3 comments
Closed

out of memory error #37

lucky2046 opened this issue Feb 21, 2024 · 3 comments

Comments

@lucky2046
Copy link

lucky2046 commented Feb 21, 2024

bash scripts/run_vision_chat.sh
removed --mesh_dim param
model is LWM-Chat-32K-Jax
out of memory error, how to solve it

my card is nvidia 2080 super 8G

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1708500656.672727   10871 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
I0221 15:30:57.202437 140383335174272 xla_bridge.py:513] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0221 15:30:57.202921 140383335174272 xla_bridge.py:513] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-02-21 15:36:18.340692: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.00GiB (rounded to 2147483648)requested by op 
2024-02-21 15:36:18.340908: W external/tsl/tsl/framework/bfc_allocator.cc:497] *________**********************************************************************_____________________
2024-02-21 15:36:18.340944: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2644] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2147483648 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.00GiB
              constant allocation:         0B
        maybe_live_out allocation:    2.00GiB
     preallocated temp allocation:         0B
                 total allocation:    3.00GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
        Buffer 1:
                Size: 2.00GiB
                Operator: op_name="pjit(to_dtype)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]" source_file="/mnt/data/test/LWM/lwm/vision_chat.py" source_line=199
                XLA Label: fusion
                Shape: f32[32,4096,4096]
                ==========================

        Buffer 2:
                Size: 1.00GiB
                Entry Parameter Subshape: bf16[32,4096,4096]
                ==========================


jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/mnt/data/test/LWM/lwm/vision_chat.py", line 254, in <module>
    run(main)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/mnt/data/test/LWM/lwm/vision_chat.py", line 249, in main
    sampler = Sampler()
  File "/mnt/data/test/LWM/lwm/vision_chat.py", line 51, in __init__
    self._load_model()
  File "/mnt/data/test/LWM/lwm/vision_chat.py", line 199, in _load_model
    self.params = tree_apply(shard_fns, self.params)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/jax_utils.py", line 148, in tree_apply
    return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/jax_utils.py", line 148, in <lambda>
    return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/distributed.py", line 95, in shard_fn
    return jax_shard_function(tensor).block_until_ready()
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2147483648 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.00GiB
              constant allocation:         0B
        maybe_live_out allocation:    2.00GiB
     preallocated temp allocation:         0B
                 total allocation:    3.00GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
        Buffer 1:
                Size: 2.00GiB
                Operator: op_name="pjit(to_dtype)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]" source_file="/mnt/data/test/LWM/lwm/vision_chat.py" source_line=199
                XLA Label: fusion
                Shape: f32[32,4096,4096]
                ==========================

        Buffer 2:
                Size: 1.00GiB
                Entry Parameter Subshape: bf16[32,4096,4096]
                ==========================


I0000 00:00:1708500978.900009   10871 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.
(lwm) test@test-3:/mnt/data/test/LWM$ nvidia-smi
Wed Feb 21 15:47:00 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 2080 S...    Off| 00000000:01:00.0 Off |                  N/A |
|  0%   40C    P0               23W / 250W|      0MiB /  8192MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+
@jackyin68
Copy link

can you share me your modified requirements.txt?

@lucky2046
Copy link
Author

can you share me your modified requirements.txt?

I did not modify requirements. txt, I modified run_vision_chat.sh for your reference

#! /bin/bash

export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"

# MODEL_NAME='LWM-Chat-1M-Jax'
# MODEL_NAME='LWM-Chat-128K-Jax'
MODEL_NAME='LWM-Chat-32K-Jax'

export llama_tokenizer_path="/mnt/data/test/LWM/models/${MODEL_NAME}/tokenizer.model"
export vqgan_checkpoint="/mnt/data/t'e's't/LWM/models/${MODEL_NAME}/vqgan"
export lwm_checkpoint="/mnt/data/test/LWM/models/${MODEL_NAME}/params"
export input_file="/mnt/data/test/2020-07-30_pose_test_006.mp4"

python3 -u -m lwm.vision_chat \
    --prompt="What is the video about?" \
    --input_file="$input_file" \
    --vqgan_checkpoint="$vqgan_checkpoint" \
    --dtype='fp32' \
    --load_llama_config='7b' \
    --max_n_frames=8 \
    --update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,use_flash_attention=False,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,remat_attention='',scan_mlp=False,scan_mlp_chunk_size=2048,remat_mlp='',remat_block='',scan_layers=True)" \
    --load_checkpoint="params::$lwm_checkpoint" \
    --tokenizer.vocab_file="$llama_tokenizer_path" \
2>&1 | tee ~/output.log
read

@wilson1yan
Copy link
Contributor

I don't think your GPU has enough memory, as by itself a 7B model with fp32 would be 28GB.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants