Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading real data on subset of hosts #187

Merged
merged 1 commit into from
Apr 2, 2024
Merged

Conversation

khatwanimohit
Copy link
Collaborator

No description provided.

@khatwanimohit khatwanimohit force-pushed the mohit/hosts_real_data branch 2 times, most recently from 92cd603 to 3841984 Compare September 29, 2023 23:06
@rwitten rwitten removed their assignment Sep 30, 2023
Copy link
Collaborator

@rwitten rwitten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just unassigning myself. We will talk live to discuss landing this CR.

@khatwanimohit khatwanimohit force-pushed the mohit/hosts_real_data branch 3 times, most recently from a19360b to d41cbb6 Compare January 31, 2024 19:59
@khatwanimohit khatwanimohit force-pushed the mohit/hosts_real_data branch 2 times, most recently from 7db70f5 to 9d6ae76 Compare February 27, 2024 20:40
@khatwanimohit khatwanimohit force-pushed the mohit/hosts_real_data branch 6 times, most recently from 60f3a1d to b4578c1 Compare March 12, 2024 23:43
@@ -317,12 +317,16 @@ def get_individual_scales(scale):
def calculate_global_batch_sizes(raw_keys):
""" Calculates target global batch size from target devices and per_device_batch"""
per_device_batch_size = raw_keys['per_device_batch_size']
expansion_factor_real_data = raw_keys['expansion_factor_real_data']
num_devices = get_num_target_devices(raw_keys)
if per_device_batch_size < 1.0:
# For per_device_batch_size<1, we load the data as if per_device_batch_size=1
global_batch_size_to_load = num_devices
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But couldn't fewer than that number of hosts load the data?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like I think we should still ramp the data?

MaxText/input_pipeline/input_pipeline_interface.py Outdated Show resolved Hide resolved
MaxText/input_pipeline/input_pipeline_interface.py Outdated Show resolved Hide resolved
@rwitten rwitten removed their assignment Mar 18, 2024
@khatwanimohit khatwanimohit force-pushed the mohit/hosts_real_data branch 4 times, most recently from c108249 to f92ec46 Compare March 21, 2024 23:01
@rwitten rwitten requested a review from aireenmei March 22, 2024 18:06
Copy link
Collaborator

@rwitten rwitten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has my approval but please discuss with Aireen and Roshani before merging.

@aireenmei @RoshaniN -- this CR changes it so MaxText can make balancing decisions about the number of hosts that read from GCS. We find in practice this is a useful lever because the thundering horde of VMs can crush GCS but the aggregate data isn't too much.

@aireenmei
Copy link
Collaborator

Thanks for adding me to the thread. I think I missed some context so not sure I understand the whole idea. I see we only have a subset of host loading data, are they going to pass the data to hosts that are not loading real data? Why the rest of hosts are returning synthetic data?

@RoshaniN
Copy link
Collaborator

I don't see any issues as standalone_dataloader would be using the same input pipeline as train. I would like to understand the recommendations on the expansion_factor_real_data, will do that offline.

@RoshaniN RoshaniN self-requested a review March 25, 2024 20:26
@rwitten rwitten removed their assignment Mar 26, 2024
@aireenmei
Copy link
Collaborator

Could you share some convergence results when expansion_factor_real_data != -1 ?

@khatwanimohit
Copy link
Collaborator Author

Could you share some convergence results when expansion_factor_real_data != -1 ?

Convergence test with expansion_factor_real_data=4 (i.e. 16 hosts out of 64 hosts will load the real data)
https://cloudlogging.app.goo.gl/DPDSXu2tSM3ga8hG8

@aireenmei
Copy link
Collaborator

Thanks! Could you also run a convergence test with grain? bash end_to_end/test_convergence_1b_params.sh DATASET_TYPE="c4-array_record" ...

@khatwanimohit
Copy link
Collaborator Author

khatwanimohit commented Apr 2, 2024

Thanks! Could you also run a convergence test with grain? bash end_to_end/test_convergence_1b_params.sh DATASET_TYPE="c4-array_record" ...

Convergence run with c4-array_record: https://cloudlogging.app.goo.gl/jiFMzAx8SDRw4nM46

@aireenmei I will also add airflow tests for both the convergence test

Copy link
Collaborator

@RoshaniN RoshaniN left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Mohit!

MaxText/configs/base.yml Outdated Show resolved Hide resolved
MaxText/input_pipeline/input_pipeline_interface.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@aireenmei aireenmei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Mohit!

@copybara-service copybara-service bot merged commit 5cb6052 into main Apr 2, 2024
8 checks passed
@copybara-service copybara-service bot deleted the mohit/hosts_real_data branch April 2, 2024 21:21
@A9isha A9isha mentioned this pull request Apr 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants