# **Broadcasting Simulator**

In [None]:
import torch
from typing import List, Tuple

def broadcast_shape(shape_a: Tuple[int, ...],
                    shape_b: Tuple[int, ...]) -> Tuple[int, ...]:
    """
    Compute the broadcasted shape of two tensors, or raise ValueError
    if they are not broadcastable (NumPy/PyTorch rules).
    """
    # Align from the right by prepending 1s on the shorter shape
    a = list(shape_a)
    b = list(shape_b)

    max_len = max(len(a), len(b))
    a = [1] * (max_len - len(a)) + a
    b = [1] * (max_len - len(b)) + b

    out: List[int] = []
    for dim_a, dim_b in zip(a, b):
        if dim_a == dim_b or dim_a == 1 or dim_b == 1:
            out.append(max(dim_a, dim_b))
        else:
            raise ValueError(f"Shapes {shape_a} and {shape_b} are not "
                             f"broadcastable: conflict {dim_a} vs {dim_b}")
    return tuple(out)


def explain_broadcast(shape_a: Tuple[int, ...],
                      shape_b: Tuple[int, ...]) -> None:
    """
    Print a step-by-step explanation of broadcasting between two shapes.
    """
    print(f"A shape: {shape_a}")
    print(f"B shape: {shape_b}")
    print("-" * 40)

    # Align with leading 1s
    a = list(shape_a)
    b = list(shape_b)
    max_len = max(len(a), len(b))
    a = [1] * (max_len - len(a)) + a
    b = [1] * (max_len - len(b)) + b

    print("Aligned (with leading 1s):")
    print(f"A aligned: {tuple(a)}")
    print(f"B aligned: {tuple(b)}")
    print("-" * 40)

    out: List[int] = []
    print("Compare dimensions (left → right):")
    for i, (dim_a, dim_b) in enumerate(zip(a, b)):
        pos = i - max_len  # relative from the right: -1 is last dim
        msg = f"  dim {i} (from left, pos {pos:+} from right): {dim_a} vs {dim_b} -> "
        if dim_a == dim_b:
            out_dim = dim_a
            reason = "same, keep"
        elif dim_a == 1:
            out_dim = dim_b
            reason = "A expands to match B"
        elif dim_b == 1:
            out_dim = dim_a
            reason = "B expands to match A"
        else:
            print(msg + "❌ conflict (not broadcastable)")
            print()
            print(f"Result: shapes {shape_a} and {shape_b} are NOT broadcastable.")
            return
        out.append(out_dim)
        print(msg + f"✔ {out_dim} ({reason})")

    out_shape = tuple(out)
    print("-" * 40)
    print(f"Broadcasted shape: {out_shape}")
    print()


# --------- quick tests / examples ---------
if __name__ == "__main__":
    # 1) Your classic (B,T,1) × (B,1,D)
    B, T, D = 2, 4, 8
    explain_broadcast((B, T, 1), (B, 1, D))

    # 2) (B,T,D) + (D,)
    explain_broadcast((B, T, D), (D,))

    # 3) (B,T) × (D,) – usually invalid unless T == D
    try:
        explain_broadcast((B, T), (D,))
    except Exception as e:
        print("Caught error:", e)

    # 4) Quick programmatic use:
    print("broadcast_shape((3, 1), (1, 4)) ->",
          broadcast_shape((3, 1), (1, 4)))


A shape: (2, 4, 1)
B shape: (2, 1, 8)
----------------------------------------
Aligned (with leading 1s):
A aligned: (2, 4, 1)
B aligned: (2, 1, 8)
----------------------------------------
Compare dimensions (left → right):
  dim 0 (from left, pos -3 from right): 2 vs 2 -> ✔ 2 (same, keep)
  dim 1 (from left, pos -2 from right): 4 vs 1 -> ✔ 4 (B expands to match A)
  dim 2 (from left, pos -1 from right): 1 vs 8 -> ✔ 8 (A expands to match B)
----------------------------------------
Broadcasted shape: (2, 4, 8)

A shape: (2, 4, 8)
B shape: (8,)
----------------------------------------
Aligned (with leading 1s):
A aligned: (2, 4, 8)
B aligned: (1, 1, 8)
----------------------------------------
Compare dimensions (left → right):
  dim 0 (from left, pos -3 from right): 2 vs 1 -> ✔ 2 (B expands to match A)
  dim 1 (from left, pos -2 from right): 4 vs 1 -> ✔ 4 (B expands to match A)
  dim 2 (from left, pos -1 from right): 8 vs 8 -> ✔ 8 (same, keep)
----------------------------------------
Bro

In [None]:

B, T, D = 2, 4, 8
explain_broadcast((B, T, 1), (B, 1, D))

A shape: (2, 4, 1)
B shape: (2, 1, 8)
----------------------------------------
Aligned (with leading 1s):
A aligned: (2, 4, 1)
B aligned: (2, 1, 8)
----------------------------------------
Compare dimensions (left → right):
  dim 0 (from left, pos -3 from right): 2 vs 2 -> ✔ 2 (same, keep)
  dim 1 (from left, pos -2 from right): 4 vs 1 -> ✔ 4 (B expands to match A)
  dim 2 (from left, pos -1 from right): 1 vs 8 -> ✔ 8 (A expands to match B)
----------------------------------------
Broadcasted shape: (2, 4, 8)



In [None]:
import torch
from typing import List, Tuple, Sequence

Shape = Tuple[int, ...]


def broadcast_shape_many(shapes: Sequence[Shape]) -> Shape:
    """
    Compute the broadcasted shape of N tensors (NumPy/PyTorch rules).

    shapes: iterable of shapes, e.g. [(B,T,1), (B,1,D), (1,T,1)]
    returns: single broadcasted shape, or raises ValueError if incompatible.
    """
    if not shapes:
        return ()

    # Step 1: compute max rank and left-pad each shape with 1s
    max_len = max(len(s) for s in shapes)
    aligned: List[List[int]] = []
    for s in shapes:
        padded = [1] * (max_len - len(s)) + list(s)
        aligned.append(padded)

    out: List[int] = []
    # Step 2: for each dimension (left → right, but rules are symmetric)
    for dim_idx in range(max_len):
        dims_here = [a[dim_idx] for a in aligned]
        non_ones = sorted({d for d in dims_here if d != 1})

        if len(non_ones) == 0:
            # all 1s → result is 1
            out_dim = 1
        elif len(non_ones) == 1:
            # all non-1 dims agree → that is the result dim
            out_dim = non_ones[0]
        else:
            # more than one distinct non-1 size → conflict
            raise ValueError(
                f"Incompatible shapes at dim {dim_idx}: {dims_here} "
                f"(non-1 dims {non_ones})"
            )
        out.append(out_dim)

    return tuple(out)


def explain_broadcast_many(shapes: Sequence[Shape]) -> None:
    """
    Verbose explanation of N-way broadcasting.
    """
    print("Input shapes:")
    for i, s in enumerate(shapes):
        print(f"  Tensor {i}: {s}")
    print("-" * 60)

    if not shapes:
        print("No shapes given → result shape is ().")
        return

    max_len = max(len(s) for s in shapes)
    aligned: List[List[int]] = []
    for s in shapes:
        padded = [1] * (max_len - len(s)) + list(s)
        aligned.append(padded)

    print("Aligned with leading 1s (so all have same rank):")
    for i, a in enumerate(aligned):
        print(f"  Tensor {i} aligned: {tuple(a)}")
    print("-" * 60)

    out: List[int] = []
    print("Per-dimension analysis (left → right):")
    for dim_idx in range(max_len):
        dims_here = [a[dim_idx] for a in aligned]
        non_ones = sorted({d for d in dims_here if d != 1})

        print(f"\nDim {dim_idx} (from left):")
        for i, d in enumerate(dims_here):
            print(f"  - Tensor {i}: {d}")

        if len(non_ones) == 0:
            out_dim = 1
            print("  -> All dims are 1 → result dim = 1")
        elif len(non_ones) == 1:
            out_dim = non_ones[0]
            expanders = [
                i for i, d in enumerate(dims_here) if d == 1 and out_dim != 1
            ]
            print(f"  -> Non-1 dims agree on {out_dim} → result dim = {out_dim}")
            if expanders:
                print(f"     Tensors {expanders} broadcast (their dim 1 expands)")
        else:
            print(
                f"  -> ❌ Conflict: multiple distinct non-1 dims {non_ones}. "
                f"Shapes are NOT broadcastable."
            )
            return

        out.append(out_dim)

    out_shape = tuple(out)
    print("\n" + "-" * 60)
    print(f"Broadcasted shape: {out_shape}")
    print()


# --------------------------
# Example usage / quick tests
# --------------------------
if __name__ == "__main__":
    B, T, D = 2, 4, 8

    # 1) Classic attention-style shapes: (B,T,1) * (B,1,D) * (1,T,D)
    shapes1 = [(B, T, 1), (B, 1, D), (1, T, D)]
    explain_broadcast_many(shapes1)
    print("broadcast_shape_many:", broadcast_shape_many(shapes1))
    print()

    # 2) (B,T,D) + (D,) + (1,1,D)
    # (D,) is the same as
    shapes2 = [(B, T, D), (D,), (1, 1, D)]
    explain_broadcast_many(shapes2)
    print("broadcast_shape_many:", broadcast_shape_many(shapes2))
    print()

    # 3) Incompatible case: (B,T) and (D,2)
    shapes3 = [(B, T), (D, 2)]
    try:
        explain_broadcast_many(shapes3)
        print("broadcast_shape_many:", broadcast_shape_many(shapes3))
    except ValueError as e:
        print("Caught error:", e)

Input shapes:
  Tensor 0: (2, 4, 1)
  Tensor 1: (2, 1, 8)
  Tensor 2: (1, 4, 8)
------------------------------------------------------------
Aligned with leading 1s (so all have same rank):
  Tensor 0 aligned: (2, 4, 1)
  Tensor 1 aligned: (2, 1, 8)
  Tensor 2 aligned: (1, 4, 8)
------------------------------------------------------------
Per-dimension analysis (left → right):

Dim 0 (from left):
  - Tensor 0: 2
  - Tensor 1: 2
  - Tensor 2: 1
  -> Non-1 dims agree on 2 → result dim = 2
     Tensors [2] broadcast (their dim 1 expands)

Dim 1 (from left):
  - Tensor 0: 4
  - Tensor 1: 1
  - Tensor 2: 4
  -> Non-1 dims agree on 4 → result dim = 4
     Tensors [1] broadcast (their dim 1 expands)

Dim 2 (from left):
  - Tensor 0: 1
  - Tensor 1: 8
  - Tensor 2: 8
  -> Non-1 dims agree on 8 → result dim = 8
     Tensors [0] broadcast (their dim 1 expands)

------------------------------------------------------------
Broadcasted shape: (2, 4, 8)

broadcast_shape_many: (2, 4, 8)

Input shapes

# Computing mean


https://www.youtube.com/watch?v=kCc8FmEb1nY
Karpathy Trick behind attention

44:00


<ul>
<li>Summary: A mean reduces the dims of a shape like [10,10] to [10] because one of the dims becomes the mean. Use keepdims=True to get [1,10] instead of [10]. This helps to make clear it is ready for broadcasting</li>
<li>create data pattern for debugging. Show the batch, time, Column dims</li>
<li>We are calculating the mean per vertical column. 4,8,2 has 2 colmns so the mean is [mean col0, mean col1]</li>
<li>Broadcasting requires a 1 as one of the dimensions or pytorch uses the shift trick to add a 1 to the arg with the smaller number of dims.  </li>
<li>Pytorch first aligns left then adds ones until num dims match the other arg. [4,8,2] amnd [2] become [4,8,2] and [1,1,2].</li>
<li>We use the phrase left aligned but pytorch calls it right aligned</li>
<li>[4,8,2]</li>
<li>[6]</li>
<li>Unsqueeze(1) adds a 1 to the front of the shape tuple,  (6)->(1,6) and unsqueeze(-1) adds a 1 to the end of the shape vector (1,6)->(1,6,1). A unsqueeze(2) adds to the 2 index, (1,6,1)->(1,6,1,1)</li>
<li>the 6 is in the same column as the 4. I call this left aligned but pytorch calls it right aligned.</li>
<li>Some operations like mean, sum eliminate one of the dimensions, to keep this dim=1, use keepdims=True</li>
<ul>

In [None]:
import torch
x = torch.tensor([6])
print(x.shape)
x=x.unsqueeze(1)
print(x.shape)
x = x.unsqueeze(-1)
print(x.shape)

torch.Size([1])
torch.Size([1, 1])
torch.Size([1, 1, 1])


In [None]:
import torch
x = torch.tensor([1,2,3,4,5,6], dtype=torch.float16)
last_5=x[0:5] #this is not really right should index fro 6
print('-------------')
print('reverse indexing')
print(x[-1:])
print(x[-2:])
print(x[-3:])
print('-------------')
print('foward indexing')
print(x[:(0+1)])
print(x[:(1+1):])
print(x[:(2+1):])


#print(last_5)
#torch.mean(last_5)
print('------------')
print("easier to do forward indexing since we tokenize from l->r")
print("the below is incorrect because of the first empty array")
for idx in range(len(x)):
  print("start from index:",idx," previous tokens:", x[:idx], "mean:",torch.mean(x[:idx]))

print("\n")
print("\n")
print("since we index from 0, we need to start from 1")

for idx in range(len(x)):
  print("start from index+1:",idx+1," previous tokens:", x[:idx+1], "mean:",torch.mean(x[:idx+1]))

#convert to tensor x(B,T,C) convention
B,T,C = 4,8,2


B, T, C = 4, 8, 2
x = torch.zeros((B, T, C))

for b in range(B):
    for t in range(T):
        x[b, t, 0] = b     # channel 0 shows batch index
        x[b, t, 1] = t     # channel 1 shows timestep index

print('----------')
print(x)
print('----------')

for b in range(B):
  for t in range(T):
    xprev = x[b,:t+1] #(t,C)
    xbow = torch.mean(xprev,0) #the 0 means 0 dimension, which is t because (t,C)
    print(f'b:{b},t:{t},xprev:{xprev}, xbow:{xbow}')

#c isnt incremented [b,t] so we are stepping through rows [0,0],[0,1],[0,2],[0,3],,...
# xprev is the previous row.

-------------
reverse indexing
tensor([6.], dtype=torch.float16)
tensor([5., 6.], dtype=torch.float16)
tensor([4., 5., 6.], dtype=torch.float16)
-------------
foward indexing
tensor([1.], dtype=torch.float16)
tensor([1., 2.], dtype=torch.float16)
tensor([1., 2., 3.], dtype=torch.float16)
------------
easier to do forward indexing since we tokenize from l->r
the below is incorrect because of the first empty array
start from index: 0  previous tokens: tensor([], dtype=torch.float16) mean: tensor(nan, dtype=torch.float16)
start from index: 1  previous tokens: tensor([1.], dtype=torch.float16) mean: tensor(1., dtype=torch.float16)
start from index: 2  previous tokens: tensor([1., 2.], dtype=torch.float16) mean: tensor(1.5000, dtype=torch.float16)
start from index: 3  previous tokens: tensor([1., 2., 3.], dtype=torch.float16) mean: tensor(2., dtype=torch.float16)
start from index: 4  previous tokens: tensor([1., 2., 3., 4.], dtype=torch.float16) mean: tensor(2.5000, dtype=torch.float16)
sta

Attention definitions

<ul>
<li>
given an embedding , 768 how does this become B,T,C, is B a collection of 768 vectorx, T is the row, C are the columns so C = 768. We have to create B,T
</li>
<li>
T = sequence length. and B is number of sequences processed in GPU memory at one time.
</li>
<li>

attention splits C to num_heads.
</li>
<li>
num_heads is a predefined constant
num_heads=H;
number of attention heads is C/H. this isnt the same as num_heads
num_attention_heads = C/H
</li>
<li>
head_dim = C/H, 768/16
</li>
<li>
K_cache: (B, T, H, head_dim)
V_cache: (B, T, H, head_dim)
</li>

<ul>

Matrix Multiplies

<ul>
<li>Matrix multiply by identity matrix proudces sums and averages. [2,7], [6,4], [6,5] produce column sums by matrix multiply with identity matrix</li>
</ul>

In [None]:
# karpathy replicating medians in (B,T,C) format with triangular matrix multiply
import torch
torch.manual_seed(42)
a = torch.ones(3,3)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('----')
print('b=')
print(b)
print('----')
print('c=')
print(c)
print('---end multiply w identiy matrix---')

# can see matrix multiply with Identiy matrix and data produces sums in the columns. Column sums 2+6+6=16, 7+4+5=16
# the matrix multiply is a sum when we take the dot product. First row [1,1,1] * first col [2,6,6] gives sum 2+6+6=14,

# second step take the lower triangular matrix instead if Identity matrix. This adds 0s
print('---replace with lower triangular and make rows sum to 1---')
print('  ')
a = torch.tril(torch.ones(3,3))/torch.sum(a,1,keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('----')
print('b=')
print(b)
print('----')
print('c=')
print(c)
print('---end multiply w lower triagular rows normalzied to sum 1---')


a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
----
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
----
c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])
---end multiply w identiy matrix---
---replace with lower triangular and make rows sum to 1---
  
a=
tensor([[0.3333, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.0000],
        [0.3333, 0.3333, 0.3333]])
----
b=
tensor([[0., 4.],
        [0., 3.],
        [8., 4.]])
----
c=
tensor([[0.0000, 1.3333],
        [0.0000, 2.3333],
        [2.6667, 3.6667]])
---end multiply w lower triagular rows normalzied to sum 1---


In [None]:
x=torch.tensor([[1,2,3],[4,5,6]])
print(x, x.shape)
#y = x/torch.sum(x,dim=1)
print(torch.sum(x, dim=0, keepdim=True))
print(torch.sum(x, dim=0, keepdim=False))
print(torch.sum(x, dim=1, keepdim=True))
print(torch.sum(x, dim=1, keepdim=False))
#print(s)
print(torch.sum(x))

tensor([[1, 2, 3],
        [4, 5, 6]]) torch.Size([2, 3])
tensor([[5, 7, 9]])
tensor([5, 7, 9])
tensor([[ 6],
        [15]])
tensor([ 6, 15])
tensor([[3],
        [7]])
tensor(21)


# **Matrix and Vector multiply for averages.**

There are different versions of averages which are used for normalization. Karpathy develops the t-1 or the autoregresseive average where the average of the sequence of a,b,c,d has 4 averages; at t=0, avg=a, t=1 avg=(a+b)/2, t=2 avg=
(a+b+c)/3, etc...


<ul>
<li>$v = \frac{1}{N}[1,1,1,...]$ where num ones = N</li>
<li>A cumulative average is the conventional average $\frac{1}{N}\sum_0^{N-1}x_i$. It is time invariant. Shifting the sequence produces the same average. $v=[1,1,1...len(x)]$ and the data is $x$. cumulative avg = $v^T@x$</li>
<li>A column average $v^T@X$</li>
<li>How to derive row and column avg. X=(B,D). Create a 2x3 test matrix
[[1,2,3],[4,5,6]]. Make sure the TM is not symmetric to reduce confusion. We have to options an identity matrix 1x2 if I@X or 3x1 if X@I. There are 2 rows so a row sum must have 2 rows so you know 1x2 is rows and there are 3 cols and you know you need 3 columns for a column sum. so I@X is row sum and X@I is column sum. Then add 1/N to get avg. N=num elements in row or col.
</li>
<li>A row average $X@v$</li>
<li>Weighted avg for softmax. W=(T), V=(T,D). $avg=W^T@V$ Output = (D,) or (D,1) if keepdims=True. Because avg collapases and removes dimensions by default</li>
<li>How to derive Weighted Softmax Avg</li>
<li>Sequence avg: <li>
<li>How to derive sequence avg. </li>
</ul>

In [None]:
import torch

x = torch.tensor([1,2,3,4,5,6]).float()

v = (1/6)*(torch.ones(6))
print("avg of sum of all elements v@x:",v@x)
tri = torch.tril(torch.ones(6,6))
avg = (tri@x)/torch.arange(1,7)

print("sums:",tri@x)
print("torch arange:",torch.arange(1,7))
print("rolling avg:",avg) # 1/1, (1+2)/2, (1+2+3)/3,...

avg of sum of all elements v@x: tensor(3.5000)
sums: tensor([ 1.,  3.,  6., 10., 15., 21.])
torch arange: tensor([1, 2, 3, 4, 5, 6])
avg: tensor([1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000])


In [None]:
import torch

x = torch.tensor([1,2,3,4,5])
print(f'len(x):{len(x)}')
v = torch.ones(len(x))
lower_tri = torch.ones(len(x),len(x))





len(x):5


In [None]:
import torch
import numpy as np
from typing import Optional, Union

# avoid division by 0
def safe_torch_mean(x: torch.Tensor) -> torch.Tensor:
    return x.float().sum() / max(x.numel(), 1)

def safe_mean(
    x: torch.Tensor,
    dim: Optional[int] = None,
    keepdim: bool = False,
    default: Union[float, int] = 0.0,
) -> torch.Tensor:
    """
    Mean that:
      - casts non-floating dtypes to float32
      - returns `default` when there are no elements along `dim`
    """
    if not x.is_floating_point():
        x = x.to(torch.float32)

    if dim is None:
        if x.numel() == 0:
            # scalar default
            return x.new_tensor(float(default))
        return x.mean()

    # Mean along a dimension
    if x.size(dim) == 0:
        # build output shape manually
        out_shape = list(x.shape)
        if keepdim:
            out_shape[dim] = 1
        else:
            del out_shape[dim]
        return x.new_full(out_shape, float(default))

    return x.mean(dim=dim, keepdim=keepdim)

x = torch.tensor([], dtype=torch.float32)
print(safe_mean(x))  # tensor(0.)

x = torch.randint(0, 10, (3,), dtype=torch.int8)
print(safe_mean(x))  # float32 mean, no error


def masked_mean(
    x: torch.Tensor,
    mask: torch.Tensor,
    dim: int,
    keepdim: bool = False,
    default: Union[float, int] = 0.0,
) -> torch.Tensor:
    """
    Mean over elements where mask == 1/True along `dim`.

    x: (..., D, ...)
    mask: same shape as x or broadcastable to x
    """
    if not x.is_floating_point():
        x = x.to(torch.float32)

    # make mask float for multiplication
    m = mask.to(x.dtype)
    # broadcast OK: this relies on PyTorch broadcasting
    masked_x = x * m

    # sum over dim
    num = masked_x.sum(dim=dim, keepdim=keepdim)
    den = m.sum(dim=dim, keepdim=keepdim)

    # safe division: where den > 0, num / den; else default
    default_tensor = num.new_full(num.shape, float(default))
    mean = torch.where(den > 0, num / torch.clamp(den, min=1e-12), default_tensor)
    return mean

x = torch.tensor([[1., 2., 3.],
                  [4., 5., 6.]])
mask = torch.tensor([[1, 0, 1],
                     [0, 0, 0]])  # second row all masked out

print(masked_mean(x, mask, dim=1))
# tensor([2., 0.])  (last row default=0)

print(masked_mean(x, mask, dim=1, default=-1.0))
# tensor([2., -1.])


def segment_mean(
    values: torch.Tensor,
    segment_ids: torch.Tensor,
    num_segments: Optional[int] = None,
    default: Union[float, int] = 0.0,
) -> torch.Tensor:
    """
    Compute mean over segments:
        segment_means[k] = mean(values[segment_ids == k])

    values: (N, D) or (N,)
    segment_ids: (N,) ints in [0, num_segments-1]
    """
    if not values.is_floating_point():
        values = values.to(torch.float32)

    if values.dim() == 1:
        values = values.unsqueeze(-1)  # make it (N, 1)

    N, D = values.shape
    segment_ids = segment_ids.to(torch.long)

    if num_segments is None:
        num_segments = int(segment_ids.max().item()) + 1 if N > 0 else 0

    device = values.device
    dtype = values.dtype

    # sums for each segment
    sums = torch.zeros(num_segments, D, device=device, dtype=dtype)
    counts = torch.zeros(num_segments, 1, device=device, dtype=dtype)

    # index_add along segment dimension
    sums.index_add_(0, segment_ids, values)
    counts.index_add_(0, segment_ids, torch.ones_like(values[:, :1]))

    default_tensor = sums.new_full(sums.shape, float(default))
    means = torch.where(
        counts > 0,
        sums / torch.clamp(counts, min=1e-12),
        default_tensor,
    )

    # squeeze if original was 1D
    if values.shape[1] == 1:
        means = means.squeeze(-1)

    return means

vals = torch.tensor([[1., 2.],
                     [3., 4.],
                     [10., 20.]], dtype=torch.float32)
seg = torch.tensor([0, 0, 2])   # segment 1 is empty

print(segment_mean(vals, seg, num_segments=3, default=0.0))
# tensor([[2., 3.],      # mean of rows 0 and 1
#         [0., 0.],      # empty segment -> default
#         [10., 20.]])   # row 2

def batch_safe_mean(
    x: torch.Tensor,
    mask: torch.Tensor,
    default: Union[float, int] = 0.0,
) -> torch.Tensor:
    """
    Per-batch masked mean over time dimension 1.

    x: (B, T, D) or (B, T)
    mask: (B, T) with 1/True = valid entries
    """
    if x.dim() == 2:
        # (B, T) -> (B, T, 1) so we reuse the same logic
        x = x.unsqueeze(-1)
        squeeze_back = True
    else:
        squeeze_back = False

    # broadcast mask to (B, T, 1)
    mask_exp = mask.unsqueeze(-1)

    means = masked_mean(
        x,
        mask_exp,
        dim=1,           # average over time
        keepdim=False,
        default=default,
    )

    if squeeze_back:
        means = means.squeeze(-1)

    return means

B, T, D = 2, 5, 3
x = torch.randn(B, T, D)
mask = torch.tensor([
    [1, 1, 1, 0, 0],   # first sequence length 3
    [0, 0, 0, 0, 0],   # second is fully padded
])

m = batch_safe_mean(x, mask, default=0.0)
print(m.shape)   # (2, 3)
# row 0: mean over first 3 time steps
# row 1: [0., 0., 0.] from default

import numpy as np

def np_safe_mean(x: np.ndarray, axis=None, keepdims=False):
  if np.size(x) == 0:
      # Empty → return 0 with requested shape
      if axis is None:
          return 0.0
      # build shape as if mean had been taken, but filled with 0
      return np.zeros(np.mean(x, axis=axis, keepdims=keepdims).shape, dtype=float)

  # normal mean is fine when non-empty
  return np.mean(x, axis=axis, keepdims=keepdims)



def np_safe_masked_mean(x: np.ndarray,
                        mask: np.ndarray,
                        axis=None,
                        keepdims=False):
    m = mask.astype(float)
    masked = x * m

    num = masked.sum(axis=axis, keepdims=keepdims)
    count = m.sum(axis=axis, keepdims=keepdims)
    safe_count = np.clip(count, 1.0, None)
    return num / safe_count



def np_safe_segment_mean(values: np.ndarray,
                         segment_ids: np.ndarray,
                         num_segments: int | None = None):
    if num_segments is None:
        num_segments = int(segment_ids.max()) + 1

    rest_shape = values.shape[1:]
    sums = np.zeros((num_segments, *rest_shape), dtype=values.dtype)
    counts = np.zeros(num_segments, dtype=float)

    np.add.at(sums, segment_ids, values)
    np.add.at(counts, segment_ids, 1.0)

    counts = np.clip(counts, 1.0, None)
    # reshape for broadcast
    while counts.ndim < sums.ndim:
        counts = counts[..., None]

    return sums / counts




In [None]:
import os, time, math
import numpy as np
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# =============================
# 1. Safe mean helpers (PyTorch / NumPy)
# =============================

def torch_safe_mean(x: torch.Tensor, dim: int = -1, keepdim: bool = False):
    """
    Safe mean: if the reduction dim is empty, returns 0 (not NaN).
    Works on CPU or GPU.
    """
    if x.numel() == 0 or x.size(dim) == 0:
        # Build an output shape consistent with keepdim
        out_shape = list(x.shape)
        if keepdim:
            out_shape[dim] = 1
        else:
            del out_shape[dim]
        if len(out_shape) == 0:
            return torch.zeros((), dtype=x.dtype, device=x.device)
        return torch.zeros(out_shape, dtype=x.dtype, device=x.device)
    return x.mean(dim=dim, keepdim=keepdim)

def torch_masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int = -1, eps: float = 1e-8):
    """
    Masked mean: mean over elements where mask==1.
    If all masked-out, returns 0.
    x: (..., D), mask: same shape as x or broadcastable.
    """
    mask = mask.to(dtype=x.dtype)
    num = (x * mask).sum(dim=dim)
    denom = mask.sum(dim=dim)
    mean = num / torch.clamp(denom, min=eps)
    # Zero out where denom is zero
    mean = torch.where(denom > 0, mean, torch.zeros_like(mean))
    return mean

def torch_segment_mean(x: torch.Tensor, segment_ids: torch.Tensor, num_segments: int):
    """
    Segment mean on CPU/GPU using scatter_add:
    x: (N, D), segment_ids: (N,), num_segments=K
    returns: (K, D)
    """
    N, D = x.shape
    out = torch.zeros(num_segments, D, dtype=x.dtype, device=x.device)
    count = torch.zeros(num_segments, 1, dtype=x.dtype, device=x.device)

    out.scatter_add_(0,
                     segment_ids.view(-1, 1).expand(-1, D),
                     x)
    ones = torch.ones(N, 1, dtype=x.dtype, device=x.device)
    count.scatter_add_(0,
                       segment_ids.view(-1, 1),
                       ones)

    denom = torch.clamp(count, min=1.0)
    mean = out / denom
    mean[count.squeeze(-1) == 0] = 0.0
    return mean

# NumPy equivalents
def np_safe_mean(x: np.ndarray, axis: int = -1, keepdims: bool = False):
    if x.size == 0 or x.shape[axis] == 0:
        out_shape = list(x.shape)
        if keepdims:
            out_shape[axis] = 1
        else:
            del out_shape[axis]
        if len(out_shape) == 0:
            return np.array(0, dtype=x.dtype)
        return np.zeros(out_shape, dtype=x.dtype)
    return x.mean(axis=axis, keepdims=keepdims)

def np_masked_mean(x: np.ndarray, mask: np.ndarray, axis: int = -1, eps: float = 1e-8):
    mask = mask.astype(x.dtype)
    num = (x * mask).sum(axis=axis)
    denom = mask.sum(axis=axis)
    mean = num / np.clip(denom, eps, None)
    mean = np.where(denom > 0, mean, 0.0)
    return mean

def np_segment_mean(x: np.ndarray, segment_ids: np.ndarray, num_segments: int):
    """
    Segment mean for NumPy:
    x: (N, D), segment_ids: (N,), num_segments=K
    returns: (K, D)
    """
    if x.ndim == 1:
        x = x[:, None]
        squeeze = True
    else:
        squeeze = False

    N, D = x.shape
    out = np.zeros((num_segments, D), dtype=x.dtype)
    count = np.zeros((num_segments, 1), dtype=x.dtype)

    for i in range(N):
        seg = int(segment_ids[i])
        out[seg] += x[i]
        count[seg] += 1

    denom = np.clip(count, 1.0, None)
    mean = out / denom
    # zero out segments with count=0
    mask_zero = (count == 0).reshape(num_segments, 1)
    mean[mask_zero[:, 0]] = 0.0

    if squeeze:
        mean = mean[:, 0]
    return mean

# ======================================
# 2. Triton kernels: safe mean / masked mean / segment mean
# ======================================

try:
    import triton
    import triton.language as tl
    HAS_TRITON = True
except Exception as e:
    print("Triton not available:", e)
    HAS_TRITON = False

if HAS_TRITON:
    @triton.jit
    def triton_row_mean_kernel(
        x_ptr, out_ptr,
        B, D,
        BLOCK_SIZE: tl.constexpr,
    ):
        row_id = tl.program_id(0)
        # Each program computes mean over one row x[row_id, :]
        offs = tl.arange(0, BLOCK_SIZE)
        acc = tl.zeros((), dtype=tl.float32)
        # Loop over D in chunks of BLOCK_SIZE
        for start in range(0, D, BLOCK_SIZE):
            idx = start + offs
            mask = idx < D
            vals = tl.load(x_ptr + row_id * D + idx, mask=mask, other=0.0)
            acc += tl.sum(vals.to(tl.float32), axis=0)
        mean = acc / tl.max(tl.float32(D), 1.0)
        tl.store(out_ptr + row_id, mean)

    @triton.jit
    def triton_row_masked_mean_kernel(
        x_ptr, mask_ptr, out_ptr,
        B, D,
        BLOCK_SIZE: tl.constexpr,
    ):
        row_id = tl.program_id(0)
        offs = tl.arange(0, BLOCK_SIZE)
        sum_acc = tl.zeros((), dtype=tl.float32)
        cnt_acc = tl.zeros((), dtype=tl.float32)
        for start in range(0, D, BLOCK_SIZE):
            idx = start + offs
            mask = idx < D
            vals = tl.load(x_ptr + row_id * D + idx, mask=mask, other=0.0)
            m = tl.load(mask_ptr + row_id * D + idx, mask=mask, other=0)
            vals = vals.to(tl.float32)
            m = m.to(tl.float32)
            sum_acc += tl.sum(vals * m, axis=0)
            cnt_acc += tl.sum(m, axis=0)
        denom = tl.where(cnt_acc > 0, cnt_acc, 1.0)
        mean = sum_acc / denom
        mean = tl.where(cnt_acc > 0, mean, 0.0)
        tl.store(out_ptr + row_id, mean)

    @triton.jit
    def triton_segment_sum_kernel(
        x_ptr, seg_ptr, out_ptr, cnt_ptr,
        N, D, K,
        BLOCK_SIZE_N: tl.constexpr,
        BLOCK_SIZE_D: tl.constexpr,
    ):
        """
        Atomically accumulate segment sums and counts:
        x: (N, D) -> out: (K, D), cnt: (K,)
        """
        n_offsets = tl.program_id(0) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        d_offsets = tl.program_id(1) * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)

        mask_n = n_offsets < N
        # For each (n, d), accumulate into out[seg, d] and cnt[seg]
        for n in n_offsets:
            if n >= N:
                continue
            seg = tl.load(seg_ptr + n).to(tl.int32)
            # inner loop over d
            x_row_ptr = x_ptr + n * D
            for d in d_offsets:
                if d >= D:
                    continue
                val = tl.load(x_row_ptr + d)
                tl.atomic_add(out_ptr + seg * D + d, val.to(tl.float32))
            # count (only once per row)
            tl.atomic_add(cnt_ptr + seg, 1.0)

# ======================================
# 3. CUDA kernels via torch.utils.cpp_extension
# ======================================

from torch.utils.cpp_extension import load

cuda_src = r"""
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

template<typename scalar_t>
__global__ void row_mean_kernel(const scalar_t* __restrict__ x,
                                float* __restrict__ out,
                                int B, int D) {
  int row = blockIdx.x;
  if (row >= B) return;
  extern __shared__ float sdata[];
  int tid = threadIdx.x;
  sdata[tid] = 0.0f;

  for (int col = tid; col < D; col += blockDim.x) {
    float v = static_cast<float>(x[row * D + col]);
    sdata[tid] += v;
  }
  __syncthreads();

  // reduction in shared mem
  for (int s = blockDim.x / 2; s > 0; s >>= 1) {
    if (tid < s) {
      sdata[tid] += sdata[tid + s];
    }
    __syncthreads();
  }

  if (tid == 0) {
    float denom = D > 0 ? static_cast<float>(D) : 1.0f;
    out[row] = (D > 0) ? (sdata[0] / denom) : 0.0f;
  }
}

template<typename scalar_t>
__global__ void row_masked_mean_kernel(const scalar_t* __restrict__ x,
                                       const uint8_t* __restrict__ mask,
                                       float* __restrict__ out,
                                       int B, int D) {
  int row = blockIdx.x;
  if (row >= B) return;
  extern __shared__ float sdata[];
  float* s_sum = sdata;
  float* s_cnt = sdata + blockDim.x;

  int tid = threadIdx.x;
  s_sum[tid] = 0.0f;
  s_cnt[tid] = 0.0f;

  for (int col = tid; col < D; col += blockDim.x) {
    int idx = row * D + col;
    uint8_t m = mask[idx];
    if (m) {
      float v = static_cast<float>(x[idx]);
      s_sum[tid] += v;
      s_cnt[tid] += 1.0f;
    }
  }
  __syncthreads();

  // reduce
  for (int s = blockDim.x / 2; s > 0; s >>= 1) {
    if (tid < s) {
      s_sum[tid] += s_sum[tid + s];
      s_cnt[tid] += s_cnt[tid + s];
    }
    __syncthreads();
  }

  if (tid == 0) {
    float denom = s_cnt[0] > 0.0f ? s_cnt[0] : 1.0f;
    float mean = (s_cnt[0] > 0.0f) ? (s_sum[0] / denom) : 0.0f;
    out[row] = mean;
  }
}

template<typename scalar_t>
__global__ void segment_sum_kernel(const scalar_t* __restrict__ x,
                                   const int32_t* __restrict__ seg_ids,
                                   float* __restrict__ out,
                                   float* __restrict__ cnt,
                                   int N, int D, int K) {
  int n = blockIdx.x;
  if (n >= N) return;
  int tid = threadIdx.x;
  int seg = seg_ids[n];
  if (seg < 0 || seg >= K) return;

  for (int d = tid; d < D; d += blockDim.x) {
    float v = static_cast<float>(x[n * D + d]);
    atomicAdd(out + seg * D + d, v);
  }
  // count once per row (thread 0)
  if (tid == 0) {
    atomicAdd(cnt + seg, 1.0f);
  }
}

torch::Tensor cuda_row_mean(torch::Tensor x) {
  TORCH_CHECK(x.is_cuda(), "x must be CUDA");
  TORCH_CHECK(x.dim() == 2, "x must be (B, D)");
  const auto B = x.size(0);
  const auto D = x.size(1);
  auto out = torch::empty({B}, x.options().dtype(torch::kFloat32));
  const int threads = 256;
  const int blocks = B;
  const size_t shmem = threads * sizeof(float);
  AT_DISPATCH_ALL_TYPES_AND(torch::ScalarType::Half, x.scalar_type(), "row_mean_kernel", [&] {
    row_mean_kernel<scalar_t><<<blocks, threads, shmem>>>(
      x.data_ptr<scalar_t>(),
      out.data_ptr<float>(),
      B, D
    );
  });
  return out;
}

torch::Tensor cuda_row_masked_mean(torch::Tensor x, torch::Tensor mask) {
  TORCH_CHECK(x.is_cuda(), "x must be CUDA");
  TORCH_CHECK(mask.is_cuda(), "mask must be CUDA");
  TORCH_CHECK(x.sizes() == mask.sizes(), "x and mask shape mismatch");
  TORCH_CHECK(x.dim() == 2, "x must be (B, D)");
  const auto B = x.size(0);
  const auto D = x.size(1);
  auto out = torch::empty({B}, x.options().dtype(torch::kFloat32));
  const int threads = 256;
  const int blocks = B;
  const size_t shmem = threads * sizeof(float) * 2;
  AT_DISPATCH_ALL_TYPES_AND(torch::ScalarType::Half, x.scalar_type(), "row_masked_mean_kernel", [&] {
    row_masked_mean_kernel<scalar_t><<<blocks, threads, shmem>>>(
      x.data_ptr<scalar_t>(),
      mask.data_ptr<uint8_t>(),
      out.data_ptr<float>(),
      B, D
    );
  });
  return out;
}

std::vector<torch::Tensor> cuda_segment_mean(torch::Tensor x,
                                             torch::Tensor seg_ids,
                                             int64_t K) {
  TORCH_CHECK(x.is_cuda(), "x must be CUDA");
  TORCH_CHECK(seg_ids.is_cuda(), "seg_ids must be CUDA");
  TORCH_CHECK(x.dim() == 2, "x must be (N, D)");
  TORCH_CHECK(seg_ids.dim() == 1, "seg_ids must be (N)");
  const auto N = x.size(0);
  const auto D = x.size(1);
  auto out = torch::zeros({K, D}, x.options().dtype(torch::kFloat32));
  auto cnt = torch::zeros({K}, x.options().dtype(torch::kFloat32));
  const int threads = 256;
  const int blocks = N;
  AT_DISPATCH_ALL_TYPES_AND(torch::ScalarType::Half, x.scalar_type(), "segment_sum_kernel", [&] {
    segment_sum_kernel<scalar_t><<<blocks, threads>>>(
      x.data_ptr<scalar_t>(),
      seg_ids.data_ptr<int32_t>(),
      out.data_ptr<float>(),
      cnt.data_ptr<float>(),
      N, D, (int)K
    );
  });
  return {out, cnt};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("row_mean", &cuda_row_mean, "Row-wise mean (safe) [CUDA]");
  m.def("row_masked_mean", &cuda_row_masked_mean, "Row-wise masked mean (safe) [CUDA]");
  m.def("segment_mean_raw", &cuda_segment_mean, "Segment sum+count (CUDA)");
}
"""

if device == "cuda":
    cuda_kernels = load(
        name="mean_kernels",
        sources=[cuda_src],
        verbose=False,
    )
    print("Loaded custom CUDA kernels.")
else:
    cuda_kernels = None
    print("CUDA not available, skipping custom kernels.")

# ======================================
# 4. Benchmark harness
# ======================================

def bench(fn, iters=50, cuda_sync=True):
    # warmup
    for _ in range(5):
        fn()
    if cuda_sync and torch.cuda.is_available():
        torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        out = fn()
    if cuda_sync and torch.cuda.is_available():
        torch.cuda.synchronize()
    t1 = time.perf_counter()
    return (t1 - t0) * 1000.0 / iters  # ms

# ======================================
# 5. Run benchmarks for various dtypes
# ======================================

B, D = 4096, 1024      # row-wise mean shape
N, K = 8192, 128       # segment shape (N rows, K segments)

dtypes = [torch.float32, torch.float16]  # you can add int8 etc. for CPU / cast

for dtype in dtypes:
    print(f"\n=== dtype: {dtype} ===")

    # generate data
    x_cpu = torch.randn(B, D, dtype=dtype)
    mask_cpu = (torch.rand(B, D) > 0.5).to(torch.bool)
    seg_ids_cpu = torch.randint(low=0, high=K, size=(N,), dtype=torch.int64)
    x_seg_cpu = torch.randn(N, D, dtype=dtype)

    # PyTorch CPU
    def fn_torch_cpu_batch_mean():
        return torch_safe_mean(x_cpu, dim=-1)

    def fn_torch_cpu_masked_mean():
        return torch_masked_mean(x_cpu, mask_cpu, dim=-1)

    def fn_torch_cpu_segment_mean():
        return torch_segment_mean(x_seg_cpu, seg_ids_cpu, num_segments=K)

    print("=== PyTorch CPU ===")
    print("torch_cpu_batch_mean:   %.3f ms" % bench(fn_torch_cpu_batch_mean, cuda_sync=False))
    print("torch_cpu_masked_mean:  %.3f ms" % bench(fn_torch_cpu_masked_mean, cuda_sync=False))
    print("torch_cpu_segment_mean: %.3f ms" % bench(fn_torch_cpu_segment_mean, cuda_sync=False))

    # NumPy CPU
    x_np = x_cpu.numpy().astype(np.float32)   # use float32 internally
    mask_np = mask_cpu.numpy()
    x_seg_np = x_seg_cpu.numpy().astype(np.float32)
    seg_ids_np = seg_ids_cpu.numpy()

    def fn_numpy_batch_mean():
        return np_safe_mean(x_np, axis=-1)

    def fn_numpy_masked_mean():
        return np_masked_mean(x_np, mask_np, axis=-1)

    def fn_numpy_segment_mean():
        return np_segment_mean(x_seg_np, seg_ids_np, num_segments=K)

    print("=== NumPy CPU ===")
    print("numpy_batch_mean:   %.3f ms" % bench(fn_numpy_batch_mean, cuda_sync=False))
    print("numpy_masked_mean:  %.3f ms" % bench(fn_numpy_masked_mean, cuda_sync=False))
    print("numpy_segment_mean: %.3f ms" % bench(fn_numpy_segment_mean, cuda_sync=False))

    if device == "cuda":
        x_gpu = x_cpu.to(device)
        mask_gpu = mask_cpu.to(device)
        x_seg_gpu = x_seg_cpu.to(device)
        seg_ids_gpu = seg_ids_cpu.to(device).to(torch.int32)

        def fn_torch_gpu_batch_mean():
            return torch_safe_mean(x_gpu, dim=-1)

        def fn_torch_gpu_masked_mean():
            return torch_masked_mean(x_gpu, mask_gpu, dim=-1)

        def fn_torch_gpu_segment_mean():
            return torch_segment_mean(x_seg_gpu, seg_ids_gpu.to(torch.long), num_segments=K)

        print("=== PyTorch GPU ===")
        print("torch_gpu_batch_mean:   %.3f ms" % bench(fn_torch_gpu_batch_mean))
        print("torch_gpu_masked_mean:  %.3f ms" % bench(fn_torch_gpu_masked_mean))
        print("torch_gpu_segment_mean: %.3f ms" % bench(fn_torch_gpu_segment_mean))

        # Triton
        if HAS_TRITON:
            B_, D_ = x_gpu.shape

            def fn_triton_batch_mean():
                x32 = x_gpu.to(torch.float32)
                out = torch.empty(B_, device=device, dtype=torch.float32)
                grid = (B_,)
                triton_row_mean_kernel[grid](
                    x32, out, B_, D_,
                    BLOCK_SIZE=128,
                )
                return out

            def fn_triton_masked_mean():
                x32 = x_gpu.to(torch.float32)
                m8 = mask_gpu.to(torch.uint8)
                out = torch.empty(B_, device=device, dtype=torch.float32)
                grid = (B_,)
                triton_row_masked_mean_kernel[grid](
                    x32, m8, out, B_, D_,
                    BLOCK_SIZE=128,
                )
                return out

            def fn_triton_segment_mean():
                # segment_sum -> divide
                x32 = x_seg_gpu.to(torch.float32)
                out = torch.zeros(K, D, dtype=torch.float32, device=device)
                cnt = torch.zeros(K, dtype=torch.float32, device=device)
                grid = (triton.cdiv(N, 32), triton.cdiv(D, 32))
                triton_segment_sum_kernel[grid](
                    x32, seg_ids_gpu, out, cnt,
                    N, D, K,
                    BLOCK_SIZE_N=32,
                    BLOCK_SIZE_D=32,
                )
                denom = torch.clamp(cnt.view(K, 1), min=1.0)
                mean = out / denom
                mean[cnt == 0] = 0.0
                return mean

            print("=== Triton GPU ===")
            print("triton_batch_mean:   %.3f ms" % bench(fn_triton_batch_mean))
            print("triton_masked_mean:  %.3f ms" % bench(fn_triton_masked_mean))
            print("triton_segment_mean: %.3f ms" % bench(fn_triton_segment_mean))

        # CUDA custom
        if cuda_kernels is not None:
            def fn_cuda_batch_mean():
                # returns float32
                return cuda_kernels.row_mean(x_gpu)

            def fn_cuda_masked_mean():
                return cuda_kernels.row_masked_mean(x_gpu, mask_gpu.to(torch.uint8))

            def fn_cuda_segment_mean():
                out, cnt = cuda_kernels.segment_mean_raw(x_seg_gpu, seg_ids_gpu, K)
                denom = torch.clamp(cnt.view(K, 1), min=1.0)
                mean = out / denom
                mean[cnt == 0] = 0.0
                return mean

            print("=== Custom CUDA ===")
            print("cuda_batch_mean:   %.3f ms" % bench(fn_cuda_batch_mean))
            print("cuda_masked_mean:  %.3f ms" % bench(fn_cuda_masked_mean))
            print("cuda_segment_mean: %.3f ms" % bench(fn_cuda_segment_mean))

In [None]:
# triton_safe_masked_mean.py
import torch
import triton
import triton.language as tl

@triton.jit
def safe_masked_mean_rowwise_kernel(
    x_ptr,         # *f32
    mask_ptr,      # *i32 or *f32 (0/1); pass nullptr for unmasked
    out_ptr,       # *f32
    B, N,
    stride_xb, stride_xn,
    stride_mb, stride_mn,
    BLOCK_N: tl.constexpr,
):
    b = tl.program_id(0)  # batch index

    offs_n = tl.arange(0, BLOCK_N)
    row_x_ptr = x_ptr + b * stride_xb + offs_n * stride_xn

    has_mask = mask_ptr != 0
    if has_mask:
        row_m_ptr = mask_ptr + b * stride_mb + offs_n * stride_mn

    acc_sum = tl.zeros((), dtype=tl.float32)
    acc_count = tl.zeros((), dtype=tl.float32)

    for start_n in range(0, N, BLOCK_N):
        cur_mask = start_n + offs_n < N

        x_vals = tl.load(row_x_ptr + start_n * stride_xn,
                         mask=cur_mask,
                         other=0.0)

        if has_mask:
            m_vals = tl.load(row_m_ptr + start_n * stride_mn,
                             mask=cur_mask,
                             other=0)
            # assume mask is 0/1 or bool, cast to float
            m_vals_f = m_vals.to(tl.float32)
        else:
            m_vals_f = tl.where(cur_mask, 1.0, 0.0)

        acc_sum += tl.sum(x_vals * m_vals_f, axis=0)
        acc_count += tl.sum(m_vals_f, axis=0)

    # safe mean: if acc_count == 0, define mean = 0
    safe_count = tl.maximum(acc_count, 1.0)
    mean = acc_sum / safe_count
    mean = tl.where(acc_count > 0, mean, 0.0)

    tl.store(out_ptr + b, mean)

def safe_masked_mean_rowwise(x: torch.Tensor, mask: torch.Tensor | None = None):
    """
    x: (B, N), float32
    mask: (B, N) or None; 0/1 or bool
    returns: (B,)
    """
    assert x.dim() == 2
    B, N = x.shape
    x = x.contiguous()

    if mask is not None:
        mask = mask.to(torch.int32).contiguous()
        mask_ptr = mask
    else:
        mask_ptr = torch.tensor([], device=x.device, dtype=torch.int32)  # dummy
        # We'll treat 'mask_ptr != 0' as 'has_mask', so ensure it's not literally None.
        # Instead: we pass mask_ptr.data_ptr()==0? Triton can't; hack: pass 0 below.

    out = torch.empty(B, device=x.device, dtype=torch.float32)

    BLOCK_N = 128
    grid = (B,)

    # carefully pass 0 for mask_ptr when mask is None
    mask_arg = mask_ptr if mask is not None else 0

    safe_masked_mean_rowwise_kernel[grid](
        x,
        mask_arg,
        out,
        B, N,
        x.stride(0), x.stride(1),
        mask.stride(0) if mask is not None else 0,
        mask.stride(1) if mask is not None else 0,
        BLOCK_N=BLOCK_N,
    )
    return out

// safe_masked_mean.cu
#include <cuda_runtime.h>
#include <stdint.h>

__global__ void safe_masked_mean_kernel(
    const float* __restrict__ x,
    const uint8_t* __restrict__ mask,  // 0 or 1; nullptr for unmasked
    int64_t N,
    float* __restrict__ out_sum,
    float* __restrict__ out_count
) {
    extern __shared__ float shmem[];
    float* sh_sum   = shmem;
    float* sh_count = shmem + blockDim.x;

    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    float local_sum = 0.0f;
    float local_count = 0.0f;

    for (int64_t i = idx; i < N; i += blockDim.x * gridDim.x) {
        uint8_t m = mask ? mask[i] : 1;
        if (m) {
            local_sum += x[i];
            local_count += 1.0f;
        }
    }

    sh_sum[threadIdx.x] = local_sum;
    sh_count[threadIdx.x] = local_count;
    __syncthreads();

    for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
        if (threadIdx.x < stride) {
            sh_sum[threadIdx.x] += sh_sum[threadIdx.x + stride];
            sh_count[threadIdx.x] += sh_count[threadIdx.x + stride];
        }
        __syncthreads();
    }

    if (threadIdx.x == 0) {
        atomicAdd(out_sum, sh_sum[0]);
        atomicAdd(out_count, sh_count[0]);
    }
}


#include <torch/extension.h>  // or your own wrapper
#include <cuda_runtime.h>

std::pair<float, float> safe_masked_mean_cuda(
    const float* d_x,
    const uint8_t* d_mask,
    int64_t N
) {
    float h_sum = 0.0f;
    float h_count = 0.0f;

    float* d_sum;
    float* d_count;
    cudaMalloc(&d_sum, sizeof(float));
    cudaMalloc(&d_count, sizeof(float));
    cudaMemcpy(d_sum, &h_sum, sizeof(float), cudaMemcpyHostToDevice);
    cudaMemcpy(d_count, &h_count, sizeof(float), cudaMemcpyHostToDevice);

    int threads = 256;
    int blocks = (N + threads - 1) / threads;
    size_t shmem = 2 * threads * sizeof(float);

    safe_masked_mean_kernel<<<blocks, threads, shmem>>>(d_x, d_mask, N, d_sum, d_count);

    cudaMemcpy(&h_sum, d_sum, sizeof(float), cudaMemcpyDeviceToHost);
    cudaMemcpy(&h_count, d_count, sizeof(float), cudaMemcpyDeviceToHost);

    cudaFree(d_sum);
    cudaFree(d_count);

    // safe mean
    float mean = (h_count > 0.0f) ? (h_sum / h_count) : 0.0f;
    return {mean, h_count};
}

sum[seg] = Σ x[i] for i with segment_ids[i] == seg and mask[i]==1
count[seg] = Σ 1    for same
mean[seg] = sum[seg] / max(count[seg], 1)

// segment_mean.cu
#include <cuda_runtime.h>
#include <stdint.h>

__global__ void segment_sum_count_kernel(
    const float* __restrict__ x,
    const int32_t* __restrict__ segment_ids,
    const uint8_t* __restrict__ mask,    // 0/1 or nullptr
    int64_t N,
    float* __restrict__ seg_sums,
    float* __restrict__ seg_counts,
    int32_t num_segments
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    for (int64_t i = idx; i < N; i += blockDim.x * gridDim.x) {
        uint8_t m = mask ? mask[i] : 1;
        if (!m) {
            continue;
        }

        int32_t seg = segment_ids[i];
        if (seg < 0 || seg >= num_segments) {
            continue;  // or assert
        }

        float val = x[i];
        atomicAdd(&seg_sums[seg], val);
        atomicAdd(&seg_counts[seg], 1.0f);
    }
}

__global__ void segment_safe_mean_kernel(
    const float* __restrict__ seg_sums,
    const float* __restrict__ seg_counts,
    float* __restrict__ seg_means,
    int32_t num_segments
) {
    int seg = blockIdx.x * blockDim.x + threadIdx.x;
    if (seg >= num_segments) return;

    float s = seg_sums[seg];
    float c = seg_counts[seg];

    if (c > 0.0f) {
        seg_means[seg] = s / c;
    } else {
        seg_means[seg] = 0.0f;  // safe mean for empty segment
    }
}

#include <torch/extension.h>  // or your own wrapper
#include <cuda_runtime.h>

std::pair<float, float> safe_masked_mean_cuda(
    const float* d_x,
    const uint8_t* d_mask,
    int64_t N
) {
    float h_sum = 0.0f;
    float h_count = 0.0f;

    float* d_sum;
    float* d_count;
    cudaMalloc(&d_sum, sizeof(float));
    cudaMalloc(&d_count, sizeof(float));
    cudaMemcpy(d_sum, &h_sum, sizeof(float), cudaMemcpyHostToDevice);
    cudaMemcpy(d_count, &h_count, sizeof(float), cudaMemcpyHostToDevice);

    int threads = 256;
    int blocks = (N + threads - 1) / threads;
    size_t shmem = 2 * threads * sizeof(float);

    safe_masked_mean_kernel<<<blocks, threads, shmem>>>(d_x, d_mask, N, d_sum, d_count);

    cudaMemcpy(&h_sum, d_sum, sizeof(float), cudaMemcpyDeviceToHost);
    cudaMemcpy(&h_count, d_count, sizeof(float), cudaMemcpyDeviceToHost);

    cudaFree(d_sum);
    cudaFree(d_count);

    // safe mean
    float mean = (h_count > 0.0f) ? (h_sum / h_count) : 0.0f;
    return {mean, h_count};
}

#sum[seg] = Σ x[i] for i with segment_ids[i] == seg and mask[i]==1
#count[seg] = Σ 1    for same
#mean[seg] = sum[seg] / max(count[seg], 1)

// segment_mean.cu
#include <cuda_runtime.h>
#include <stdint.h>

__global__ void segment_sum_count_kernel(
    const float* __restrict__ x,
    const int32_t* __restrict__ segment_ids,
    const uint8_t* __restrict__ mask,    // 0/1 or nullptr
    int64_t N,
    float* __restrict__ seg_sums,
    float* __restrict__ seg_counts,
    int32_t num_segments
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    for (int64_t i = idx; i < N; i += blockDim.x * gridDim.x) {
        uint8_t m = mask ? mask[i] : 1;
        if (!m) {
            continue;
        }

        int32_t seg = segment_ids[i];
        if (seg < 0 || seg >= num_segments) {
            continue;  // or assert
        }

        float val = x[i];
        atomicAdd(&seg_sums[seg], val);
        atomicAdd(&seg_counts[seg], 1.0f);
    }
}

__global__ void segment_safe_mean_kernel(
    const float* __restrict__ seg_sums,
    const float* __restrict__ seg_counts,
    float* __restrict__ seg_means,
    int32_t num_segments
) {
    int seg = blockIdx.x * blockDim.x + threadIdx.x;
    if (seg >= num_segments) return;

    float s = seg_sums[seg];
    float c = seg_counts[seg];

    if (c > 0.0f) {
        seg_means[seg] = s / c;
    } else {
        seg_means[seg] = 0.0f;  // safe mean for empty segment
    }
}

void segment_mean_cuda(
    const float* d_x,
    const int32_t* d_segment_ids,
    const uint8_t* d_mask,      // may be nullptr
    int64_t N,
    int32_t num_segments,
    float* d_out_means
) {
    float* d_sums;
    float* d_counts;
    cudaMalloc(&d_sums,   num_segments * sizeof(float));
    cudaMalloc(&d_counts, num_segments * sizeof(float));
    cudaMemset(d_sums,   0, num_segments * sizeof(float));
    cudaMemset(d_counts, 0, num_segments * sizeof(float));

    int threads = 256;
    int blocks = (N + threads - 1) / threads;

    segment_sum_count_kernel<<<blocks, threads>>>(
        d_x, d_segment_ids, d_mask, N, d_sums, d_counts, num_segments
    );

    int blocks_seg = (num_segments + threads - 1) / threads;
    segment_safe_mean_kernel<<<blocks_seg, threads>>>(
        d_sums, d_counts, d_out_means, num_segments
    );

    cudaFree(d_sums);
    cudaFree(d_counts);
}


#triton segment mean sketch
@triton.jit
def segment_sum_count_kernel_triton(
    x_ptr, seg_id_ptr, mask_ptr,
    seg_sums_ptr, seg_counts_ptr,
    N, NUM_SEGMENTS: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offs < N

    x = tl.load(x_ptr + offs, mask=mask, other=0.0)
    seg = tl.load(seg_id_ptr + offs, mask=mask, other=0)
    has_mask = mask_ptr != 0
    if has_mask:
        m = tl.load(mask_ptr + offs, mask=mask, other=0)
        valid = mask & (m != 0)
    else:
        valid = mask

    x = tl.where(valid, x, 0.0)
    seg = tl.where(valid, seg, 0)

    # atomic adds
    tl.atomic_add(seg_sums_ptr + seg, x, mask=valid)
    tl.atomic_add(seg_counts_ptr + seg,
                  tl.where(valid, 1.0, 0.0),
                  mask=valid)

In [None]:
# Colab-ready benchmark: safe mean, masked mean, segment mean, batch mean
# Across: PyTorch CPU, PyTorch GPU, NumPy CPU, Triton kernel, CUDA (CuPy) kernel
import torch
!pip install -q triton==3.0.0 cupy-cuda12x

import time
import math
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn.functional as F

import triton
import triton.language as tl

import cupy as cp

# ------------------------------------------------------------
# 0. Config
# ------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# Problem sizes
BATCH = 4096     # number of rows
N      = 1024    # row length
SEG_K  = 128     # number of segments for segment mean

WARMUP_ITERS = 5
BENCH_ITERS  = 20

# ------------------------------------------------------------
# 1. Safe mean helpers (PyTorch & NumPy)
# ------------------------------------------------------------

def torch_safe_mean(x: torch.Tensor, dim=None, keepdim=False) -> torch.Tensor:
    """
    Safe mean: if count == 0, returns 0 (no NaN).
    Works with arbitrary dimension, keeps gradients.
    """
    if dim is None:
        # flatten
        x_flat = x.view(-1)
        count = x_flat.numel()
        if count == 0:
            return x_flat.new_zeros(())
        return x_flat.sum() / max(count, 1)
    else:
        # general dim
        x = x.float()
        ones = torch.ones_like(x, dtype=x.dtype)
        count = ones.sum(dim=dim, keepdim=keepdim)
        s = x.sum(dim=dim, keepdim=keepdim)
        # clamp denominator to at least 1 to avoid NaNs; where count==0, we force 0
        denom = count.clamp_min(1.0)
        mean = s / denom
        mean = torch.where(count > 0, mean, torch.zeros_like(mean))
        return mean

def torch_masked_mean(x: torch.Tensor, mask: torch.Tensor, dim=-1, keepdim=False):
    """
    mask: bool or 0/1, same shape as x.
    """
    x = x.float()
    mask = mask.to(dtype=x.dtype)
    s = (x * mask).sum(dim=dim, keepdim=keepdim)
    count = mask.sum(dim=dim, keepdim=keepdim)
    denom = count.clamp_min(1.0)
    mean = s / denom
    mean = torch.where(count > 0, mean, torch.zeros_like(mean))
    return mean

def torch_segment_mean(x: torch.Tensor, segment_ids: torch.Tensor, num_segments: int):
    """
    x: [N, D] or [N]; segment_ids: [N] in [0, num_segments-1]
    Returns [num_segments, D] or [num_segments]
    """
    if x.dim() == 1:
        x = x[:, None]
        squeeze = True
    else:
        squeeze = False

    N, D = x.shape
    device = x.device

    segment_ids = segment_ids.to(device=device, dtype=torch.long)
    out = torch.zeros(num_segments, D, device=device, dtype=x.dtype)
    count = torch.zeros(num_segments, 1, device=device, dtype=x.dtype)

    out.index_add_(0, segment_ids, x)
    ones = torch.ones(N, 1, device=device, dtype=x.dtype)
    count.index_add_(0, segment_ids, ones)

    denom = count.clamp_min(1.0)
    mean = out / denom
    mean = torch.where(count > 0, mean, torch.zeros_like(mean))

    if squeeze:
        mean = mean[:, 0]
    return mean

# NumPy equivalents (no gradient)
def np_safe_mean(x: np.ndarray, axis=None, keepdims=False):
    if x.size == 0:
        return np.zeros((), dtype=x.dtype)
    count = x.shape[axis] if axis is not None else x.size
    s = x.sum(axis=axis, keepdims=keepdims)
    return s / max(count, 1)

def np_masked_mean(x: np.ndarray, mask: np.ndarray, axis=-1, keepdims=False):
    x = x.astype(np.float32)
    mask = mask.astype(np.float32)
    s = (x * mask).sum(axis=axis, keepdims=keepdims)
    count = mask.sum(axis=axis, keepdims=keepdims)
    denom = np.clip(count, 1.0, None)
    out = s / denom
    out = np.where(count > 0, out, np.zeros_like(out))
    return out

import numpy as np

def np_segment_mean(x: np.ndarray,
                    segment_ids: np.ndarray,
                    num_segments: int):
    # Handle 1D x by temporarily promoting to 2D
    if x.ndim == 1:
        x = x[:, None]   # (N,) -> (N, 1)
        squeeze = True
    else:
        squeeze = False

    N, D = x.shape
    out = np.zeros((num_segments, D), dtype=x.dtype)      # sum per segment
    count = np.zeros((num_segments,), dtype=np.int64)     # count per segment

    # Accumulate per segment
    for i in range(N):
        seg = int(segment_ids[i])
        if 0 <= seg < num_segments:
            out[seg] += x[i]
            count[seg] += 1

    # Avoid div-by-zero by clamping denominator to at least 1
    denom = np.maximum(count, 1).reshape(num_segments, 1)  # (K, 1)
    mean = out / denom                                    # (K, D)

    # For segments where count == 0, explicitly set mean to 0
    zero_mask = (count == 0)  # (K,)
    mean[zero_mask] = 0.0

    if squeeze:
        mean = mean[:, 0]  # (K,)
    return mean

# ------------------------------------------------------------
# 2. Data setup
# ------------------------------------------------------------
torch.manual_seed(0)
np.random.seed(0)

x_torch_cpu = torch.randn(BATCH, N, dtype=torch.float32)
mask_torch_cpu = (torch.rand(BATCH, N) > 0.3).to(torch.bool)

seg_ids_cpu = torch.randint(0, SEG_K, (BATCH,), dtype=torch.long)

x_np = x_torch_cpu.numpy()
mask_np = mask_torch_cpu.numpy().astype(np.bool_)
seg_ids_np = seg_ids_cpu.numpy()

if DEVICE == "cuda":
    x_torch_gpu = x_torch_cpu.to("cuda")
    mask_torch_gpu = mask_torch_cpu.to("cuda")
    seg_ids_gpu = seg_ids_cpu.to("cuda")
else:
    x_torch_gpu = None
    mask_torch_gpu = None
    seg_ids_gpu = None

# ------------------------------------------------------------
# 3. Triton kernel: row-wise mean (batch mean) for 2D tensor
# ------------------------------------------------------------

@triton.jit
def row_mean_kernel(X_ptr, Y_ptr, BATCH, N, BLOCK_SIZE: tl.constexpr):
    row_id = tl.program_id(0)
    offs = row_id * N + tl.arange(0, BLOCK_SIZE)
    mask = offs < (row_id * N + N)

    x = tl.load(X_ptr + offs, mask=mask, other=0.0)
    # parallel reduction in block
    # here we just sum and rely on BLOCK_SIZE == N for simplicity
    # (you can generalize to partial tiles if needed)
    s = tl.sum(x, axis=0)
    # each program handles a whole row
    denom = N
    mean = s / denom
    tl.store(Y_ptr + row_id, mean)

def triton_row_mean(x: torch.Tensor) -> torch.Tensor:
    """
    x: [B, N] on CUDA
    returns: [B]
    """
    assert x.is_cuda
    B, N = x.shape
    y = torch.empty(B, device=x.device, dtype=x.dtype)

    BLOCK_SIZE = N  # simple case: one block per row
    grid = (B,)

    row_mean_kernel[grid](
        x, y,
        BATCH=B,
        N=N,
        BLOCK_SIZE=BLOCK_SIZE,
        num_warps=4,
    )
    return y

# ------------------------------------------------------------
# 4. CUDA kernel with CuPy: row-wise mean
# ------------------------------------------------------------

cuda_row_mean_src = r"""
extern "C" __global__
void row_mean(const float* __restrict__ x,
              float* __restrict__ y,
              int B, int N) {
    int row = blockIdx.x;
    if (row >= B) return;

    float sum = 0.0f;
    for (int i = threadIdx.x; i < N; i += blockDim.x) {
        sum += x[row * N + i];
    }

    __shared__ float smem[256]; // up to 256 threads
    int tid = threadIdx.x;
    smem[tid] = sum;
    __syncthreads();

    // simple reduction in shared memory
    for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
        if (tid < stride) {
            smem[tid] += smem[tid + stride];
        }
        __syncthreads();
    }

    if (tid == 0) {
        y[row] = smem[0] / (float)N;
    }
}
"""

row_mean_kernel_cuda = cp.RawKernel(cuda_row_mean_src, "row_mean")

def cuda_row_mean(x_torch: torch.Tensor) -> torch.Tensor:
    """
    Takes a CUDA torch tensor [B, N], uses CuPy to run custom kernel,
    returns torch tensor [B].
    """
    assert x_torch.is_cuda
    B, N = x_torch.shape
    # zero-copy via DLPack
    x_cu = cp.fromDlpack(torch.utils.dlpack.to_dlpack(x_torch))
    y_cu = cp.empty((B,), dtype=cp.float32)

    threads_per_block = 256
    blocks = (B,)

    row_mean_kernel_cuda(blocks, (threads_per_block,),
                         (x_cu, y_cu, B, N))

    # back to torch
    y_torch = torch.utils.dlpack.from_dlpack(y_cu.toDlpack())
    return y_torch

# ------------------------------------------------------------
# 5. Benchmark helpers
# ------------------------------------------------------------

@dataclass
class BenchResult:
    name: str
    time_ms: float

def bench(fn, iters=BENCH_ITERS, warmup=WARMUP_ITERS):
    # Warmup
    for _ in range(warmup):
        out = fn()
        if isinstance(out, torch.Tensor) and out.is_cuda:
            torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        out = fn()
        if isinstance(out, torch.Tensor) and out.is_cuda:
            torch.cuda.synchronize()
    t1 = time.perf_counter()
    return (t1 - t0) * 1000.0 / iters

results = []

# ------------------------------------------------------------
# 6. PyTorch CPU benchmarks
# ------------------------------------------------------------
print("\n=== PyTorch CPU ===")

# mean (batch-level: row-wise)
def fn_torch_cpu_batch_mean():
    return torch_safe_mean(x_torch_cpu, dim=1)  # [B]

t = bench(fn_torch_cpu_batch_mean)
results.append(BenchResult("torch_cpu_batch_mean", t))
print("torch_cpu_batch_mean: %.3f ms" % t)

# masked mean (row-wise)
def fn_torch_cpu_masked_mean():
    return torch_masked_mean(x_torch_cpu, mask_torch_cpu, dim=1)

t = bench(fn_torch_cpu_masked_mean)
results.append(BenchResult("torch_cpu_masked_mean", t))
print("torch_cpu_masked_mean: %.3f ms" % t)

# segment mean (over batch dimension)
def fn_torch_cpu_segment_mean():
    return torch_segment_mean(x_torch_cpu, seg_ids_cpu, num_segments=SEG_K)

t = bench(fn_torch_cpu_segment_mean)
results.append(BenchResult("torch_cpu_segment_mean", t))
print("torch_cpu_segment_mean: %.3f ms" % t)

# ------------------------------------------------------------
# 7. PyTorch GPU benchmarks (if available)
# ------------------------------------------------------------
if DEVICE == "cuda":
    print("\n=== PyTorch GPU ===")

    def fn_torch_gpu_batch_mean():
        return torch_safe_mean(x_torch_gpu, dim=1)

    t = bench(fn_torch_gpu_batch_mean)
    results.append(BenchResult("torch_gpu_batch_mean", t))
    print("torch_gpu_batch_mean: %.3f ms" % t)

    def fn_torch_gpu_masked_mean():
        return torch_masked_mean(x_torch_gpu, mask_torch_gpu, dim=1)

    t = bench(fn_torch_gpu_masked_mean)
    results.append(BenchResult("torch_gpu_masked_mean", t))
    print("torch_gpu_masked_mean: %.3f ms" % t)

    def fn_torch_gpu_segment_mean():
        return torch_segment_mean(x_torch_gpu, seg_ids_gpu, num_segments=SEG_K)

    t = bench(fn_torch_gpu_segment_mean)
    results.append(BenchResult("torch_gpu_segment_mean", t))
    print("torch_gpu_segment_mean: %.3f ms" % t)

# ------------------------------------------------------------
# 8. NumPy CPU benchmarks
# ------------------------------------------------------------
print("\n=== NumPy CPU ===")

def fn_numpy_batch_mean():
    return np_safe_mean(x_np, axis=1)

t = bench(fn_numpy_batch_mean)
results.append(BenchResult("numpy_batch_mean", t))
print("numpy_batch_mean: %.3f ms" % t)

def fn_numpy_masked_mean():
    return np_masked_mean(x_np, mask_np, axis=1)

t = bench(fn_numpy_masked_mean)
results.append(BenchResult("numpy_masked_mean", t))
print("numpy_masked_mean: %.3f ms" % t)

def fn_numpy_segment_mean():
    return np_segment_mean(x_np, seg_ids_np, num_segments=SEG_K)

t = bench(fn_numpy_segment_mean)
results.append(BenchResult("numpy_segment_mean", t))
print("numpy_segment_mean: %.3f ms" % t)

# ------------------------------------------------------------
# 9. Triton benchmark (row-wise mean only)
# ------------------------------------------------------------
if DEVICE == "cuda":
    print("\n=== Triton row-wise mean (CUDA) ===")

    def fn_triton_row_mean():
        return triton_row_mean(x_torch_gpu)

    t = bench(fn_triton_row_mean)
    results.append(BenchResult("triton_row_mean", t))
    print("triton_row_mean: %.3f ms" % t)

# ------------------------------------------------------------
# 10. CUDA (CuPy) kernel benchmark (row-wise mean only)
# ------------------------------------------------------------
if DEVICE == "cuda":
    print("\n=== CUDA (CuPy) row-wise mean ===")

    def fn_cuda_row_mean():
        return cuda_row_mean(x_torch_gpu)

    t = bench(fn_cuda_row_mean)
    results.append(BenchResult("cuda_row_mean", t))
    print("cuda_row_mean: %.3f ms" % t)

# ------------------------------------------------------------
# 11. Summary
# ------------------------------------------------------------
print("\n=== Summary (ms per call) ===")
for r in results:
    print(f"{r.name:30s}: {r.time_ms:8.3f} ms")

Math trick in self attention
Karpathy YT https://www.youtube.com/watch?v=kCc8FmEb1nY 42:27


In [None]:

torch.manual_seed(1337)
B, T, C = 4,8,2
x = torch.randn(B,T,C)
x.shape
#we want to look at past tokens from currnt positio
# Batch, Time, Channels
#simplest way to communicate with past tokens is to take average of tokens before
# current token. This vector of 5 past tokens with an average becomes
#
xbow = torch.zeros(B,T,C)
for batch in range(B):
  for time in range(T):
    xprev = x[batch, time+1, ] #t,C
    xbow[b,t] = torch.mean()


In [None]:
#https://github.com/facebookresearch/MobileLLM-R1
# is this really true? Doesnt match Karpathy's progression with nanogpt, ie faster tokenizer needed, etc..
from google.colab import drive
drive.mount('/content/drive')

# vim editor to prevent unaligned tabs when pasting into vim from chatGPT
# :set paste

In [None]:
%cd /content/drive/MyDrive/quora

Simple pytorch dataset

In [None]:
import pandas as pd
import numpy as np

PATH = "/content/drive/MyDrive/quora/questions 3.csv"
df = pd.read_csv(PATH)

df = df.dropna(subset=["question1", "question2", "is_duplicate"])
df = df.sample(50000, random_state=0)  # subsample for speed

df.head()

$$\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right)\mathbf{V}$$

In [None]:
# attention
from transformers import AutoTokenizer, AutoModel
import torch


def get_embeddings(text):
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
  model = AutoModel.from_pretrained("bert-base-uncased")
  inputs = tokenizer(text, return_tensors="pt")
  print(inputs)
  outputs = model(**inputs)
  return outputs.last_hidden_state

print(get_embeddings("hello").shape)
print(get_embeddings("hello"))
#the first dim = batch size, how many sentences
#second dim = sequence len or token count. add start/stop tokens to hello, CLS, SEP, classification token and separator token
# one row for each token.

In [None]:
def sdpa(Wq, Wk, Wv, text):
  embed = get_embeddings(text)
  Q = torch.matmul(embed, Wq)
  K = torch.matmul(embed, Wk)
  V = torch.matmul(embed, Wv)
  print(Q.shape,K.shape,embed.shape)
  attn_scores = torch.matmul(Q, torch.transpose(K,1,2)) / torch.sqrt(torch.tensor(embed.shape[2]))
  #d_k is the embedding dimension
  #do we need -1?
  print("attn_scores shape",attn_scores.shape)
  scaled = torch.softmax(attn_scores,dim=-1)
  attn = torch.matmul(scaled, V)

  return attn

Wq = torch.rand(768,768)
Wk = torch.rand(768,768)
Wv = torch.rand(768,768)

print(sdpa(Wq, Wk, Wv, "hello").shape)



$$\text{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Concat}(\text{head}_1, \text{head}_2, \dots, \text{head}_h)\mathbf{W}^O$$

In [None]:
#karpathy lets build gpt 1:02:00
torch.manual_seed(1337)
B,T,C = 4, 8, 32
x = torch.randn(B,T,C)

tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)

out = wei @ x


Triton SDPA

When $\text{softmax}$ is applied:$$\text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}} \text{ with causal mask}\right)$$The $e^{-\infty}$ terms become zero, effectively giving zero attention weight to all future tokens.

In [None]:
import torch
import triton
import triton.language as tl
import math

# Define the Triton Kernel
@triton.jit
def causal_attention_kernel(
    Q, K, V, O,  # Data Pointers for Query, Key, Value, Output
    sm_scale,  # Scaling factor: 1/sqrt(d_k)
    # Tensors dimensions
    Lq, Lk, Lv, Lo,  # Strides for Q, K, V, O
    N_CTX,  # Sequence Length (Max tokens)
    D_HEAD,  # Head Dimension (d_k)
    # Block dimensions (These are fixed by the user when launching)
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
):
    """
    Computes Causal Masked Scaled Dot-Product Attention.

    This kernel implements the online softmax trick to avoid materializing the
    full N x N attention matrix in global memory, which is the core principle
    of FlashAttention.
    """
    # 1. Block Indexing for parallelization
    # Program ID (PID) maps to a Query block M (rows in the attention matrix)
    pid_m = tl.program_id(0)

    # Initialize a pointer to the output block (O_ptr)
    # O is indexed by [pid_m * BLOCK_M, :], scaled by the strides.
    offs_om = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_DMODEL)
    O_ptr = O + offs_om[:, None] * Lo + offs_n[None, :] * Lk

    # Initialize the accumulators for the output (O_i), running maximum (m_i),
    # and normalization factor (l_i). These are the core variables for online softmax.
    m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
    l_i = tl.full([BLOCK_M], 0.0, dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)

    # 2. Load Query block Q_i
    # Q is indexed by [pid_m * BLOCK_M, :]
    offs_qm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    Q_ptr = Q + offs_qm[:, None] * Lq + offs_d[None, :] * Lk

    # Load Q block and multiply by the scaling factor
    q = tl.load(Q_ptr) * sm_scale

    # 3. Iterate over blocks of K_j and V_j (tiling along the sequence dimension N)
    for start_n in range(0, N_CTX, BLOCK_N):
        # Create pointer offsets for the current block K_j and V_j
        offs_n_load = start_n + tl.arange(0, BLOCK_N)
        offs_k = offs_n_load[None, :] * Lk + offs_d[:, None] * Lq
        K_ptr = K + offs_k
        V_ptr = V + offs_n_load[:, None] * Lv + offs_d[None, :] * Lk

        # Load K_j and V_j block
        k = tl.load(K_ptr)
        v = tl.load(V_ptr)

        # 4. Compute attention scores S_ij = Q_i * K_j^T
        # s has shape [BLOCK_M, BLOCK_N]
        s = tl.dot(q, k, allow_tf32=True)

        # 5. Apply Causal Masking (Prevent attending to future tokens)
        # s must be masked where query index > key index.
        # This is where the causal constraint (i <= j) is enforced.
        mask = offs_qm[:, None] >= offs_n_load[None, :]
        s = tl.where(mask, s, float("-inf"))

        # 6. Online Softmax Update (Row-wise max and sum)
        # This is the core trick for numerical stability and memory efficiency.

        # 6a. Compute the new row-wise maximum m_j
        m_j = tl.max(s, 1)

        # 6b. Update the running maximum m_i
        m_new = tl.maximum(m_i, m_j)

        # 6c. Compute the exponential terms e_i and e_j
        alpha = tl.exp(m_i - m_new)
        beta = tl.exp(m_j - m_new)

        # 6d. Update the running normalization factor l_i
        l_new = alpha * l_i + beta
        l_i = l_new

        # 6e. Re-scale the previous accumulator acc
        acc_scale = alpha / l_i
        acc = acc * acc_scale[:, None]

        # 6f. Compute the attention weights and update the accumulator
        s = s - m_new[:, None]
        p = tl.exp(s) * (beta / l_i)[:, None]

        # 6g. Update the accumulator: acc_new = acc_old + P_ij * V_j
        acc = acc + tl.dot(p, v, allow_tf32=True)

        # 6h. Update the running maximum m_i for the next iteration
        m_i = m_new

    # 7. Write the final result to the output tensor O
    # The final output is acc (weighted sum of V) divided by the final normalization l_i
    tl.store(O_ptr, acc / l_i[:, None])

# ----------------------------------------------------------------------
# Python Host Wrapper
# ----------------------------------------------------------------------

def run_attention_kernel(q, k, v, is_causal=True):
    """
    Runs the Triton attention kernel with torch tensors.

    Args:
        q (torch.Tensor): Query tensor (L, D).
        k (torch.Tensor): Key tensor (L, D).
        v (torch.Tensor): Value tensor (L, D).
        is_causal (bool): Whether to apply causal masking.
    """
    assert q.shape == k.shape == v.shape, "Q, K, V must have the same shape (L, D) for self-attention."
    assert q.is_cuda and k.is_cuda and v.is_cuda, "Inputs must be on the GPU."

    N_CTX, D_HEAD = q.shape[0], q.shape[1]

    # Hyperparameters: Tune these for performance
    BLOCK_M = 64  # Block size for the Query dimension (rows)
    BLOCK_N = 64  # Block size for the Key dimension (columns)
    BLOCK_DMODEL = D_HEAD # Block size for the Head dimension

    # Define the scaling factor: 1/sqrt(d_k)
    sm_scale = 1.0 / math.sqrt(D_HEAD)

    # Output tensor initialization (N_CTX, D_HEAD)
    o = torch.empty_like(q)

    # 1D launch grid: we launch one program for each block M (rows) in the sequence
    grid = lambda META: (triton.cdiv(N_CTX, META['BLOCK_M']),)

    # Kernel call
    causal_attention_kernel[grid](
        q, k, v, o,  # Pointers
        sm_scale,  # Scaling factor
        q.stride(0), k.stride(0), v.stride(0), o.stride(0),  # Strides
        N_CTX, D_HEAD,
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=BLOCK_DMODEL,
        num_warps=4,
        num_stages=3
    )
    return o

# ----------------------------------------------------------------------
# Example Usage
# ----------------------------------------------------------------------
if __name__ == '__main__':
    # Ensure inputs are in a format compatible with Triton (float16 is typical for perf)
    dtype = torch.float16
    device = 'cuda'

    # Set Sequence Length (N_CTX) and Head Dimension (D_HEAD)
    N_CTX = 256
    D_HEAD = 64

    # Create dummy tensors for Q, K, V (all derived from the same input for self-attention)
    # Shape: [Sequence Length, Head Dimension]
    Q_data = torch.randn(N_CTX, D_HEAD, dtype=dtype, device=device)
    K_data = torch.randn(N_CTX, D_HEAD, dtype=dtype, device=device)
    V_data = torch.randn(N_CTX, D_HEAD, dtype=dtype, device=device)

    # Run the Triton kernel
    output_triton = run_attention_kernel(Q_data, K_data, V_data)

    print(f"Input Shape (Q, K, V): {Q_data.shape}")
    print(f"Output Shape (O):      {output_triton.shape}")
    print(f"\nExample Output (First 5 elements of first row):\n{output_triton[0, :5]}")

    # For comparison, you would compare this output against a known, verified
    # PyTorch implementation (like F.scaled_dot_product_attention with causal=True).
    # The numerical results should be close, demonstrating the kernel works.

In [None]:
texts1 = df["question1"].tolist()
texts2 = df["question2"].tolist()
y = df["is_duplicate"].values

In [None]:
!pip install -q sentence-transformers

In [None]:
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

model_name = "sentence-transformers/all-MiniLM-L6-v2"
st_model = SentenceTransformer(model_name)

# For speed, we encode q1 and q2 separately and reuse
emb1 = st_model.encode(texts1, batch_size=128, convert_to_numpy=True)
emb2 = st_model.encode(texts2, batch_size=128, convert_to_numpy=True)

In [None]:
print(type(texts1), type(texts2))
print(len(texts1), len(texts2))
print(texts1[0][:100])
print(texts2[0][:100])

In [None]:
len(texts1),len(texts2)

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer

# Basic lexical features
tfidf = TfidfVectorizer(dtype=np.float64, min_df=5, max_features=20000, stop_words="english")

tfidf.fit(texts1 + texts2)


import numpy as np

def lexical_features(q1_list, q2_list):
    assert len(q1_list) == len(q2_list), "q1_list and q2_list must have same length"

    # TF-IDF vectors
    v1 = tfidf.transform(q1_list)   # csr_matrix [N, V]
    v2 = tfidf.transform(q2_list)   # csr_matrix [N, V]

    # L2 norms: sum(...) returns np.matrix, so convert to 1D ndarray with .A1
    v1_sq = v1.multiply(v1).sum(axis=1).A1   # [N]
    v2_sq = v2.multiply(v2).sum(axis=1).A1   # [N]

    v1_norm = np.sqrt(v1_sq) + 1e-8          # [N]
    v2_norm = np.sqrt(v2_sq) + 1e-8          # [N]

    # Cosine similarity: again, sum(...) → matrix, so .A1
    num = v1.multiply(v2).sum(axis=1).A1     # [N]
    cos = num / (v1_norm * v2_norm)          # [N], elementwise

    # Length-based features
    len1 = np.array([len(t.split()) for t in q1_list])   # [N]
    len2 = np.array([len(t.split()) for t in q2_list])   # [N]
    len_diff = np.abs(len1 - len2)                       # [N]

    # Crude token overlap (Jaccard)
    overlap = []
    for a, b in zip(q1_list, q2_list):
        s1 = set(a.lower().split())
        s2 = set(b.lower().split())
        if not s1 or not s2:
            overlap.append(0.0)
        else:
            overlap.append(len(s1 & s2) / len(s1 | s2))
    overlap = np.array(overlap)                          # [N]

    # Stack into feature matrix [N, 5]
    feats = np.stack([cos, len1, len2, len_diff, overlap], axis=1)
    return feats

# import numpy as np

# def lexical_features(q1_list, q2_list):
#     assert len(q1_list) == len(q2_list), "q1_list and q2_list must have same length"

#     # TF-IDF vectors for each side
#     v1 = tfidf.transform(q1_list)   # sparse [N, V]
#     v2 = tfidf.transform(q2_list)   # sparse [N, V]

#     # L2 norms
#     v1_norm = np.sqrt(v1.multiply(v1).sum(axis=1)) + 1e-8   # [N, 1]
#     v2_norm = np.sqrt(v2.multiply(v2).sum(axis=1)) + 1e-8   # [N, 1]

#     # Cosine similarity for each pair
#     cos = (v1.multiply(v2).sum(axis=1) / (v1_norm * v2_norm)).A1  # -> [N]

#     # Length-based features
#     len1 = np.array([len(t.split()) for t in q1_list])           # [N]
#     len2 = np.array([len(t.split()) for t in q2_list])           # [N]
#     len_diff = np.abs(len1 - len2)                               # [N]

#     # Crude token overlap (Jaccard)
#     overlap = []
#     for a, b in zip(q1_list, q2_list):
#         s1 = set(a.lower().split())
#         s2 = set(b.lower().split())
#         if not s1 or not s2:
#             overlap.append(0.0)
#         else:
#             overlap.append(len(s1 & s2) / len(s1 | s2))
#     overlap = np.array(overlap)                                  # [N]

#     # Stack into feature matrix [N, 5]
#     feats = np.stack([cos, len1, len2, len_diff, overlap], axis=1)
#     return feats


# def lexical_features(q1_list, q2_list):
#     # naive TF-IDF cosine similarity & lengths & token overlap
#     v1 = tfidf.transform(q1_list)
#     v2 = tfidf.transform(q2_list)

#     v1_norm = np.sqrt(v1.multiply(v1).sum(axis=1)) + 1e-8
#     v2_norm = np.sqrt(v2.multiply(v2).sum(axis=1)) + 1e-8
#     cos = (v1.multiply(v2).sum(axis=1) / (v1_norm * v2_norm)).A1  # cosine similarity

#     len1 = np.array([len(t.split()) for t in q1_list])
#     len2 = np.array([len(t.split()) for t in q2_list])
#     len_diff = np.abs(len1 - len2)

#     # crude token overlap
#     overlap = []
#     for a, b in zip(q1_list, q2_list):
#         s1 = set(a.lower().split())
#         s2 = set(b.lower().split())
#         if not s1 or not s2:
#             overlap.append(0.0)
#         else:
#             overlap.append(len(s1 & s2) / len(s1 | s2))
#     overlap = np.array(overlap)

#     return np.stack([cos, len1, len2, len_diff, overlap], axis=1)

lex_feats = lexical_features(texts1, texts2)
print(lex_feats.shape)

In [None]:
# pairwise embedding features
diff = np.abs(emb1 - emb2)
prod = emb1 * emb2
pair_emb = np.concatenate([emb1, emb2, diff, prod], axis=1)

X_all = np.concatenate([pair_emb, lex_feats], axis=1)

scaler = StandardScaler()
X_all = scaler.fit_transform(X_all)

X_train, X_val, y_train, y_val = train_test_split(
    X_all, y, test_size=0.2, random_state=0, stratify=y
)

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score

def train_model(X, y, sample_weight=None):
    clf = LogisticRegression(
        max_iter=200,
        class_weight=None,
        n_jobs=-1,
    )
    clf.fit(X, y, sample_weight=sample_weight)
    return clf

def eval_model(clf, X_val, y_val):
    y_pred = clf.predict(X_val)
    return f1_score(y_val, y_pred)

base_clf = train_model(X_train, y_train)
base_f1 = eval_model(base_clf, X_val, y_val)
print("Base F1:", base_f1)

p_train = base_clf.predict_proba(X_train)[:, 1]
conf = np.abs(p_train - 0.5)  # small = uncertain → more likely noisy

In [None]:
candidate_policies = [
    {"theta": 0.05, "noisy_weight": 0.5},
    {"theta": 0.10, "noisy_weight": 0.5},
    {"theta": 0.10, "noisy_weight": 0.2},
    {"theta": 0.15, "noisy_weight": 0.2},
    {"theta": 0.20, "noisy_weight": 0.0},  # drop most uncertain
    {"theta": 0.30, "noisy_weight": 0.0},
]

def apply_policy(policy, conf):
    theta = policy["theta"]
    noisy_weight = policy["noisy_weight"]
    w = np.ones_like(conf, dtype=float)
    w[conf < theta] = noisy_weight
    return w

n_trials = 20
epsilon = 0.3
K = len(candidate_policies)
Q = np.zeros(K)
N = np.zeros(K)

for t in range(n_trials):
    if np.random.rand() < epsilon:
        k = np.random.randint(K)
    else:
        k = np.argmax(Q)

    policy = candidate_policies[k]
    sample_weight = apply_policy(policy, conf)

    clf = train_model(X_train, y_train, sample_weight=sample_weight)
    reward = eval_model(clf, X_val, y_val)

    N[k] += 1
    Q[k] += (reward - Q[k]) / N[k]

    print(f"Iter {t:02d} | trial policy {k} {policy} | F1={reward:.4f}")

best_k = np.argmax(Q)
best_policy = candidate_policies[best_k]
print("\nBase F1:", base_f1)
print("Best policy from bandit:", best_policy, "Estimated F1:", Q[best_k])

In [None]:
!pip install -q cleanlab

In [None]:
from typing import Optional, Literal, Dict, Any, Tuple, Callable

class NoisyDatasetCleaner:
    """
    A generic wrapper for noisy supervised datasets.
    Strategies:
      - 'bandit_weight'
      - 'cleanlab'
      - 'none'
    You can extend it with active-relabelling logic.
    """
    def __init__(self,
                 strategy: Literal["none", "bandit_weight", "cleanlab"] = "none",
                 clf_factory: Callable[[], Any] = None,
                 strategy_kwargs: Optional[Dict[str, Any]] = None):
        self.strategy = strategy
        self.clf_factory = clf_factory or (lambda: LogisticRegression(max_iter=300, n_jobs=-1))
        self.strategy_kwargs = strategy_kwargs or {}
        self.sample_weight_ = None
        self._cleanlab_model = None

    def fit(self, X, y) -> "NoisyDatasetCleaner":
        if self.strategy == "none":
            self.sample_weight_ = np.ones(len(y), dtype=float)

        elif self.strategy == "bandit_weight":
            self._fit_bandit_weight(X, y)

        elif self.strategy == "cleanlab":
            self._fit_cleanlab(X, y)

        else:
            raise ValueError(f"Unknown strategy: {self.strategy}")

        return self

    def _fit_bandit_weight(self, X, y):
        # train base model
        base_clf = self.clf_factory()
        base_clf.fit(X, y)
        p = base_clf.predict_proba(X)[:, 1]
        conf = np.abs(p - 0.5)

        # simple discrete bandit over thresholds, like earlier
        candidate_policies = self.strategy_kwargs.get("candidate_policies") or [
            {"theta": 0.05, "noisy_weight": 0.5},
            {"theta": 0.10, "noisy_weight": 0.5},
            {"theta": 0.10, "noisy_weight": 0.2},
            {"theta": 0.15, "noisy_weight": 0.2},
            {"theta": 0.20, "noisy_weight": 0.0},
            {"theta": 0.30, "noisy_weight": 0.0},
        ]
        n_trials = self.strategy_kwargs.get("n_trials", 10)
        epsilon  = self.strategy_kwargs.get("epsilon", 0.3)
        X_train, X_val, y_train, y_val, conf_train = train_test_split(
            X, y, conf, test_size=0.2, random_state=0, stratify=y
        )

        def apply_policy(policy, conf_vec):
            w = np.ones_like(conf_vec, dtype=float)
            w[conf_vec < policy["theta"]] = policy["noisy_weight"]
            return w

        K = len(candidate_policies)
        Q = np.zeros(K)
        N = np.zeros(K)

        for _ in range(n_trials):
            if np.random.rand() < epsilon:
                k = np.random.randint(K)
            else:
                k = np.argmax(Q)

            policy = candidate_policies[k]
            weights_train = apply_policy(policy, conf_train)

            clf = self.clf_factory()
            clf.fit(X_train, y_train, sample_weight=weights_train)
            y_pred_val = clf.predict(X_val)
            reward = f1_score(y_val, y_pred_val)

            N[k] += 1
            Q[k] += (reward - Q[k]) / N[k]

        best_k = np.argmax(Q)
        best_policy = candidate_policies[best_k]

        # final weights on full data
        self.sample_weight_ = apply_policy(best_policy, conf)

    def _fit_cleanlab(self, X, y):
        from cleanlab.classification import CleanLearning

        base_clf = self.clf_factory()
        cl = CleanLearning(clf=base_clf)
        cl.fit(X, y)
        self._cleanlab_model = cl

        # label issues summary -> create weights
        issues = cl.get_label_issues()
        is_issue = issues["is_label_issue"].values
        # simple scheme: 0.3 weight for suspected issues
        w = np.ones(len(y), dtype=float)
        w[is_issue] = 0.3
        self.sample_weight_ = w

    def get_weights(self) -> np.ndarray:
        if self.sample_weight_ is None:
            raise RuntimeError("Call fit() first")
        return self.sample_weight_

    def fit_clean_model(self, X, y):
        """Train a final classifier using the learned sample weights."""
        w = self.get_weights()
        clf = self.clf_factory()
        clf.fit(X, y, sample_weight=w)
        return clf

In [None]:
cleaner = NoisyDatasetCleaner(strategy="bandit_weight")
cleaner.fit(X_train, y_train)
weights = cleaner.get_weights()

clf_clean = cleaner.fit_clean_model(X_train, y_train)
f1_clean = eval_model(clf_clean, X_val, y_val)
print("F1 with NoisyDatasetCleaner (bandit_weight):", f1_clean)

# Or:
cleaner_cl = NoisyDatasetCleaner(strategy="cleanlab")
cleaner_cl.fit(X_train, y_train)
clf_cleanlab = cleaner_cl.fit_clean_model(X_train, y_train)
print("F1 with NoisyDatasetCleaner (cleanlab):", eval_model(clf_cleanlab, X_val, y_val))

In [None]:
# old dont run
from typing import Optional, Literal, Dict, Any, Tuple, Callable

class NoisyDatasetCleaner:
    """
    A generic wrapper for noisy supervised datasets.
    Strategies:
      - 'bandit_weight'
      - 'cleanlab'
      - 'none'
    You can extend it with active-relabelling logic.
    """
    def __init__(self,
                 strategy: Literal["none", "bandit_weight", "cleanlab"] = "none",
                 clf_factory: Callable[[], Any] = None,
                 strategy_kwargs: Optional[Dict[str, Any]] = None):
        self.strategy = strategy
        self.clf_factory = clf_factory or (lambda: LogisticRegression(max_iter=300, n_jobs=-1))
        self.strategy_kwargs = strategy_kwargs or {}
        self.sample_weight_ = None
        self._cleanlab_model = None

    def fit(self, X, y) -> "NoisyDatasetCleaner":
        if self.strategy == "none":
            self.sample_weight_ = np.ones(len(y), dtype=float)

        elif self.strategy == "bandit_weight":
            self._fit_bandit_weight(X, y)

        elif self.strategy == "cleanlab":
            self._fit_cleanlab(X, y)

        else:
            raise ValueError(f"Unknown strategy: {self.strategy}")

        return self

    def _fit_bandit_weight(self, X, y):
        # train base model
        base_clf = self.clf_factory()
        base_clf.fit(X, y)
        p = base_clf.predict_proba(X)[:, 1]
        conf = np.abs(p - 0.5)

        # simple discrete bandit over thresholds, like earlier
        candidate_policies = self.strategy_kwargs.get("candidate_policies") or [
            {"theta": 0.05, "noisy_weight": 0.5},
            {"theta": 0.10, "noisy_weight": 0.5},
            {"theta": 0.10, "noisy_weight": 0.2},
            {"theta": 0.15, "noisy_weight": 0.2},
            {"theta": 0.20, "noisy_weight": 0.0},
            {"theta": 0.30, "noisy_weight": 0.0},
        ]
        n_trials = self.strategy_kwargs.get("n_trials", 10)
        epsilon  = self.strategy_kwargs.get("epsilon", 0.3)
        X_train, X_val, y_train, y_val, conf_train = train_test_split(
            X, y, conf, test_size=0.2, random_state=0, stratify=y
        )

        def apply_policy(policy, conf_vec):
            w = np.ones_like(conf_vec, dtype=float)
            w[conf_vec < policy["theta"]] = policy["noisy_weight"]
            return w

        K = len(candidate_policies)
        Q = np.zeros(K)
        N = np.zeros(K)

        for _ in range(n_trials):
            if np.random.rand() < epsilon:
                k = np.random.randint(K)
            else:
                k = np.argmax(Q)

            policy = candidate_policies[k]
            weights_train = apply_policy(policy, conf_train)

            clf = self.clf_factory()
            clf.fit(X_train, y_train, sample_weight=weights_train)
            y_pred_val = clf.predict(X_val)
            reward = f1_score(y_val, y_pred_val)

            N[k] += 1
            Q[k] += (reward - Q[k]) / N[k]

        best_k = np.argmax(Q)
        best_policy = candidate_policies[best_k]

        # final weights on full data
        self.sample_weight_ = apply_policy(best_policy, conf)

    def _fit_cleanlab(self, X, y):
        from cleanlab.classification import CleanLearning

        base_clf = self.clf_factory()
        cl = CleanLearning(clf=base_clf)
        cl.fit(X, y)
        self._cleanlab_model = cl

        # label issues summary -> create weights
        issues = cl.get_label_issues()
        is_issue = issues["is_label_issue"].values
        # simple scheme: 0.3 weight for suspected issues
        w = np.ones(len(y), dtype=float)
        w[is_issue] = 0.3
        self.sample_weight_ = w

    def get_weights(self) -> np.ndarray:
        if self.sample_weight_ is None:
            raise RuntimeError("Call fit() first")
        return self.sample_weight_

    def fit_clean_model(self, X, y):
        """Train a final classifier using the learned sample weights."""
        w = self.get_weights()
        clf = self.clf_factory()
        clf.fit(X, y, sample_weight=w)
        return clf

In [None]:
# tinyllama_server_kv_stream.py
import asyncio
import json
import time
import uuid
from typing import List, Optional

import torch
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer

# ---------------------------------------------------------
# Model load
# ---------------------------------------------------------
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

print(f"Loading model: {MODEL_NAME}")

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=dtype,
    device_map="auto" if device == "cuda" else None,
)
model.to(device)
model.eval()

# ---------------------------------------------------------
# FastAPI + schemas
# ---------------------------------------------------------
app = FastAPI(title="TinyLlama KV streaming demo")

class ChatMessage(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    model: Optional[str] = None
    messages: List[ChatMessage]
    max_tokens: int = 128
    temperature: float = 0.7
    top_p: float = 1.0
    n: int = 1
    stream: bool = False
    stop: Optional[List[str]] = None

# ---------------------------------------------------------
# Prompt formatting
# ---------------------------------------------------------
def build_prompt(messages: List[ChatMessage]) -> str:
    parts = []
    for m in messages:
        if m.role == "user":
            parts.append(f"User: {m.content}")
        elif m.role == "assistant":
            parts.append(f"Assistant: {m.content}")
        else:
            parts.append(f"{m.role.capitalize()}: {m.content}")
    parts.append("Assistant:")
    return "\n".join(parts)

# ---------------------------------------------------------
# Sampling helpers
# ---------------------------------------------------------

def sample_next_token(
    logits: torch.Tensor,
    temperature: float,
    top_p: float,
) -> int:
    """
    logits: [vocab_size] (for a single position)
    returns: int token id
    """
    if temperature <= 0:
        # greedy
        return int(torch.argmax(logits, dim=-1).item())

    # temperature scaling
    logits = logits / temperature

    # softmax to probs
    probs = torch.softmax(logits, dim=-1)

    if top_p < 1.0:
        # nucleus sampling
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative = torch.cumsum(sorted_probs, dim=-1)

        # keep minimal set that sums to >= top_p
        mask = cumulative - sorted_probs > top_p
        sorted_probs[mask] = 0
        sorted_probs = sorted_probs / sorted_probs.sum()
        idx = torch.multinomial(sorted_probs, 1)
        token_id = sorted_indices[idx]
        return int(token_id.item())
    else:
        # plain multinomial over full vocab
        token_id = torch.multinomial(probs, 1)
        return int(token_id.item())

# ---------------------------------------------------------
# Non-streaming (simple single-call generate)
# ---------------------------------------------------------

async def handle_non_stream(req: ChatRequest):
    prompt = build_prompt(req.messages)

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        gen_ids = model.generate(
            **inputs,
            max_new_tokens=req.max_tokens,
            temperature=req.temperature,
            top_p=req.top_p,
            do_sample=req.temperature > 0,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    input_len = inputs["input_ids"].shape[1]
    new_tokens = gen_ids[0, input_len:]
    text = tokenizer.decode(new_tokens, skip_special_tokens=True)

    prompt_tokens = int(inputs["input_ids"].numel())
    completion_tokens = int(new_tokens.numel())
    total_tokens = prompt_tokens + completion_tokens

    resp = {
        "id": f"chatcmpl-{uuid.uuid4().hex}",
        "object": "chat.completion",
        "created": int(time.time()),
        "model": req.model or MODEL_NAME,
        "choices": [
            {
                "index": 0,
                "message": {"role": "assistant", "content": text},
                "finish_reason": "stop",
            }
        ],
        "usage": {
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens,
            "total_tokens": total_tokens,
        },
    }
    return resp

# ---------------------------------------------------------
# Real per-token streaming with KV cache
# ---------------------------------------------------------

async def handle_stream(req: ChatRequest):
    prompt = build_prompt(req.messages)
    request_id = f"chatcmpl-{uuid.uuid4().hex}"
    model_name = req.model or MODEL_NAME
    created = int(time.time())

    async def event_stream():
        # 1) initial input: full prompt
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        input_ids = inputs["input_ids"]  # [1, seq_len]
        attention_mask = inputs["attention_mask"]

        max_new_tokens = req.max_tokens
        eos_id = tokenizer.eos_token_id

        # Send initial chunk with role (OpenAI-style)
        first_chunk = {
            "id": request_id,
            "object": "chat.completion.chunk",
            "created": created,
            "model": model_name,
            "choices": [
                {
                    "index": 0,
                    "delta": {"role": "assistant"},
                    "finish_reason": None,
                }
            ],
        }
        yield f"data: {json.dumps(first_chunk)}\n\n"

        past_key_values = None
        generated = []
        finish_reason = None

        for step in range(max_new_tokens):
            # 2) forward pass: either full prompt (first step) or just last token (subsequent steps)
            with torch.no_grad():
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    past_key_values=past_key_values,
                    use_cache=True,
                )

            logits = outputs.logits[:, -1, :]  # [1, vocab]
            past_key_values = outputs.past_key_values

            # 3) sample next token
            next_token_id = sample_next_token(
                logits[0], req.temperature, req.top_p
            )
            generated.append(next_token_id)

            if next_token_id == eos_id:
                finish_reason = "stop"
                break

            # 4) decode just this token to text piece
            token_text = tokenizer.decode([next_token_id], skip_special_tokens=True)

            if token_text:
                chunk = {
                    "id": request_id,
                    "object": "chat.completion.chunk",
                    "created": created,
                    "model": model_name,
                    "choices": [
                        {
                            "index": 0,
                            "delta": {"content": token_text},
                            "finish_reason": None,
                        }
                    ],
                }
                yield f"data: {json.dumps(chunk)}\n\n"

            # 5) prepare for next step: feed only this token, no need to resend the whole prompt
            input_ids = torch.tensor([[next_token_id]], device=device)
            attention_mask = None  # not strictly needed when using past with single token

            # let event loop breathe a bit
            await asyncio.sleep(0)

        if finish_reason is None:
            finish_reason = "length"

        # final empty delta with finish_reason
        done_chunk = {
            "id": request_id,
            "object": "chat.completion.chunk",
            "created": created,
            "model": model_name,
            "choices": [
                {
                    "index": 0,
                    "delta": {},
                    "finish_reason": finish_reason,
                }
            ],
        }
        yield f"data: {json.dumps(done_chunk)}\n\n"
        yield "data: [DONE]\n\n"

    return StreamingResponse(event_stream(), media_type="text/event-stream")

# ---------------------------------------------------------
# Main endpoint
# ---------------------------------------------------------

@app.post("/v1/chat/completions")
async def chat_completions(req: ChatRequest):
    if req.stream:
        return await handle_stream(req)
    else:
        return await handle_non_stream(req)

In [None]:
!pip install "fastapi[all]" uvicorn transformers torch
uvicorn tinyllama_server_kv_stream:app --host 0.0.0.0 --port 9000