Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS support #169

Open
louismullie opened this issue Jul 29, 2023 · 2 comments
Open

MPS support #169

louismullie opened this issue Jul 29, 2023 · 2 comments

Comments

@louismullie
Copy link

Thanks for the amazing, clean library. I had a few issues running it on MPS, I thought I'd share how I got it to work. May not be suitable for a PR given some of it is hacky, but could evolve into one if there is interest and additional input from the community.

Minimal diff to get it to run: master...louismullie:llama2.c:master

1. Disable AutoCast

RuntimeError: User specified an unsupported autocast device_type 'mps'

AutoCast not yet supported on MPS. See: pytorch/pytorch#88415

2. Disable fused optimizer

fused=True requires all the params to be floating point Tensors of supported devices

Fused=True not yet supported on MPS. No GitHub issue found.

3. Enable CPU fallback for unsupported Pytorch ops

The operator 'aten::_weight_norm_interface' is not currently implemented for the MPS device

Set environment: PYTORCH_ENABLE_MPS_FALLBACK=1 and expect slowdown.
See: pytorch/pytorch#77754, pytorch/pytorch#77764

4. Suppress Dynamo warnings.

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Unsupported device type: mps

Set torch._logging.set_logs(dynamo=logging.ERROR) to suppress.

5. Try to detect MPS everywhere CUDA can't be detected.

Self-explanatory.

7. Remove use of complex operations in rotary embedding

Complex types are unsupported on MPS.

See: pytorch/pytorch#95976

This required reimplementing the functions precompute_freqs_cis, reshape_for_broadcast, apply_rotary_emb in train.py to split out the complex operations:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):

    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()

    return torch.stack([torch.cos(freqs), torch.sin(freqs)], -1)

def reshape_for_broadcast(freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, x: torch.Tensor):

    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cos.shape == freqs_sin.shape == (x.shape[1], x.shape[-1] // 2)
    shape = [1] * ndim
    shape[1] = x.shape[1]
    shape[-1] = x.shape[-1] // 2

    return freqs_cos.view(*shape), freqs_sin.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:

    freqs_cos, freqs_sin = freqs[..., 0], freqs[..., 1]
    freqs_cos, freqs_sin = reshape_for_broadcast(freqs_cos, freqs_sin, xq)

    xq = xq.view(*xq.shape[:-1], -1, 2)
    xk = xk.view(*xk.shape[:-1], -1, 2)
    xq_r, xq_i = xq[..., 0], xq[..., 1]
    xk_r, xk_i = xk[..., 0], xk[..., 1]

    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin 
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin 
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

    xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
    xk_out = torch.stack((xk_out_r, xk_out_i), dim=-1).flatten(-2)

    return xq_out.type_as(xq), xk_out.type_as(xk)

Essentially:

  • Separate the computation of the cosine and sine parts in precompute_freqs_cis().
  • Split the real and imaginary parts of the input tensors in apply_rotary_emb() before applying the transformations.

Unfortunately, this results in a ~10X slowdown in the apply_rotary_emb function, and I haven't had much luck in optimizing it further. Maybe someone else has some ideas!

8. NB. Must use PyTorch nightly

With the above modifications, the code runs as is for training and inference in this conda environment:

conda create -n torch-gpu python=3.8
conda activate torch-gpu
conda install pytorch torchvision torchaudio -c pytorch-nightly

@tyeestudio
Copy link

thanks for sharing. what is the result of inference (MPS), same quality as original one (karpathy's version)?

@PhilippeFerreiraDeSousa
Copy link

PhilippeFerreiraDeSousa commented Sep 18, 2023

Hi, awesome research!
I actually have been training on m2 (pro) mac mini using this codebase by just adding

from torch._dynamo import config
config.suppress_errors=True

I still get a lot of error logs printed at startup (2.5 screen heights of them) but then the training goes on.
I have been wondering if this impacts performance a lot. It has too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants