# 8 Masked "last valid" gather + pooled stats

Theme: variable-length sequences, masking, advanced indexing, reductions

You’re given batched sequences with padding.

x: float tensor of shape [B, T, D]

valid: bool tensor of shape [B, T] where True means the timestep is real data, False is padding

Assume each batch element has at least one valid timestep.

## Requirements

- No looping over B or T.
- Use masking + reductions correctly (avoid padded values corrupting max / mean).
- Must handle cases where valid timesteps are not contiguous (e.g., True, False, True).

## Example

If valid[b] = [True, True, False, True, False], then last[b] = x[b, 3].

## What it tests

- Turning boolean masks into indices (e.g., “last True per row”)
- Masked reductions (mean, max)
- Numerical safety (avoid dividing by zero; although guaranteed at least one True, candidates should still code defensively)

## Nice follow-ups
- Also return argmax_t: the time index of max value per feature or per sequence (clarify axis!)
- Compute mean and var in one pass (stability)

## Implement the Function

In [48]:
import numpy as np

def summarize_sequences(x, valid):
    """
    Returns:
      last: [B, D]   # last valid vector per sequence
      mean: [B, D]   # mean over valid timesteps
      max_: [B, D]   # max over valid timesteps
    """
    valid_data = x * np.repeat(valid[:, :, None], x.shape[-1], axis=2)
    num_valids = valid.sum(axis=1)[:, None]

    numerate = np.array([np.arange(x.shape[1]) for _ in range(x.shape[0])])
    numerate = numerate * valid

    last_valid_index = numerate.argmax(axis=1)
    last = x[np.arange(x.shape[0]), last_valid_index, :]
    mean = valid_data.sum(axis=1) / num_valids

    masked_data = np.where(valid[..., None], x, -np.inf)
    max_ = masked_data.max(axis=1)
    return dict(
        last=last,
        mean=mean,
        max=max_
    )

## Check Test Cases

In [20]:
test_cases = [
    # 1) Basic contiguous padding
    {
        "x": np.array(
            [
                [[1.0, 10.0], [2.0, 20.0], [3.0, 30.0], [0.0, 0.0]],
                [[4.0, 40.0], [5.0, 50.0], [0.0, 0.0], [0.0, 0.0]],
            ],
            dtype=np.float32,
        ),
        "valid": np.array(
            [
                [True, True, True, False],
                [True, True, False, False],
            ],
            dtype=bool,
        ),
        "output": {
            "last": np.array([[3.0, 30.0], [5.0, 50.0]], dtype=np.float32),
            "mean": np.array([[2.0, 20.0], [4.5, 45.0]], dtype=np.float32),
            "max":  np.array([[3.0, 30.0], [5.0, 50.0]], dtype=np.float32),
        },
    },

    # 2) Non-contiguous valid mask
    {
        "x": np.array(
            [
                [[1.0, -1.0], [100.0, -100.0], [2.0, -2.0], [3.0, -3.0]],
            ],
            dtype=np.float32,
        ),
        "valid": np.array([[True, False, True, False]], dtype=bool),
        "output": {
            "last": np.array([[2.0, -2.0]], dtype=np.float32),
            "mean": np.array([[1.5, -1.5]], dtype=np.float32),
            "max":  np.array([[2.0, -1.0]], dtype=np.float32),
        },
    },

    # 3) Valid only at final timestep
    {
        "x": np.array(
            [
                [[9.0, 1.0], [8.0, 2.0], [7.0, 3.0], [6.0, 4.0]],
            ],
            dtype=np.float32,
        ),
        "valid": np.array([[False, False, False, True]], dtype=bool),
        "output": {
            "last": np.array([[6.0, 4.0]], dtype=np.float32),
            "mean": np.array([[6.0, 4.0]], dtype=np.float32),
            "max":  np.array([[6.0, 4.0]], dtype=np.float32),
        },
    },

    # 4) Single timestep sequences (T = 1)
    {
        "x": np.array(
            [
                [[-5.0, 0.5]],
                [[10.0, -0.5]],
            ],
            dtype=np.float32,
        ),
        "valid": np.array([[True], [True]], dtype=bool),
        "output": {
            "last": np.array([[-5.0, 0.5], [10.0, -0.5]], dtype=np.float32),
            "mean": np.array([[-5.0, 0.5], [10.0, -0.5]], dtype=np.float32),
            "max":  np.array([[-5.0, 0.5], [10.0, -0.5]], dtype=np.float32),
        },
    },

    # 5) All timesteps valid
    {
        "x": np.array(
            [
                [[1.0, 2.0], [3.0, 4.0]],
                [[-1.0, -2.0], [-3.0, -4.0]],
            ],
            dtype=np.float32,
        ),
        "valid": np.array([[True, True], [True, True]], dtype=bool),
        "output": {
            "last": np.array([[3.0, 4.0], [-3.0, -4.0]], dtype=np.float32),
            "mean": np.array([[2.0, 3.0], [-2.0, -3.0]], dtype=np.float32),
            "max":  np.array([[3.0, 4.0], [-1.0, -2.0]], dtype=np.float32),
        },
    },

    # 6) Padded large values (must be ignored in max)
    {
        "x": np.array(
            [
                [[-10.0, -1.0], [-20.0, -2.0], [999.0, 999.0], [999.0, 999.0]],
            ],
            dtype=np.float32,
        ),
        "valid": np.array([[True, True, False, False]], dtype=bool),
        "output": {
            "last": np.array([[-20.0, -2.0]], dtype=np.float32),
            "mean": np.array([[-15.0, -1.5]], dtype=np.float32),
            "max":  np.array([[-10.0, -1.0]], dtype=np.float32),
        },
    },

    # 7) Mixed batch + D=3 (broadcasting stress test)
    {
        "x": np.array(
            [
                [[1.0,  2.0,  3.0],
                 [4.0,  5.0,  6.0],
                 [7.0,  8.0,  9.0],
                 [10.0, 11.0, 12.0]],
                [[-1.0, -2.0, -3.0],
                 [100.0, 100.0, 100.0],
                 [200.0, 200.0, 200.0],
                 [300.0, 300.0, 300.0]],
            ],
            dtype=np.float32,
        ),
        "valid": np.array(
            [
                [False, True, False, True],
                [True, False, False, False],
            ],
            dtype=bool,
        ),
        "output": {
            "last": np.array([[10.0, 11.0, 12.0], [-1.0, -2.0, -3.0]], dtype=np.float32),
            "mean": np.array([[7.0, 8.0, 9.0],    [-1.0, -2.0, -3.0]], dtype=np.float32),
            "max":  np.array([[10.0, 11.0, 12.0], [-1.0, -2.0, -3.0]], dtype=np.float32),
        },
    },
]


In [49]:
from tqdm import tqdm

test_func = summarize_sequences
error_list = []; correct = 0
num = len(test_cases)
for i, t in tqdm(enumerate(test_cases)):
    result = test_func(t['x'], t['valid'])
    # if np.all(result == t['output']):
    if np.all(result['last'] == t['output']['last']) and np.all(result['mean'] == t['output']['mean']) and np.all(result['max'] == t['output']['max']):
        correct += 1
    else:
        error_list.append(f'Received {result} on input {t["x"]}, {t["valid"]}. Expected {t["output"]}.')

print(f'Test cases passed: {correct}/{num}')

if correct == num:
    print('Success!')
else:
    [print(e) for e in error_list]

7it [00:00, 1058.40it/s]

Test cases passed: 7/7
Success!





## Solutions

My solution (2026-02-04):

```python
def summarize_sequences(x, valid):
    """
    Returns:
      last: [B, D]   # last valid vector per sequence
      mean: [B, D]   # mean over valid timesteps
      max_: [B, D]   # max over valid timesteps
    """
    valid_data = x * np.repeat(valid[:, :, None], x.shape[-1], axis=2)
    num_valids = valid.sum(axis=1)[:, None]

    numerate = np.array([np.arange(x.shape[1]) for _ in range(x.shape[0])])
    numerate = numerate * valid

    last_valid_index = numerate.argmax(axis=1)
    last = x[np.arange(x.shape[0]), last_valid_index, :]
    mean = valid_data.sum(axis=1) / num_valids

    masked_data = np.where(valid[..., None], x, -np.inf)
    max_ = masked_data.max(axis=1)
    return dict(
        last=last,
        mean=mean,
        max=max_
    )
```

Need to be careful about this line: `last = x[np.arange(x.shape[0]), last_valid_index, :]`. Originally, I wrote it as `last = x[:, last_valid_index, :]`, NumPy indexes `[B, B, D]` instead of `[B, D]` as intended.

This is what ChatGPT came up with:

```python
def summarize_sequences(x: np.ndarray, valid: np.ndarray):
    """
    Args:
      x:     float array [B, T, D]
      valid: bool  array [B, T]

    Returns:
      dict with:
        last: [B, D]  last valid vector per sequence
        mean: [B, D]  mean over valid timesteps
        max:  [B, D]  max over valid timesteps
    """
    B, T, D = x.shape

    # ---- last valid index per batch ----
    # Build timestep indices [1, T] and mask invalid positions to -1.
    t_idx = np.arange(T)[None, :]                 # [1, T]
    masked_idx = np.where(valid, t_idx, -1)       # [B, T]
    last_t = masked_idx.max(axis=1)               # [B]
    # (Spec guarantees at least one True per row, so last_t is always >= 0)

    batch_idx = np.arange(B)
    last = x[batch_idx, last_t, :]                # [B, D]

    # ---- mean over valid ----
    valid_f = valid[..., None].astype(x.dtype)    # [B, T, 1]
    summed = (x * valid_f).sum(axis=1)            # [B, D]
    counts = valid.sum(axis=1, keepdims=True)     # [B, 1]
    mean = summed / counts                        # [B, D]

    # ---- max over valid ----
    # Set invalid positions to -inf so they never win the max.
    masked_x = np.where(valid[..., None], x, -np.inf)  # [B, T, D]
    max_ = masked_x.max(axis=1)                         # [B, D]

    return {"last": last, "mean": mean, "max": max_}
```