-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loading real data on subset of hosts #187
Conversation
92cd603
to
3841984
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just unassigning myself. We will talk live to discuss landing this CR.
3841984
to
f064388
Compare
a19360b
to
d41cbb6
Compare
7db70f5
to
9d6ae76
Compare
60f3a1d
to
b4578c1
Compare
MaxText/pyconfig.py
Outdated
@@ -317,12 +317,16 @@ def get_individual_scales(scale): | |||
def calculate_global_batch_sizes(raw_keys): | |||
""" Calculates target global batch size from target devices and per_device_batch""" | |||
per_device_batch_size = raw_keys['per_device_batch_size'] | |||
expansion_factor_real_data = raw_keys['expansion_factor_real_data'] | |||
num_devices = get_num_target_devices(raw_keys) | |||
if per_device_batch_size < 1.0: | |||
# For per_device_batch_size<1, we load the data as if per_device_batch_size=1 | |||
global_batch_size_to_load = num_devices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But couldn't fewer than that number of hosts load the data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like I think we should still ramp the data?
c108249
to
f92ec46
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has my approval but please discuss with Aireen and Roshani before merging.
@aireenmei @RoshaniN -- this CR changes it so MaxText can make balancing decisions about the number of hosts that read from GCS. We find in practice this is a useful lever because the thundering horde of VMs can crush GCS but the aggregate data isn't too much.
Thanks for adding me to the thread. I think I missed some context so not sure I understand the whole idea. I see we only have a subset of host loading data, are they going to pass the data to hosts that are not loading real data? Why the rest of hosts are returning synthetic data? |
I don't see any issues as standalone_dataloader would be using the same input pipeline as train. I would like to understand the recommendations on the expansion_factor_real_data, will do that offline. |
f92ec46
to
24d9512
Compare
Could you share some convergence results when expansion_factor_real_data != -1 ? |
Convergence test with expansion_factor_real_data=4 (i.e. 16 hosts out of 64 hosts will load the real data) |
Thanks! Could you also run a convergence test with grain? |
Convergence run with c4-array_record: https://cloudlogging.app.goo.gl/jiFMzAx8SDRw4nM46 @aireenmei I will also add airflow tests for both the convergence test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Mohit!
24d9512
to
0681c96
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Mohit!
0681c96
to
4272b6a
Compare
4272b6a
to
2a0972b
Compare
No description provided.