Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tensor parallelism #17

Closed
marib00 opened this issue Jun 12, 2024 · 4 comments
Closed

Implement tensor parallelism #17

marib00 opened this issue Jun 12, 2024 · 4 comments

Comments

@marib00
Copy link

marib00 commented Jun 12, 2024

I thought tensor parallelism would be an interesting idea. There's a tutorial for this and even some code examples, but so far no joy.

I started simple, trying to shard the MLP like this:

# run using: torchrun --standalone --nproc-per-node=2 train_gpt2_tp.py

from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel

_world_size = int(os.environ["WORLD_SIZE"])
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))

class Block(nn.Module):

    def __init__(self, config):

        ...
        
        # was: self.mlp = MLP(config)
        self.mlp = parallelize_module( 
            module=MLP(config),
            device_mesh=device_mesh,
            parallelize_plan={
                "c_fc": ColwiseParallel(),
                "c_proj": RowwiseParallel(),
            },
        )

But PyTorch (nightly) gives me grief:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/Sync-shared/projects/repos/build-nanogpt/train_gpt2_tp.py", line 326, in <module>
[rank0]:     norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/utils/clip_grad.py", line 21, in _no_grad_wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/utils/clip_grad.py", line 68, in clip_grad_norm_
[rank0]:     norms.extend(torch._foreach_norm(device_grads, norm_type))
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/_compile.py", line 31, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/distributed/_tensor/api.py", line 309, in __torch_dispatch__
[rank0]:     return DTensor._op_dispatcher.dispatch(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/distributed/_tensor/_dispatch.py", line 115, in dispatch
[rank0]:     op_info = self.unwrap_to_op_info(op_call, args, kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/distributed/_tensor/_dispatch.py", line 348, in unwrap_to_op_info
[rank0]:     args_schema.append(try_get_replicate_spec(arg, mesh))
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/torch/distributed/_tensor/_dispatch.py", line 329, in try_get_replicate_spec
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: aten._foreach_norm.Scalar: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
[rank1]: Traceback (most recent call last):
[rank1]:   File "/mnt/Sync-shared/projects/repos/build-nanogpt/train_gpt2_tp.py", line 326, in <module>
[rank1]:     norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/utils/clip_grad.py", line 21, in _no_grad_wrapper
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.12/site-packages/torch/nn/utils/clip_grad.py", line 68, in clip_grad_norm_
[rank1]:     norms.extend(torch._foreach_norm(device_grads, norm_type))
[rank1]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.12/site-packages/torch/_compile.py", line 31, in inner
[rank1]:     return disable_fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.12/site-packages/torch/distributed/_tensor/api.py", line 309, in __torch_dispatch__
[rank1]:     return DTensor._op_dispatcher.dispatch(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.12/site-packages/torch/distributed/_tensor/_dispatch.py", line 115, in dispatch
[rank1]:     op_info = self.unwrap_to_op_info(op_call, args, kwargs)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.12/site-packages/torch/distributed/_tensor/_dispatch.py", line 348, in unwrap_to_op_info
[rank1]:     args_schema.append(try_get_replicate_spec(arg, mesh))
[rank1]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.12/site-packages/torch/distributed/_tensor/_dispatch.py", line 329, in try_get_replicate_spec
[rank1]:     raise RuntimeError(
[rank1]: RuntimeError: aten._foreach_norm.Scalar: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

As a quick fix I tried converting what I thought were DTensors to local tensors:

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu    = nn.GELU(approximate='tanh')
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x.to_local() # change here!

but then I get even more grief 🤦‍♂️:

[rank0]:   File "/mnt/Sync-shared/projects/repos/build-nanogpt/train_gpt2_tp.py", line 58, in forward
[rank0]:     return x.to_local()
[rank0]:            ^^^^^^^^^^
[rank0]: AttributeError: 'AsyncCollectiveTensor' object has no attribute 'to_local'

Any ideas? 🙏

@marib00
Copy link
Author

marib00 commented Jun 12, 2024

Turns out RowwiseParallel(use_local_output=True) is the default, so x should already be a torch.Tensor... 🤔

@marib00
Copy link
Author

marib00 commented Jun 12, 2024

Progress - disabling gradient clipping and fused AdamW actually makes it work (even with torch.compile!) 🎉🎉🎉
Next step - tensor parallel attention 👍

@marib00
Copy link
Author

marib00 commented Jun 14, 2024

Almost there but I've discovered a weird behaviour of torch.distributed.tensor.parallel.SequenceParallel(). It should be sharding across the sequence dimension i.e. [B, T, C] -> [B, T//_world_size, C] but it seems to be tiling instead i.e. [B, T, C] -> [B, T*_world_size, C], which obviously doesn't sit well with the loss function, which now gets _world_size times too many logits. Didn't realise how much I didn't know about parallelism! Investigation continues... 🧐

On the bright side, getting just shy of 150k tok/sec on a 2x 3090 configuration and can now work with a model, which wouldn't fit on a single 3090. Bad news for 8x A100 users - you have too many GPUs to shard the 12 attention heads 🤣

#protip: didn't know it was possible to use the VScode debugger with distributed workloads, turns out you can and all the [conditional] breakpoints, stepping into libraries etc. works like a dream! Here's my launch.json file for that - hope people find it useful:

{
    // Use IntelliSense to learn about possible attributes.
    // Hover to view descriptions of existing attributes.
    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
    "version": "0.2.0",
    "configurations": [
        {
            "name": "Distributed train_gpt2_tp.py",
            "type": "debugpy",
            "request": "launch",
            "purpose": ["debug-in-terminal"],
            "console": "integratedTerminal",
            "module":"torch.distributed.run",
            "args":["--standalone","--nproc_per_node=2","train_gpt2_tp.py"],
            "justMyCode": false,
        }
    ]
}

@marib00
Copy link
Author

marib00 commented Jun 18, 2024

Ok, all done, except it disappointingly slow! 😮 Quite possible I've messed something up, so if anybody notices, do let me know please.

The repo available at https://github.com/marib00/build-nanogpt and the file of interest is train_gpt2_tp.py - I didn't touch any of the other files.

I have added some benchmarks to README.md.

@marib00 marib00 closed this as completed Jun 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant