# Patch extraction optimization: why `unfold` wins

This notebook demonstrates how to speed up patch extraction from 2D arrays by 10-20×.

## The problem

In ptychography (and many imaging applications), we need to extract many overlapping patches from a large object array. A naive loop-based approach is slow because it launches a separate GPU kernel for each patch.

## The solution

Use `torch.unfold()` which is a native sliding window operation optimized at the C++/CUDA level.

## Why unfold is the best

| Factor | Loop | unfold |
|--------|------|--------|
| **Kernel launches** | N (one per patch) | 1 (single operation) |
| **Memory access** | Random, cache-unfriendly | Sequential, cache-optimized |
| **Python overhead** | High (loop + append) | None (pure C++) |
| **GPU parallelism** | Serialized | Fully parallel |

The key insight: `unfold` doesn't copy data. It creates a *view* with clever stride manipulation, making it essentially free for the extraction step.

In [1]:
import torch
import time

device = torch.device('cuda' if torch.cuda.is_available() else 
                      'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Device: {device}")

Device: mps


## Setup: test data

In [2]:
# Configuration
OBJECT_SIZE = 256
PROBE_SIZE = 80
SCAN_GRID = 64  # 64×64 = 4096 patches

# Create complex-valued object
obj = torch.complex(
    torch.randn(OBJECT_SIZE, OBJECT_SIZE, device=device),
    torch.randn(OBJECT_SIZE, OBJECT_SIZE, device=device)
)

# Compute step size for uniform grid coverage
step = (OBJECT_SIZE - PROBE_SIZE) / (SCAN_GRID - 1)
step_int = max(1, int(round(step)))

# Generate scan positions
positions = [(int(i * step), int(j * step)) 
             for i in range(SCAN_GRID) for j in range(SCAN_GRID)]

print(f"Object: {OBJECT_SIZE}×{OBJECT_SIZE}, Patch: {PROBE_SIZE}×{PROBE_SIZE}")
print(f"Positions: {len(positions)}, Step: {step:.2f}px, Overlap: {(1-step/PROBE_SIZE)*100:.0f}%")

Object: 256×256, Patch: 80×80
Positions: 4096, Step: 2.79px, Overlap: 97%


## Method 1: slow loop (baseline)

In [3]:
def extract_slow(obj, positions, size):
    """Extract patches one at a time with a Python loop."""
    patches = []
    for y, x in positions:
        patches.append(obj[y:y+size, x:x+size])
    return torch.stack(patches)

In [4]:
def sync():
    if device.type == 'cuda': torch.cuda.synchronize()
    elif device.type == 'mps': torch.mps.synchronize()

# Warmup and benchmark
_ = extract_slow(obj, positions[:10], PROBE_SIZE)
sync()

N = 50
t0 = time.perf_counter()
for _ in range(N):
    patches_slow = extract_slow(obj, positions, PROBE_SIZE)
sync()
time_slow = (time.perf_counter() - t0) / N * 1000

print(f"Loop method: {time_slow:.2f} ms")

Loop method: 23.31 ms


## Method 2: fast unfold

In [5]:
def extract_unfold(obj, size, step):
    """Extract all patches at once using sliding window."""
    # unfold(dim, size, step) creates a view with sliding windows
    patches = obj.unfold(0, size, step).unfold(1, size, step)
    return patches.reshape(-1, size, size)

In [6]:
# Warmup and benchmark
_ = extract_unfold(obj, PROBE_SIZE, step_int)
sync()

t0 = time.perf_counter()
for _ in range(N):
    patches_fast = extract_unfold(obj, PROBE_SIZE, step_int)
sync()
time_fast = (time.perf_counter() - t0) / N * 1000

print(f"Unfold method: {time_fast:.2f} ms")
print(f"Speedup: {time_slow/time_fast:.1f}×")

Unfold method: 2.44 ms
Speedup: 9.5×


## Benchmark across different configurations

In [7]:
# (object_size, probe_size, scan_grid)
#  - object_size: size of the 2D object array (NxN)
#  - probe_size: size of each patch to extract (PxP)  
#  - scan_grid: number of scan positions per dimension (GxG total patches)
N_BENCH = 100  # iterations per benchmark

CONFIGS = [
    (128, 32, 32),    # 128×128 object, 32×32 patches, 1024 positions
    (256, 80, 64),    # 256×256 object, 80×80 patches, 4096 positions
    (512, 64, 128),   # 512×512 object, 64×64 patches, 16384 positions
]

print(f"{'Config':>20s} | {'Loop (ms)':>10s} | {'Unfold (ms)':>12s} | {'Speedup':>8s}")
print("-" * 60)

for obj_size, probe_size, scan_grid in CONFIGS:
    # Create test object
    test_obj = torch.complex(
        torch.randn(obj_size, obj_size, device=device),
        torch.randn(obj_size, obj_size, device=device)
    )
    
    # Compute positions
    step = (obj_size - probe_size) / (scan_grid - 1)
    step_int = max(1, int(round(step)))
    pos = [(int(i*step), int(j*step)) for i in range(scan_grid) for j in range(scan_grid)]
    # Benchmark loop
    sync()
    t0 = time.perf_counter()
    for _ in range(N_BENCH):
        _ = extract_slow(test_obj, pos, probe_size)
    sync()
    t_loop = (time.perf_counter() - t0) / N_BENCH * 1000
    
    # Benchmark unfold
    sync()
    t0 = time.perf_counter()
    for _ in range(N_BENCH):
        _ = extract_unfold(test_obj, probe_size, step_int)
    sync()
    t_unfold = (time.perf_counter() - t0) / N_BENCH * 1000
    
    config = f"{obj_size}×{probe_size}×{scan_grid}"
    print(f"{config:>20s} | {t_loop:>10.2f} | {t_unfold:>12.2f} | {t_loop/t_unfold:>7.1f}×")

              Config |  Loop (ms) |  Unfold (ms) |  Speedup
------------------------------------------------------------
           128×32×32 |       4.56 |         0.09 |    50.1×
           256×80×64 |      18.24 |         1.95 |     9.3×
          512×64×128 |     113.19 |         3.73 |    30.3×


## Appendix: how unfold works

Now that you've seen the speedup, here's how `unfold` actually works.

`unfold(dim, size, step)` extracts sliding windows along a dimension.

### 1D example

```text
tensor: [0, 1, 2, 3, 4, 5, 6, 7]

unfold(dim=0, size=3, step=2):

  Start at 0:  [0, 1, 2]
  Start at 2:  [2, 3, 4]
  Start at 4:  [4, 5, 6]

Result: [[0, 1, 2],
         [2, 3, 4],
         [4, 5, 6]]

Shape: (3, 3) = (num_windows, window_size)
```

### 2D example (extracting patches)

We apply unfold twice: once for rows, once for columns.

```text
Original 4×4 grid:

     0   1   2   3
   ┌───┬───┬───┬───┐
 0 │ 0 │ 1 │ 2 │ 3 │
   ├───┼───┼───┼───┤
 1 │ 4 │ 5 │ 6 │ 7 │
   ├───┼───┼───┼───┤
 2 │ 8 │ 9 │10 │11 │
   ├───┼───┼───┼───┤
 3 │12 │13 │14 │15 │
   └───┴───┴───┴───┘
```

**Step 1: `unfold(dim=0, size=2, step=1)`** — slide 2-row window down

```text
Window at row 0:     Window at row 1:     Window at row 2:
┌───┬───┬───┬───┐    ┌───┬───┬───┬───┐    ┌───┬───┬───┬───┐
│ 0 │ 1 │ 2 │ 3 │    │ 4 │ 5 │ 6 │ 7 │    │ 8 │ 9 │10 │11 │
├───┼───┼───┼───┤    ├───┼───┼───┼───┤    ├───┼───┼───┼───┤
│ 4 │ 5 │ 6 │ 7 │    │ 8 │ 9 │10 │11 │    │12 │13 │14 │15 │
└───┴───┴───┴───┘    └───┴───┴───┴───┘    └───┴───┴───┴───┘

Shape after step 1: (3, 4, 2) = (3 row positions, 4 columns, 2 rows each)
```

**Step 2: `unfold(dim=1, size=2, step=1)`** — slide 2-col window across

```text
Now we have 3×3 = 9 patches of size 2×2:

Patch[0,0]  Patch[0,1]  Patch[0,2]    (row window 0)
┌───┬───┐  ┌───┬───┐  ┌───┬───┐
│ 0 │ 1 │  │ 1 │ 2 │  │ 2 │ 3 │
├───┼───┤  ├───┼───┤  ├───┼───┤
│ 4 │ 5 │  │ 5 │ 6 │  │ 6 │ 7 │
└───┴───┘  └───┴───┘  └───┴───┘

Patch[1,0]  Patch[1,1]  Patch[1,2]    (row window 1)
┌───┬───┐  ┌───┬───┐  ┌───┬───┐
│ 4 │ 5 │  │ 5 │ 6 │  │ 6 │ 7 │
├───┼───┤  ├───┼───┤  ├───┼───┤
│ 8 │ 9 │  │ 9 │10 │  │10 │11 │
└───┴───┘  └───┴───┘  └───┴───┘

Patch[2,0]  Patch[2,1]  Patch[2,2]    (row window 2)
┌───┬───┐  ┌───┬───┐  ┌───┬───┐
│ 8 │ 9 │  │ 9 │10 │  │10 │11 │
├───┼───┤  ├───┼───┤  ├───┼───┤
│12 │13 │  │13 │14 │  │14 │15 │
└───┴───┘  └───┴───┘  └───┴───┘

Final shape: (3, 3, 2, 2) = (3 row pos, 3 col pos, 2×2 patch)
After reshape(-1, 2, 2): (9, 2, 2) = 9 patches of 2×2
```

In [8]:
# 1D example
x = torch.arange(8)
print("Original 1D tensor:", x.tolist())

windows = x.unfold(0, 3, 2)
print(f"After unfold(dim=0, size=3, step=2): shape {windows.shape}")
print(windows)

Original 1D tensor: [0, 1, 2, 3, 4, 5, 6, 7]
After unfold(dim=0, size=3, step=2): shape torch.Size([3, 3])
tensor([[0, 1, 2],
        [2, 3, 4],
        [4, 5, 6]])


In [9]:
# 2D example - extracting patches
grid = torch.arange(16).reshape(4, 4)
print("Original 4×4 grid:")
print(grid)
print()

# Double unfold for 2D patches
patches = grid.unfold(0, 2, 1).unfold(1, 2, 1)
print(f"After double unfold: shape {patches.shape}")
print()

# Reshape to list of patches
patches_flat = patches.reshape(-1, 2, 2)
print(f"Reshaped to {patches_flat.shape}: 9 patches of 2×2")
print("First 3 patches:")
for i in range(3):
    print(f"  Patch {i}: {patches_flat[i].tolist()}")

Original 4×4 grid:
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

After double unfold: shape torch.Size([3, 3, 2, 2])

Reshaped to torch.Size([9, 2, 2]): 9 patches of 2×2
First 3 patches:
  Patch 0: [[0, 1], [4, 5]]
  Patch 1: [[1, 2], [5, 6]]
  Patch 2: [[2, 3], [6, 7]]


### Why it's fast: zero-copy view

The magic is that `unfold` doesn't copy data. It creates a **view** by manipulating tensor strides.

In [10]:
# unfold creates a VIEW, not a copy
patches_view = obj.unfold(0, PROBE_SIZE, step_int).unfold(1, PROBE_SIZE, step_int)

print("Original object:")
print(f"  Shape: {obj.shape}, Strides: {obj.stride()}")
print(f"  Storage size: {obj.storage().size()} elements")
print()
print("After unfold (VIEW, no copy!):")
print(f"  Shape: {patches_view.shape}, Strides: {patches_view.stride()}")
print(f"  Storage size: {patches_view.storage().size()} elements (same!)")

Original object:
  Shape: torch.Size([256, 256]), Strides: (256, 1)
  Storage size: 65536 elements

After unfold (VIEW, no copy!):
  Shape: torch.Size([45, 45, 80, 80]), Strides: (1024, 4, 256, 1)
  Storage size: 65536 elements (same!)


  print(f"  Storage size: {obj.storage().size()} elements")
