Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with GPU allocation (only occurs when running SAC not PPO) #1431

Closed
AlexS28 opened this issue Feb 19, 2024 · 2 comments
Closed

Issue with GPU allocation (only occurs when running SAC not PPO) #1431

AlexS28 opened this issue Feb 19, 2024 · 2 comments
Labels
question Request for help or information

Comments

@AlexS28
Copy link

AlexS28 commented Feb 19, 2024

Hi, I am trying to run my environment using MJX, and everything works perfectly fine when I use PPO. When using PPO I can easily go to over 8000 environments, with batch size over 1000 without any memory issues. However, no matter what setting I try when running SAC, it becomes a different story (even with a batch size of 32, and 1 environment it fails). I just cannot avoid running into the memory allocation issue copy/pasted below. Any help would be greatly appreciated, because for my particular environment (through trying it using other simulators) it can only really train using the SAC algorithm not PPO, so I really do need to use SAC. Any help would be greatly appreciated.

My computer has a Nvidia 4090 GPU.

Running this training:
train_fn = functools.partial(
sac.train, num_timesteps=30_000_000, num_evals=1,
episode_length=INITIAL_PARAMS.RL_PARAMS.MAX_EPISODE_TIMESTEPS, normalize_observations=True, action_repeat=1,
num_envs=1, batch_size = 32, seed=0)

Getting this error:
2024-02-19 15:14:56.882326: W external/xla/xla/service/hlo_rematerialization.cc:2941] Can't reduce memory use below -30.29GiB (-32529110990 bytes) by rematerialization; only reduced to 49.62GiB (53280000080 bytes), down from 49.62GiB (53280000080 bytes) originally
2024-02-19 15:15:06.899957: W external/tsl/tsl/framework/bfc_allocator.cc:487] Allocator (GPU_0_bfc) ran out of memory trying to allocate 49.62GiB (rounded to 53280000000)requested by op
2024-02-19 15:15:06.900601: W external/tsl/tsl/framework/bfc_allocator.cc:499] ____******
E0219 15:15:06.900728 6066 pjrt_stream_executor_client.cc:2766] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 53280000000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 8B
constant allocation: 8B
maybe_live_out allocation: 49.62GiB
preallocated temp allocation: 160B
total allocation: 49.62GiB
total fragmentation: 168B (0.00%)
Peak buffers:
Buffer 1:
Size: 49.62GiB
Operator: op_name="pmap(init)/jit(main)/broadcast_in_dim[shape=(30000000, 444) broadcast_dimensions=()]" source_file="/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/replay_buffers.py" source_line=112
XLA Label: fusion
Shape: f32[30000000,444]
==========================

Buffer 2:
	Size: 32B
	XLA Label: tuple
	Shape: (f32[30000000,444], s32[], s32[], u32[2])
	==========================

Buffer 3:
	Size: 32B
	XLA Label: tuple
	Shape: (f32[30000000,444], s32[], s32[], u32[2])
	==========================

Buffer 4:
	Size: 16B
	XLA Label: fusion
	Shape: (s32[], s32[])
	==========================

Buffer 5:
	Size: 8B
	Entry Parameter Subshape: u32[2]
	==========================

Buffer 6:
	Size: 8B
	XLA Label: fusion
	Shape: u32[2]
	==========================

Buffer 7:
	Size: 4B
	XLA Label: fusion
	Shape: s32[]
	==========================

Buffer 8:
	Size: 4B
	XLA Label: fusion
	Shape: s32[]
	==========================

Buffer 9:
	Size: 4B
	XLA Label: constant
	Shape: f32[]
	==========================

Buffer 10:
	Size: 4B
	XLA Label: constant
	Shape: s32[]
	==========================

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/sclaer/ASCENT_biped/sim/mujoco/rl_mjx.py", line 367, in
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
File "/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/agents/sac/train.py", line 429, in train
buffer_state = jax.pmap(replay_buffer.init)(
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 53280000000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 8B
constant allocation: 8B
maybe_live_out allocation: 49.62GiB
preallocated temp allocation: 160B
total allocation: 49.62GiB
total fragmentation: 168B (0.00%)
Peak buffers:
Buffer 1:
Size: 49.62GiB
Operator: op_name="pmap(init)/jit(main)/broadcast_in_dim[shape=(30000000, 444) broadcast_dimensions=()]" source_file="/home/sclaer/ASCENT_biped/venv/lib/python3.10/site-packages/brax/training/replay_buffers.py" source_line=112
XLA Label: fusion
Shape: f32[30000000,444]
==========================

Buffer 2:
	Size: 32B
	XLA Label: tuple
	Shape: (f32[30000000,444], s32[], s32[], u32[2])
	==========================

Buffer 3:
	Size: 32B
	XLA Label: tuple
	Shape: (f32[30000000,444], s32[], s32[], u32[2])
	==========================

Buffer 4:
	Size: 16B
	XLA Label: fusion
	Shape: (s32[], s32[])
	==========================

Buffer 5:
	Size: 8B
	Entry Parameter Subshape: u32[2]
	==========================

Buffer 6:
	Size: 8B
	XLA Label: fusion
	Shape: u32[2]
	==========================

Buffer 7:
	Size: 4B
	XLA Label: fusion
	Shape: s32[]
	==========================

Buffer 8:
	Size: 4B
	XLA Label: fusion
	Shape: s32[]
	==========================

Buffer 9:
	Size: 4B
	XLA Label: constant
	Shape: f32[]
	==========================

Buffer 10:
	Size: 4B
	XLA Label: constant
	Shape: s32[]
	==========================
@AlexS28 AlexS28 added the question Request for help or information label Feb 19, 2024
@kinalmehta
Copy link

You are not setting the max_replay_size parameter and if not specified, Brax uses num_timesteps as the max_replay_size which is 30_000_000 in your case. So a replay buffer of 30 million is being initialized which captures the VRAM. Try specifying max_replay_size to a smaller value and things should work.

@AlexS28
Copy link
Author

AlexS28 commented Feb 23, 2024

Awesome, thanks my RL is working now with that change.

@AlexS28 AlexS28 closed this as completed Feb 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Request for help or information
Projects
None yet
Development

No branches or pull requests

2 participants