[![Run in Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bhosmer/fold/blob/main/notebooks/broadcast.ipynb)

#### Colab setup (skip if running locally)

In [63]:
!git clone https://github.com/bhosmer/fold.git
import sys
sys.path.insert(0,'/content/fold')

Cloning into 'fold'...
remote: Enumerating objects: 102, done.[K
remote: Counting objects: 100% (98/98), done.[K
remote: Compressing objects: 100% (69/69), done.[K
remote: Total 102 (delta 51), reused 67 (delta 28), pack-reused 4[K
Receiving objects: 100% (102/102), 112.35 KiB | 5.62 MiB/s, done.
Resolving deltas: 100% (51/51), done.


In [2]:
from fold import *

## Broadcasting ragged arrays

As in PyTorch, broadcast is a wrapper around `expand`, which is defined as usual: singleton dimensions of the target array can be expanded, and other dimensions of the expansion shape need to match or be singletons.

### Expansion within a ragged dimension

Expansion within (to the right of) a ragged dimension leaves the ragged dimension unaffected:

In [100]:
a = arange(2, [2, 3], 1)
print("a:")
print("shape:", a.shape)
print(a)

ex = a.expand(-1, -1, 3)
print("\nexpanded at dim 2:")
print("shape:", ex.shape)
print(ex)

a:
shape: (2, [2, 3], 1)
[[[0],
  [1]],

 [[2],
  [3],
  [4]]]

expanded at dim 2:
shape: (2, [2, 3], 3)
[[[0, 0, 0],
  [1, 1, 1]],

 [[2, 2, 2],
  [3, 3, 3],
  [4, 4, 4]]]


### Expansion outside a ragged dimension

Expansion *outside* (to the left of) a ragged dimension can have different effects, depending on whether the ragged dimension nests within the subarray being expanded.

#### Nested outside expansion

When the ragged dimension nests within the subarray being expanded, it repeats (along with the rest of the subarray shape). For example:

In [91]:
a = arange(1, 3, [1, 2, 3])
print("a:")
print("shape:", a.shape)
print(a)

print("\nexpanded at dim 0:")
ex = a.expand(2, -1, -1)
print("shape:", ex.shape)
print(ex)

a:
shape: (1, 3, [1, 2, 3])
[[[0],
  [1, 2],
  [3, 4, 5]]]

expanded at dim 0:
shape: (2, 3, Repeat([1, 2, 3], 2))
[[[0],
  [1, 2],
  [3, 4, 5]],

 [[0],
  [1, 2],
  [3, 4, 5]]]


#### Non-nested outside expansion

On the other hand, when the ragged dimension does not nest within the dimension being expanded, the *components* of the ragged dimension are expanded, rather than the entire dimension. This usually manifests as the introduction of `runs`. 

For example:

In [101]:
a = arange(3, 1, [1, 2, 3])
print("a:")
print("shape:", a.shape)
print(a)

print("\nexpanded at dim 1:")
ex = a.expand(-1, 3, -1)
print("shape:", ex.shape)
print(ex)

a:
shape: (3, 1, [1, 2, 3])
[[[0]],

 [[1, 2]],

 [[3, 4, 5]]]

expanded at dim 1:
shape: (3, 3, Runs([1, 2, 3], 3))
[[[0],
  [0],
  [0]],

 [[1, 2],
  [1, 2],
  [1, 2]],

 [[3, 4, 5],
  [3, 4, 5],
  [3, 4, 5]]]


### Ragged expansions

There can also be ragged expansions (where the expansion shape contains ragged dimensions):

In [93]:
batch_size = 2
seq_lens = [3, 4]
embed_dim = 4 # leave num_heads * head_dim flattened for example

print(f"{batch_size=} {seq_lens=} {embed_dim=}")

print("\nseqs:")
seqs = rand(batch_size, seq_lens, embed_dim)
print("shape:", seqs.shape)
print(seqs)

print("\nembedding dim indexes:")
channels = arange(embed_dim)
print("shape:", channels.shape)
print(channels)

print("\nembedding dim indexes expanded to (batch_size, seq_lens):")
seq_channels = channels.expand(batch_size, seq_lens, -1)
print("shape:", seq_channels.shape)
print(seq_channels)

batch_size=2 seq_lens=[3, 4] embed_dim=4

seqs:
shape: (2, [3, 4], 4)
[[[0.0135, 0.9317, 0.2856, 0.1192],
  [0.8705, 0.3260, 0.5224, 0.1272],
  [0.3315, 0.6040, 0.1599, 0.9292]],

 [[0.7621, 0.8405, 0.6077, 0.4733],
  [0.4188, 0.9886, 0.5000, 0.2208],
  [0.1477, 0.8395, 0.2176, 0.3367],
  [0.8606, 0.9319, 0.9540, 0.5876]]]

embedding dim indexes:
shape: (4)
[0, 1, 2, 3]

embedding dim indexes expanded to (batch_size, seq_lens):
shape: (2, [3, 4], 4)
[[[0, 1, 2, 3],
  [0, 1, 2, 3],
  [0, 1, 2, 3]],

 [[0, 1, 2, 3],
  [0, 1, 2, 3],
  [0, 1, 2, 3],
  [0, 1, 2, 3]]]


### Ragged dimensions and ragged expansions

Ragged dimensions can interact with ragged expansions:

In [102]:
batch_size = 3
image_heights = [2, 4, 3]
image_wids = Runs([4, 8, 6], [2, 4, 3])

print(f"{batch_size=} {image_heights=} {image_wids=}")

print("\nimages:")
images = rand(batch_size, image_heights, image_wids)
print("shape:", images.shape)
print(images)

batch_size=3 image_heights=[2, 4, 3] image_wids=Runs(Seq([4, 8, 6]), Seq([2, 4, 3]))

images:
shape: (3, [2, 4, 3], Runs([4, 8, 6], [2, 4, 3]))
[[[0.7037, 0.7386, 0.4093, 0.2948],
  [0.8343, 0.1320, 0.0042, 0.4360]],

 [[0.2702, 0.0008, 0.4312, 0.4802, 0.5988, 0.1102, 0.6474, 0.5451],
  [0.9821, 0.0201, 0.5629, 0.0831, 0.3201, 0.7943, 0.0309, 0.8344],
  [0.8845, 0.1532, 0.4337, 0.3896, 0.3479, 0.8052, 0.1517, 0.3859],
  [0.2072, 0.7666, 0.0868, 0.7025, 0.4260, 0.6807, 0.1788, 0.5760]],

 [[0.1384, 0.6017, 0.3129, 0.6903, 0.8999, 0.0268],
  [0.5971, 0.6361, 0.2023, 0.9147, 0.3801, 0.3365],
  [0.8672, 0.5390, 0.6838, 0.2056, 0.4967, 0.8730]]]


In [103]:
print("\nimage indexes (unsqueezed):")
image_indexes = arange(batch_size, 1, 1)
print("shape:", image_indexes.shape)
print(image_indexes)

print("\nrimage indexes expanded to image shapes:")
image_indexes_exp = image_indexes.expand(batch_size, image_heights, image_wids)
print("shape:", image_indexes_exp.shape)
print(image_indexes_exp)


image indexes (unsqueezed):
shape: (3, 1, 1)
[[[0]],

 [[1]],

 [[2]]]

rimage indexes expanded to image shapes:
shape: (3, [2, 4, 3], Runs([4, 8, 6], [2, 4, 3]))
[[[0, 0, 0, 0],
  [0, 0, 0, 0]],

 [[1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1]],

 [[2, 2, 2, 2, 2, 2],
  [2, 2, 2, 2, 2, 2],
  [2, 2, 2, 2, 2, 2]]]


In [104]:
print("\nrow indexes (unsqueezed):")
row_indexes = arange(batch_size, image_heights, 1)
print("shape:", row_indexes.shape)
print(row_indexes)

print("\nrow indexes expanded to image shapes:")
row_indexes_exp = row_indexes.expand(batch_size, image_heights, image_wids)
print("shape:", row_indexes_exp.shape)
print(row_indexes_exp)


row indexes (unsqueezed):
shape: (3, [2, 4, 3], 1)
[[[0],
  [1]],

 [[2],
  [3],
  [4],
  [5]],

 [[6],
  [7],
  [8]]]

row indexes expanded to image shapes:
shape: (3, [2, 4, 3], Runs([4, 8, 6], [2, 4, 3]))
[[[0, 0, 0, 0],
  [1, 1, 1, 1]],

 [[2, 2, 2, 2, 2, 2, 2, 2],
  [3, 3, 3, 3, 3, 3, 3, 3],
  [4, 4, 4, 4, 4, 4, 4, 4],
  [5, 5, 5, 5, 5, 5, 5, 5]],

 [[6, 6, 6, 6, 6, 6],
  [7, 7, 7, 7, 7, 7],
  [8, 8, 8, 8, 8, 8]]]
