In [1]:
import numpy as np

# --------- Utilities ---------
def arrays_allclose(a: np.ndarray, b: np.ndarray, rtol=1e-12, atol=0.0) -> bool:
    return np.allclose(a, b, rtol=rtol, atol=atol)

def sum_array(a: np.ndarray) -> float:
    return float(np.sum(a))

# --------- Baseline: single-loop DAXPY ---------
def daxpy_single(a: float, x: np.ndarray, y: np.ndarray) -> np.ndarray:
    """
    Reference implementation using ONE loop over all elements.
    d[i] = a * x[i] + y[i]
    """
    if x.shape != y.shape:
        raise ValueError("x and y must have the same shape")
    d = np.empty_like(x, dtype=float)
    # single loop
    for i in range(x.size):
        d[i] = a * x[i] + y[i]
    return d

# --------- Chunked DAXPY + partial sums ---------
def daxpy_chunked(a: float, x: np.ndarray, y: np.ndarray, chunk_size: int):
    """
    Chunked implementation:
      - Outer loop over chunks
      - Inner loop over elements in the chunk
      - Returns (d, partial_chunk_sum)
    """
    if chunk_size <= 0:
        raise ValueError("chunk_size must be >= 1")
    if x.shape != y.shape:
        raise ValueError("x and y must have the same shape")

    n = x.size
    d = np.empty_like(x, dtype=float)

    # number of chunks (ceiling division)
    chunks = (n + chunk_size - 1) // chunk_size
    partial_chunk_sum = np.empty(chunks, dtype=float)

    for k in range(chunks):
        start = k * chunk_size
        end = min(start + chunk_size, n)

        local_sum = 0.0
        for i in range(start, end):
            d[i] = a * x[i] + y[i]
            local_sum += d[i]

        partial_chunk_sum[k] = local_sum

    return d, partial_chunk_sum

# --------- Demo / self-check ---------
if __name__ == "__main__":
    # Parameters — change as you like
    n = 100
    chunk_size = 8
    a = 2.0

    # Reproducible data in [-1, 1]
    rng = np.random.default_rng(42)
    x = rng.uniform(-1.0, 1.0, size=n)
    y = rng.uniform(-1.0, 1.0, size=n)

    # 1) Single-loop baseline
    d_single = daxpy_single(a, x, y)

    # 2) Chunked version + partial sums
    d_chunk, partial_chunk_sum = daxpy_chunked(a, x, y, chunk_size)

    # A) Element-wise equality
    same = arrays_allclose(d_single, d_chunk, rtol=1e-12, atol=0.0)
    print(f"[CHECK] d(single) == d(chunked)? {'YES' if same else 'NO'}")

    # B) Sum(partials) == Sum(d_single)
    sum_partials = sum_array(partial_chunk_sum)
    sum_single = sum_array(d_single)
    diff = abs(sum_partials - sum_single)
    ok_sum = np.isclose(sum_partials, sum_single, rtol=1e-12, atol=0.0)

    print(
        "[CHECK] sum(partials) = {:.15f} | sum(d_single) = {:.15f} | |diff| = {:.3e}  => {}"
        .format(sum_partials, sum_single, diff, "MATCH" if ok_sum else "MISMATCH")
    )

    # Optional: quick assertion to fail fast in notebooks/scripts
    assert same and ok_sum, "Chunked implementation or partial sums check failed!"


[CHECK] d(single) == d(chunked)? YES
[CHECK] sum(partials) = -5.108486612323954 | sum(d_single) = -5.108486612323953 | |diff| = 8.882e-16  => MATCH
