### Clear Vars

In [8]:
def clear_vars():
    for _n in list(globals()):
        if not _n.startswith("_") and _n not in ("In","Out","get_ipython","exit","quit"):
            del globals()[_n]

### Device Copy Overhead

In [None]:
clear_vars()

In [3]:
import time, torch, torch.nn.functional as F
V=50000
dtype=torch.bfloat16
iters = 2048
x=torch.randn(V, device="mps", dtype=dtype)
s=time.perf_counter()
for _ in range(iters): _ = int(torch.argmax(x).item())
torch.mps.synchronize(); print("1-scalar copy:", time.perf_counter()-s)

p=F.softmax(x, -1).to("cpu"); s=time.perf_counter()
for _ in range(iters): _ = int(torch.multinomial(p,1))  # simulating CPU sampler
print("CPU multinomial -> CPU scalar:", time.perf_counter()-s)


p=F.softmax(x, -1).to("mps"); s=time.perf_counter()
for _ in range(iters): _ = int(torch.multinomial(p,1))
print("MPS multinomial -> CPU scalar:", time.perf_counter()-s)


1-scalar copy: 0.6497018329973798
CPU multinomial -> CPU scalar: 1.870685500005493
MPS multinomial -> CPU scalar: 2.4869550000003073


### Rotary DTypes

In [None]:
clear_vars()

In [4]:
import torch, torch.nn as nn

def _apply_rope(x, cos, sin):
    x1, x2 = x.chunk(2, dim=-1)
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat((y1, y2), dim=-1)

class RotaryPositionEncoding(nn.Module):
    def __init__(self, dim, max_seq_len, dtype):
        super().__init__()
        half = dim // 2
        keep = dim // 4
        base = 1024
        if keep == 0:
            angular = (1 / base) ** torch.linspace(0, 1, steps=half, dtype=dtype)
        else:
            active = (1 / base) ** torch.linspace(0, 1, steps=keep, dtype=dtype)
            angular = torch.cat([active, active.new_zeros(half - keep)])
        self.inv_freq = nn.Buffer(angular, persistent=False)
        t = torch.arange(max_seq_len, dtype=dtype)
        theta = torch.einsum("i,j->ij", t, self.inv_freq)
        self.cos = nn.Buffer(theta.cos().to(dtype), persistent=False)
        self.sin = nn.Buffer(theta.sin().to(dtype), persistent=False)
        self._max_seq_len = int(max_seq_len)

    def apply(self, x_BTHD, L):
        cos = self.cos[:L][None, :L, None, :]
        sin = self.sin[:L][None, :L, None, :]
        return _apply_rope(x_BTHD, cos, sin)



In [None]:
rotary_dtypes = [torch.float16, torch.bfloat16]
B = 1; T = 1024; H = 8; D = 128; iters = 2048

for dtype in rotary_dtypes:
    rope_ref = RotaryPositionEncoding(H*D, T, dtype=torch.float32)
    rope = RotaryPositionEncoding(H*D, T, dtype=dtype)
    x = torch.randn(B, T, H, D, dtype=dtype)
    for i in range(iters):
        rope.apply(x, 1024)