### Prelude

In [1]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "0"

def prelude():
    def _clear_vars():
        for _n in list(globals()):
            if _n != "prelude" and _n.startswith("_") and _n not in ("In","Out","get_ipython","exit","quit"):
                del globals()[_n]

    _clear_vars()

    G = globals()

    import contextlib, torch

    @contextlib.contextmanager
    def dev_sync(device: str):
        is_cuda = torch.cuda.is_available() and str(device).startswith("cuda")
        is_mps = hasattr(torch, "mps") and torch.mps.is_available() and str(device).startswith("mps")
        if is_cuda:
            torch.cuda.synchronize()
        elif is_mps:
            torch.mps.synchronize()
        try:
            yield
        finally:
            if is_cuda:
                torch.cuda.synchronize()
            elif is_mps:
                torch.mps.synchronize()
    G["dev_sync"] = dev_sync

# pretend export to help static analysis
def dev_sync(device: str):
    pass

prelude()

### Device Copy Overhead

In [2]:
prelude()

In [24]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "0"
os.environ["PYTORCH_MPS_FAST_MATH"] = "1"

In [33]:
import time, torch, torch.nn.functional as F, pandas as pd
V, iters = 50000, 2048
dtypes = [torch.bfloat16, torch.float16, torch.float32]
device = "mps"
target_devices = ["cpu", "mps"]

print(f"tensor shape: ({V},), iters: {iters}, dtypes: {dtypes}, target_devices: {target_devices}")
rows = []
for device in target_devices:
    for dtype in dtypes:
        X = torch.randn(V, device=device, dtype=dtype)
        p = F.softmax(X, -1)

        # same-device .to(...)
        for _ in range(10): 
            _ = X.to(device)
        with dev_sync(device):
            s = time.perf_counter()
            for _ in range(iters): 
                _ = X.to(device)
        t_same = time.perf_counter() - s
        t_same = 1e3 * t_same / iters

        # cross-device .to(...)
        target_device = "cpu" if device == "mps" else "mps"
        for _ in range(10): 
            _ = X.to(target_device)
        with dev_sync(device):
            s = time.perf_counter()
            for _ in range(iters): 
                _ = X.to(target_device)
        t_cross = time.perf_counter() - s
        t_cross = 1e3 * t_cross / iters

        # argmax -> CPU scalar
        for _ in range(10):
            _ = int(torch.argmax(X))
        with dev_sync(device):
            s = time.perf_counter()
            for _ in range(iters):
                _ = torch.argmax(X)
        t_argmax = time.perf_counter() - s
        t_argmax = 1e3 * t_argmax / iters

        # argmax -> CPU scalar
        for _ in range(10): 
            _ = int(torch.argmax(X).item())
        with dev_sync(device):
            s = time.perf_counter()
            for _ in range(iters): 
                _ = torch.argmax(X).item()
        t_argmax_item = time.perf_counter() - s
        t_argmax_item = 1e3 * t_argmax_item / iters

        # multinomial (tensor result)
        for _ in range(10): 
            _ = torch.multinomial(p, 1).item()
        with dev_sync(device):
            s = time.perf_counter()
            for _ in range(iters): 
                _ = torch.multinomial(p, 1)
        t_mult = time.perf_counter() - s
        t_mult = 1e3 * t_mult / iters

        # multinomial -> CPU scalar
        with dev_sync(device):
            s = time.perf_counter()
            for _ in range(iters): 
                _ = torch.multinomial(p, 1).item()
        t_mult_item = time.perf_counter() - s
        t_mult_item = 1e3 * t_mult_item / iters

        # tensor -> CPU multinomial -> CPU scalar
        for _ in range(10): 
            _ = torch.multinomial(p.to("cpu"), 1).item()
        with dev_sync(device):
            s = time.perf_counter()
            for _ in range(iters): 
                _ = torch.multinomial(p.to("cpu"), 1).item()  # simulating CPU sampler
        t_cpu_mult_item = time.perf_counter() - s
        t_cpu_mult_item = 1e3 * t_cpu_mult_item / iters

        rows.append({
            "device": device,
            "dtype": str(dtype).replace("torch.", ""),
            "to_same_dev": t_same,
            "to_different_dev": t_cross,
            "argmax": t_argmax,
            "argmax_item": t_argmax_item,
            "multinomial": t_mult,
            "multinomial_item": t_mult_item,
            "cpu_multinomial_item": t_cpu_mult_item,
        })

df = pd.DataFrame(rows)
df = df[[
    "device",
    "dtype",
    "to_same_dev",
    "to_different_dev",
    "argmax",
    "argmax_item",
    "multinomial",
    "multinomial_item",
    "cpu_multinomial_item",
]].set_index("device").sort_index()

print(df.to_string(float_format=lambda x: f"{x:.4f}"))

tensor shape: (50000,), iters: 2048, dtypes: [torch.bfloat16, torch.float16, torch.float32], target_devices: ['cpu', 'mps']
           dtype  to_same_dev  to_different_dev  argmax  argmax_item  multinomial  multinomial_item  cpu_multinomial_item
device                                                                                                                   
cpu     bfloat16       0.0003            0.1275  0.0360       0.0361       0.5669            0.5573                0.5546
cpu      float16       0.0002            0.1165  0.0233       0.0233       0.5250            0.5275                0.5281
cpu      float32       0.0002            0.1150  0.0458       0.0458       0.5990            0.5649                0.5562
mps     bfloat16       0.0002            0.1104  0.0217       0.1474       0.5712            0.7345                0.7183
mps      float16       0.0002            0.1154  0.0161       0.1508       0.6070            0.7264                0.6933
mps      float32      

#### Pseudo-torch.multinomial

In [48]:
prelude()

import torch, torch.nn.functional as F
import pandas as pd

def _gumbel_multinomial(X, unused=1):
    u = torch.rand(X.shape, device=X.device, dtype=X.dtype)
    g = -torch.log(-torch.log(u.clamp_min_(1e-6)))
    return torch.argmax(X + g, dim=-1)

def _pseudo_multinomial(p, unused=1):
    q = torch.empty_like(p)
    # validity checks (as in the torch kernel)
    _ = ((p.max() < float('inf')) & (p.min() >= 0)).item()
    _ = (p.sum() == 0).item() if p.dim()==1 else ((p.sum(1) == 0).sum().item())
    q.exponential_(1.0)
    r = p / q
    return torch.argmax(r)

devices = ['mps', 'cpu']
dtypes = [torch.float32, torch.float16, torch.bfloat16]
iters = 2048
V = 50000
funcs = [_gumbel_multinomial, _pseudo_multinomial, torch.multinomial]
perf_rows = []
numerics_rows = []
with torch.inference_mode():
    for device in devices:
        for dtype in dtypes:
            X = torch.randn(V, device=device, dtype=dtype)
            P = F.softmax(X, -1)
            for func in funcs:
                arg = P if func == torch.multinomial else X
                with dev_sync(device):
                    t0 = time.perf_counter()
                    for i in range(iters):
                        func(arg,1)
                t1 = time.perf_counter()
                rate = 1e3*(t1-t0)/iters
                perf_rows.append({
                    "device": device,
                    "func": func.__name__,
                    "dtype": str(dtype).replace("torch.", ""),
                    "avg (ms)": rate,
                })

df = pd.DataFrame(perf_rows)
df = df[[
    "device",
    "func",
    "dtype",
    "avg (ms)",
]].set_index(["func","device"]).sort_index(level=["func", "device"])

print(df.to_string(float_format=lambda x: f"{x:.3f}"))


                               dtype  avg (ms)
func                device                    
_gumbel_multinomial cpu      float32     0.350
                    cpu      float16     0.319
                    cpu     bfloat16     0.359
                    mps      float32     0.056
                    mps      float16     0.051
                    mps     bfloat16     0.057
_pseudo_multinomial cpu      float32     0.563
                    cpu      float16     0.527
                    cpu     bfloat16     0.561
                    mps      float32     0.623
                    mps      float16     0.619
                    mps     bfloat16     0.580
multinomial         cpu      float32     0.575
                    cpu      float16     0.533
                    cpu     bfloat16     0.556
                    mps      float32     0.610
                    mps      float16     0.606
                    mps     bfloat16     0.577


In [93]:
def _topk_filter(x, k):
    if k is None or k <= 0 or k >= x.size(-1):
        return x
    v, i = torch.topk(x, k)
    y = torch.full_like(x, float("-inf"))
    y.scatter_(dim=-1, index=i, src=v)
    return y

def _topk_filter_simple(x, k):
    """
    bug: if there are ties then returned tensor will have >k entries
    """
    v, _ = torch.topk(x, k)
    y = torch.where(x >= v[-1], x, torch.tensor(float("-inf"), device=x.device))
    return y

# dtypes = [torch.float32, torch.float16, torch.bfloat16]
dtypes = [ torch.float16,torch.float16, torch.float16,]
k = 100
V = 50000
seed = 1337
torch.manual_seed(seed)
device = "mps"
for dtype in dtypes:
    X = torch.randn(V, dtype=dtype, device=device)
    x_tf = _topk_filter(X, k)
    x_fs = _topk_filter_simple(X, k)

    rtol, atol = 1e-3, 1e-3
    a = x_tf.to(torch.float32)
    b = x_fs.to(torch.float32)

    fin = torch.isfinite(a) & torch.isfinite(b)
    if fin.any():
        diff = (a[fin] - b[fin])
        print("max_abs_err(finite):", diff.abs().max().item())
        print("mean_abs_err(finite):", diff.abs().mean().item())
        den = a[fin].abs().max().clamp_min(1e-12)
        print("rel_err(finite):", (diff.abs().max() / den).item())
        print("l2_rel_err(finite):", diff.pow(2).sum().sqrt().item() / a[fin].pow(2).sum().sqrt().clamp_min(1e-12).item())
        print("cosine_sim(finite):", torch.nn.functional.cosine_similarity(a[fin].flatten(), b[fin].flatten(), dim=0).item())
    else:
        print("No overlapping finite values; values likely -inf everywhere except kept indices.")

    # Robust isclose on finite entries only
    if fin.any():
        frac_mismatch = (~torch.isclose(a[fin], b[fin], rtol=rtol, atol=atol)).float().mean().item()
    else:
        frac_mismatch = float('nan')
    print("fraction_mismatch(isclose, finite):", frac_mismatch)
    # python
    kidx_tf = torch.topk(x_tf, k).indices
    kidx_fs = torch.topk(x_fs, k).indices
    print("overlap:", torch.isin(kidx_tf, kidx_fs).float().mean().item(),
          "kept_counts:", torch.isfinite(x_tf).sum().item(), torch.isfinite(x_fs).sum().item())
    assert torch.allclose(x_tf, x_fs, rtol=1e-3, atol=1e-3), "will fail when _topk_filter_simple encounters rows with ties"

max_abs_err(finite): 0.0
mean_abs_err(finite): 0.0
rel_err(finite): 0.0
l2_rel_err(finite): 0.0
cosine_sim(finite): 0.9999998807907104
fraction_mismatch(isclose, finite): 0.0
overlap: 1.0 kept_counts: 100 100
max_abs_err(finite): 0.0
mean_abs_err(finite): 0.0
rel_err(finite): 0.0
l2_rel_err(finite): 0.0
cosine_sim(finite): 1.0
fraction_mismatch(isclose, finite): 0.0
overlap: 1.0 kept_counts: 100 100
max_abs_err(finite): 0.0
mean_abs_err(finite): 0.0
rel_err(finite): 0.0
l2_rel_err(finite): 0.0
cosine_sim(finite): 1.0
fraction_mismatch(isclose, finite): 0.0
overlap: 1.0 kept_counts: 100 101


AssertionError: 

In [72]:
prelude()

import torch, torch.nn.functional as F

def _topp_filter(x, p):
    if p is None or p <= 0.0 or p >= 1.0:
        return x
    probs = F.softmax(x, dim=-1)
    s, i = torch.sort(probs, dim=-1, descending=True)
    c = torch.cumsum(s, dim=-1)
    ms = c <= p
    ms[..., 0] = True
    m = torch.zeros_like(ms, dtype=torch.bool).scatter(-1, i, ms)
    return x.masked_fill(~m, float("-inf"))

def _topp_filter_simple(x, p):
    if p is None or p <= 0.0 or p >= 1.0:
        return x
    probs = F.softmax(x, dim=-1)
    sorted_probs, sorted_idx = torch.sort(probs, descending=True)
    cdf = torch.cumsum(sorted_probs, dim=-1)
    mask = cdf <= p
    mask[..., 0] = True
    keep = sorted_idx[mask]
    new_x = torch.full_like(x, float("-inf"))
    new_x[keep] = x[keep]
    return new_x

dtypes = [torch.float32, torch.float16, torch.bfloat16]
p = 0.9
V = 50000
seed = 1337
torch.manual_seed(seed)
device = "mps"
for dtype in dtypes:
    X = torch.randn(V, dtype=dtype, device=device)
    x_tf = _topp_filter(X, p)
    x_fs = _topp_filter_simple(X, p)
    print(f"max val x_tf: {torch.argmax(x_tf, dim=-1)}")
    print(f"max val x_fs: {torch.argmax(x_fs, dim=-1)}")
    assert torch.equal(x_tf, x_fs)

max val x_tf: 5822
max val x_fs: 5822
max val x_tf: 6912
max val x_fs: 6912
max val x_tf: 43770
max val x_fs: 43770


### Rotary DTypes

In [4]:
prelude()

In [13]:
import torch, torch.nn as nn

def _apply_rope(x, cos, sin):
    x1, x2 = x.chunk(2, dim=-1)
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat((y1, y2), dim=-1)

class RotaryPositionEncoding(nn.Module):
    def __init__(self, head_dim, max_seq_len, dtype):
        super().__init__()
        half = head_dim // 2
        keep = head_dim // 4
        base = 1024
        if keep == 0:
            angular = (1 / base) ** torch.linspace(0, 1, steps=half, dtype=dtype)
        else:
            active = (1 / base) ** torch.linspace(0, 1, steps=keep, dtype=dtype)
            angular = torch.cat([active, active.new_zeros(half - keep)])
        self.inv_freq = nn.Buffer(angular, persistent=False)
        t = torch.arange(max_seq_len, dtype=dtype)
        theta = torch.einsum("i,j->ij", t, self.inv_freq)
        self.cos = nn.Buffer(theta.cos().to(dtype), persistent=False)
        self.sin = nn.Buffer(theta.sin().to(dtype), persistent=False)
        self._max_seq_len = int(max_seq_len)

    def forward(self, x_BTHD):
        L = x_BTHD.size(-3)
        cos = self.cos[None, :L, None, :]
        sin = self.sin[None, :L, None, :]
        return _apply_rope(x_BTHD, cos, sin)



In [14]:
from time import perf_counter as p
import torch.nn.functional as F

device = "mps"
rotary_dtypes = [torch.float32, torch.float16, torch.bfloat16]
ref_dtype = torch.float32
B, T, H, D, iters, max_seq_len = 1, 1024, 8, 128, 2048, 65536
ref_rotary = RotaryPositionEncoding(D, max_seq_len, dtype=ref_dtype).to(device)

for dtype in rotary_dtypes:
    rotary = RotaryPositionEncoding(D, max_seq_len, dtype=dtype).to(device)
    X = torch.randn(B, T, H, D, dtype=dtype, device=device)
    with dev_sync(device):
        t0 = p()
        for i in range(iters):
            rotary(X)
    dur = p() - t0
    print(f"Rotary [{dtype}] | Avg dur/iter: {dur/iters:.6f}s")

    S = torch.zeros(iters, dtype=ref_dtype, device=device)
    for i in range(iters):
        Y = torch.randn(B, T, H, D, dtype=ref_dtype, device=device)
        approx = rotary(Y.to(dtype)).to(ref_dtype)
        ref_res = ref_rotary(Y)
        dist = F.cosine_similarity(ref_res, approx, dim=-1).mean()
        S[i] = dist
    print(f"Rotary [{dtype}] | Mean cosine similarity: {S.mean().item():.6f}")


Rotary [torch.float32] | Avg dur/iter: 0.000150s
Rotary [torch.float32] | Mean cosine similarity: 1.000000
Rotary [torch.float16] | Avg dur/iter: 0.000136s
Rotary [torch.float16] | Mean cosine similarity: 0.999386
Rotary [torch.bfloat16] | Avg dur/iter: 0.000137s
Rotary [torch.bfloat16] | Mean cosine similarity: 0.959024
