### Prelude

In [1]:
def prelude():
    def _clear_vars():
        for _n in list(globals()):
            if _n != "prelude" and _n.startswith("_") and _n not in ("In","Out","get_ipython","exit","quit"):
                del globals()[_n]

    _clear_vars()

    G = globals()

    import contextlib

    @contextlib.contextmanager
    def dev_sync(device: str):
        is_cuda = torch.cuda.is_available() and str(device).startswith("cuda")
        is_mps = hasattr(torch, "mps") and torch.mps.is_available() and str(device).startswith("mps")
        if is_cuda:
            torch.cuda.synchronize()
        elif is_mps:
            torch.mps.synchronize()
        try:
            yield
        finally:
            if is_cuda:
                torch.cuda.synchronize()
            elif is_mps:
                torch.mps.synchronize()
    G["dev_sync"] = dev_sync

# pretend export to help static analysis
def dev_sync(device: str):
    pass

### Device Copy Overhead

In [2]:
prelude()

In [3]:
import time, torch, torch.nn.functional as F
V = 50000
dtype = torch.bfloat16
iters = 2048
x=torch.randn(V, device="mps", dtype=dtype)
s=time.perf_counter()
for _ in range(iters): _ = int(torch.argmax(x).item())
torch.mps.synchronize(); print("1-scalar copy:", time.perf_counter()-s)

p = F.softmax(x, -1).to("cpu"); s=time.perf_counter()
for _ in range(iters): _ = int(torch.multinomial(p,1))  # simulating CPU sampler
print("CPU multinomial -> CPU scalar:", time.perf_counter()-s)


p = F.softmax(x, -1).to("mps"); s=time.perf_counter()
for _ in range(iters): _ = int(torch.multinomial(p,1))
print("MPS multinomial -> CPU scalar:", time.perf_counter()-s)

p =F.softmax(x, -1).to("mps"); s=time.perf_counter()
for _ in range(iters): _ = int(torch.multinomial(p,1))
print("MPS multinomial -> CPU scalar:", time.perf_counter()-s)


1-scalar copy: 0.5925095830170903
CPU multinomial -> CPU scalar: 1.8556773330201395
MPS multinomial -> CPU scalar: 2.4886934169917367
MPS multinomial -> CPU scalar: 2.4192957500054035


### Rotary DTypes

In [4]:
prelude()

In [5]:
import torch, torch.nn as nn

def _apply_rope(x, cos, sin):
    x1, x2 = x.chunk(2, dim=-1)
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat((y1, y2), dim=-1)

class RotaryPositionEncoding(nn.Module):
    def __init__(self, head_dim, max_seq_len, dtype):
        super().__init__()
        half = head_dim // 2
        keep = head_dim // 4
        base = 1024
        if keep == 0:
            angular = (1 / base) ** torch.linspace(0, 1, steps=half, dtype=dtype)
        else:
            active = (1 / base) ** torch.linspace(0, 1, steps=keep, dtype=dtype)
            angular = torch.cat([active, active.new_zeros(half - keep)])
        self.inv_freq = nn.Buffer(angular, persistent=False)
        t = torch.arange(max_seq_len, dtype=dtype)
        theta = torch.einsum("i,j->ij", t, self.inv_freq)
        self.cos = nn.Buffer(theta.cos().to(dtype), persistent=False)
        self.sin = nn.Buffer(theta.sin().to(dtype), persistent=False)
        self._max_seq_len = int(max_seq_len)

    def forward(self, x_BTHD):
        L = x_BTHD.size(-3)
        cos = self.cos[:L][None, :L, None, :]
        sin = self.sin[:L][None, :L, None, :]
        return _apply_rope(x_BTHD, cos, sin)



In [16]:
from time import perf_counter as p
import torch.nn.functional as F

device = "mps"
rotary_dtypes = [torch.float32, torch.float16, torch.bfloat16]
ref_dtype = torch.float32
B, T, H, D, iters, max_seq_len = 1, 1024, 8, 128, 2048, 65536
ref_rotary = RotaryPositionEncoding(D, max_seq_len, dtype=ref_dtype).to(device)

for dtype in rotary_dtypes:
    rotary = RotaryPositionEncoding(D, max_seq_len, dtype=dtype).to(device)
    X = torch.randn(B, T, H, D, dtype=dtype, device=device)
    with dev_sync(device):
        t0 = p()
        for i in range(iters):
            rotary(X)
    dur = p() - t0
    print(f"Rotary [{dtype}] | Avg dur/iter: {dur/iters:.6f}s")

    S = torch.zeros(iters, dtype=ref_dtype, device=device)
    for i in range(iters):
        Y = torch.randn(B, T, H, D, dtype=ref_dtype, device=device)
        approx = rotary(Y.to(dtype)).to(ref_dtype)
        ref_res = ref_rotary(Y)
        dist = F.cosine_similarity(ref_res, approx, dim=-1).mean()
        S[i] = dist
    print(f"Rotary [{dtype}] | Mean cosine similarity: {S.mean().item():.6f}")


Rotary [torch.float32] | Avg dur/iter: 0.000122s
Rotary [torch.float32] | Mean cosine similarity: 1.000000
Rotary [torch.float16] | Avg dur/iter: 0.000104s
Rotary [torch.float16] | Mean cosine similarity: 0.999386
Rotary [torch.bfloat16] | Avg dur/iter: 0.000103s
Rotary [torch.bfloat16] | Mean cosine similarity: 0.959034
