You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I am trying to run my environment using MJX, and everything works perfectly fine when I use PPO. When using PPO I can easily go to over 8000 environments, with batch size over 1000 without any memory issues. However, no matter what setting I try when running SAC, it becomes a different story (even with a batch size of 32, and 1 environment it fails). I just cannot avoid running into the memory allocation issue copy/pasted below. Any help would be greatly appreciated, because for my particular environment (through trying it using other simulators) it can only really train using the SAC algorithm not PPO, so I really do need to use SAC. Any help would be greatly appreciated.
Getting this error:
2024-02-19 15:14:56.882326: W external/xla/xla/service/hlo_rematerialization.cc:2941] Can't reduce memory use below -30.29GiB (-32529110990 bytes) by rematerialization; only reduced to 49.62GiB (53280000080 bytes), down from 49.62GiB (53280000080 bytes) originally
2024-02-19 15:15:06.899957: W external/tsl/tsl/framework/bfc_allocator.cc:487] Allocator (GPU_0_bfc) ran out of memory trying to allocate 49.62GiB (rounded to 53280000000)requested by op
2024-02-19 15:15:06.900601: W external/tsl/tsl/framework/bfc_allocator.cc:499] ____******
E0219 15:15:06.900728 6066 pjrt_stream_executor_client.cc:2766] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 53280000000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 8B
constant allocation: 8B
maybe_live_out allocation: 49.62GiB
preallocated temp allocation: 160B
total allocation: 49.62GiB
total fragmentation: 168B (0.00%)
Peak buffers:
Buffer 1:
Size: 49.62GiB
Operator: op_name="pmap(init)/jit(main)/broadcast_in_dim[shape=(30000000, 444) broadcast_dimensions=()]" source_file="/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/replay_buffers.py" source_line=112
XLA Label: fusion
Shape: f32[30000000,444]
==========================
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/sclaer/ASCENT_biped/sim/mujoco/rl_mjx.py", line 367, in
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
File "/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/agents/sac/train.py", line 429, in train
buffer_state = jax.pmap(replay_buffer.init)(
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 53280000000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 8B
constant allocation: 8B
maybe_live_out allocation: 49.62GiB
preallocated temp allocation: 160B
total allocation: 49.62GiB
total fragmentation: 168B (0.00%)
Peak buffers:
Buffer 1:
Size: 49.62GiB
Operator: op_name="pmap(init)/jit(main)/broadcast_in_dim[shape=(30000000, 444) broadcast_dimensions=()]" source_file="/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/replay_buffers.py" source_line=112
XLA Label: fusion
Shape: f32[30000000,444]
==========================
You are not setting the max_replay_size parameter and if not specified, Brax uses num_timesteps as the max_replay_size which is 30_000_000 in your case. So a replay buffer of 30 million is being initialized which captures the VRAM. Try specifying max_replay_size to a smaller value and things should work.
Hi, I am trying to run my environment using MJX, and everything works perfectly fine when I use PPO. When using PPO I can easily go to over 8000 environments, with batch size over 1000 without any memory issues. However, no matter what setting I try when running SAC, it becomes a different story (even with a batch size of 32, and 1 environment it fails). I just cannot avoid running into the memory allocation issue copy/pasted below. Any help would be greatly appreciated, because for my particular environment (through trying it using other simulators) it can only really train using the SAC algorithm not PPO, so I really do need to use SAC. Any help would be greatly appreciated.
My computer has a Nvidia 4090 GPU.
Running this training:
train_fn = functools.partial(
sac.train, num_timesteps=30_000_000, num_evals=1,
episode_length=INITIAL_PARAMS.RL_PARAMS.MAX_EPISODE_TIMESTEPS, normalize_observations=True, action_repeat=1,
num_envs=1, batch_size = 32, seed=0)
Getting this error:
2024-02-19 15:14:56.882326: W external/xla/xla/service/hlo_rematerialization.cc:2941] Can't reduce memory use below -30.29GiB (-32529110990 bytes) by rematerialization; only reduced to 49.62GiB (53280000080 bytes), down from 49.62GiB (53280000080 bytes) originally
2024-02-19 15:15:06.899957: W external/tsl/tsl/framework/bfc_allocator.cc:487] Allocator (GPU_0_bfc) ran out of memory trying to allocate 49.62GiB (rounded to 53280000000)requested by op
2024-02-19 15:15:06.900601: W external/tsl/tsl/framework/bfc_allocator.cc:499] ____******
E0219 15:15:06.900728 6066 pjrt_stream_executor_client.cc:2766] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 53280000000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 8B
constant allocation: 8B
maybe_live_out allocation: 49.62GiB
preallocated temp allocation: 160B
total allocation: 49.62GiB
total fragmentation: 168B (0.00%)
Peak buffers:
Buffer 1:
Size: 49.62GiB
Operator: op_name="pmap(init)/jit(main)/broadcast_in_dim[shape=(30000000, 444) broadcast_dimensions=()]" source_file="/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/replay_buffers.py" source_line=112
XLA Label: fusion
Shape: f32[30000000,444]
==========================
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/sclaer/ASCENT_biped/sim/mujoco/rl_mjx.py", line 367, in
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
File "/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/agents/sac/train.py", line 429, in train
buffer_state = jax.pmap(replay_buffer.init)(
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 53280000000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 8B
constant allocation: 8B
maybe_live_out allocation: 49.62GiB
preallocated temp allocation: 160B
total allocation: 49.62GiB
total fragmentation: 168B (0.00%)
Peak buffers:
Buffer 1:
Size: 49.62GiB
Operator: op_name="pmap(init)/jit(main)/broadcast_in_dim[shape=(30000000, 444) broadcast_dimensions=()]" source_file="/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/replay_buffers.py" source_line=112
XLA Label: fusion
Shape: f32[30000000,444]
==========================
The text was updated successfully, but these errors were encountered: