Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoTP] Make AutoTP work when num_heads not divisible by number of workers #4011

Merged
merged 55 commits into from Oct 25, 2023

Conversation

delock
Copy link
Contributor

@delock delock commented Jul 21, 2023

Currently AutoTP will assert when num_heads are not divisible by number of workers. However in some situation this might be what user intend to do. i.e. having three compute device for model with 16 heads and want to put all three compute device to work. In this situation, each worker will process 5 or 6 heads, which is still better than run the workload on two compute device and each worker process 8 heads and leave the third device idle.

This PR distribute attention heads to each worker as even as possible, and shard hidden_size according to this distribution, this allows AutoTP run OOB even when number heads are not divisible by number of devices installed on the system.

Copy link
Contributor

@mrwyattii mrwyattii left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a unit test to verify for an odd number of devices? Perhaps extend this test class:

class TestAutoTensorParallelism(DistributedTest):

to add something like:

@pytest.mark.world_size(3)
def test_odd_world_size(
        self,
        model_w_task,
        query,
        inf_kwargs,
        assert_fn,
        dtype,
    ):
        invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
        if invalid_test_msg:
            pytest.skip(invalid_test_msg)

        model, task = model_w_task
        local_rank = int(os.getenv("LOCAL_RANK", "0"))
        world_size = int(os.getenv("WORLD_SIZE", "2"))

        pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt")
        bs_output = pipe(query, **inf_kwargs)

        pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
        # Switch device to GPU so that input tensors are not on CPU
        pipe.device = torch.device(get_accelerator().device_name(local_rank))
        ds_output = pipe(query, **inf_kwargs)

        print(local_rank, "baseline", bs_output)
        print(local_rank, "deepspeed", ds_output)
        assert assert_fn(bs_output, ds_output)

@delock delock requested a review from tjruwase as a code owner July 25, 2023 03:43
@delock
Copy link
Contributor Author

delock commented Jul 25, 2023

@mrwyattii Test added. There is a result mismatch assertion in the test and I can also reproduce this assertion with CPU+BF16. Will need sometime to debug this issue.

Can we add a unit test to verify for an odd number of devices? Perhaps extend this test class:

class TestAutoTensorParallelism(DistributedTest):

to add something like:

@pytest.mark.world_size(3)
def test_odd_world_size(
        self,
        model_w_task,
        query,
        inf_kwargs,
        assert_fn,
        dtype,
    ):
        invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
        if invalid_test_msg:
            pytest.skip(invalid_test_msg)

        model, task = model_w_task
        local_rank = int(os.getenv("LOCAL_RANK", "0"))
        world_size = int(os.getenv("WORLD_SIZE", "2"))

        pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt")
        bs_output = pipe(query, **inf_kwargs)

        pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
        # Switch device to GPU so that input tensors are not on CPU
        pipe.device = torch.device(get_accelerator().device_name(local_rank))
        ds_output = pipe(query, **inf_kwargs)

        print(local_rank, "baseline", bs_output)
        print(local_rank, "deepspeed", ds_output)
        assert assert_fn(bs_output, ds_output)

@molly-smith molly-smith self-requested a review August 2, 2023 18:04
@delock
Copy link
Contributor Author

delock commented Aug 17, 2023

@mrwyattii @molly-smith I have identified the issue for result mismatch and fixed. Can you help restart workflow? Thanks!

@mrwyattii Test added. There is a result mismatch assertion in the test and I can also reproduce this assertion with CPU+BF16. Will need sometime to debug this issue.

Can we add a unit test to verify for an odd number of devices? Perhaps extend this test class:

class TestAutoTensorParallelism(DistributedTest):

to add something like:

@pytest.mark.world_size(3)
def test_odd_world_size(
        self,
        model_w_task,
        query,
        inf_kwargs,
        assert_fn,
        dtype,
    ):
        invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
        if invalid_test_msg:
            pytest.skip(invalid_test_msg)

        model, task = model_w_task
        local_rank = int(os.getenv("LOCAL_RANK", "0"))
        world_size = int(os.getenv("WORLD_SIZE", "2"))

        pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt")
        bs_output = pipe(query, **inf_kwargs)

        pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
        # Switch device to GPU so that input tensors are not on CPU
        pipe.device = torch.device(get_accelerator().device_name(local_rank))
        ds_output = pipe(query, **inf_kwargs)

        print(local_rank, "baseline", bs_output)
        print(local_rank, "deepspeed", ds_output)
        assert assert_fn(bs_output, ds_output)

@delock
Copy link
Contributor Author

delock commented Sep 20, 2023

Hi @mrwyattii @molly-smith The test failure is fixed. Can you help restart CI workflow? Thanks!

@mrwyattii
Copy link
Contributor

@delock approved the PR, but there is a merge conflict. Can you resolve that? The PR will auto-merge after!

@tjruwase
Copy link
Contributor

tjruwase commented Oct 2, 2023

@delock, can you please help with the merge conflict?

auto-merge was automatically disabled October 7, 2023 08:57

Head branch was pushed to by a user without write access

@delock
Copy link
Contributor Author

delock commented Oct 7, 2023

@mrwyattii @tjruwase the conflict is resolved, thanks!

@delock
Copy link
Contributor Author

delock commented Oct 10, 2023

Conflict with lm_head parallelism resolved, and add uneven sharding support for lm_head parallel.

@delock
Copy link
Contributor Author

delock commented Oct 12, 2023

Hi @mrwyattii @tjruwase , the recent merge conflict had been resolved, and we also support uneven sharding of lm_head parallel. Can you take a quick look whether it can be put into merge queue? Thanks!

@tjruwase tjruwase added this pull request to the merge queue Oct 25, 2023
Merged via the queue into microsoft:master with commit f15cccf Oct 25, 2023
15 checks passed
baodii pushed a commit to baodii/DeepSpeed that referenced this pull request Nov 7, 2023
…orkers (microsoft#4011)

* allow number of heads not divisible by number of ranks

* get num_heads from model config, more robust

* simplify logic where num_head itself is sharded

* name tweaks

* make code more robust where num_attention_heads may not be defined in model_config

* support num_key_value_heads < num_attention_heads which is used by llama2

* add test for 5 ranks

* change odd rank # to 3 to avoid test skip

* add get_shard_size function

* modify sharding mechanism according to latest auto TP

* fix accuracy issue

* fix format

* skip tests with fusedqkv

* remove skip of fusedqkv tests

* skip test fusedqkv with odd number of ranks

* support model with n_heads in model_config

* fix TestInjectionPolicy::test[fp32-t5]

* fix uneven_heads on some fusedqkv types (microsoft#12)

* odd support fusedqkv

* fix format and clear text

* better fix when activation size cannot be divided by number of heads

* move tp_shard.py under module_inject

* Add get_num_kv_heads in tp_shard.py

* Refine according to comments

* remove old comment

* fix bug in getting num_kv_heads

* support uneven sharding of lm_head tensor parallel

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Molly Smith <112220543+molly-smith@users.noreply.github.com>
Co-authored-by: mzl <mingzhi.liu@intel.com>
Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
…orkers (microsoft#4011)

* allow number of heads not divisible by number of ranks

* get num_heads from model config, more robust

* simplify logic where num_head itself is sharded

* name tweaks

* make code more robust where num_attention_heads may not be defined in model_config

* support num_key_value_heads < num_attention_heads which is used by llama2

* add test for 5 ranks

* change odd rank # to 3 to avoid test skip

* add get_shard_size function

* modify sharding mechanism according to latest auto TP

* fix accuracy issue

* fix format

* skip tests with fusedqkv

* remove skip of fusedqkv tests

* skip test fusedqkv with odd number of ranks

* support model with n_heads in model_config

* fix TestInjectionPolicy::test[fp32-t5]

* fix uneven_heads on some fusedqkv types (microsoft#12)

* odd support fusedqkv

* fix format and clear text

* better fix when activation size cannot be divided by number of heads

* move tp_shard.py under module_inject

* Add get_num_kv_heads in tp_shard.py

* Refine according to comments

* remove old comment

* fix bug in getting num_kv_heads

* support uneven sharding of lm_head tensor parallel

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Molly Smith <112220543+molly-smith@users.noreply.github.com>
Co-authored-by: mzl <mingzhi.liu@intel.com>
Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants