[![Run in Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bhosmer/fold/blob/main/notebooks/transpose.ipynb)

#### Colab setup (skip if running locally)

In [63]:
!git clone https://github.com/bhosmer/fold.git
import sys
sys.path.insert(0,'/content/fold')

Cloning into 'fold'...
remote: Enumerating objects: 102, done.[K
remote: Counting objects: 100% (98/98), done.[K
remote: Compressing objects: 100% (69/69), done.[K
remote: Total 102 (delta 51), reused 67 (delta 28), pack-reused 4[K
Receiving objects: 100% (102/102), 112.35 KiB | 5.62 MiB/s, done.
Resolving deltas: 100% (51/51), done.


In [2]:
from fold import *

## Broadcasting ragged arrays

As in PyTorch, broadcast is a wrapper around `expand`, which is defined as usual: singleton dimensions of the target array can be expanded, and other dimensions of the expansion shape need to match or be singletons.

Expansion *outward* of a ragged dimension can have different effects, depending on whether the ragged dimension nests within the subarray being expanded.

When it does, the ragged dimension repeats. For example:

In [58]:
a = arange(1, 3, [1, 2, 3])
print("a:")
print("shape:", a.shape)
print(a)

print("\nexpanded at dim 0:")
ex = a.expand(2, -1, -1)
print("shape:", ex.shape)
print(ex)

a:
shape: (1, 3, [1, 2, 3])
[[[0],
  [1, 2],
  [3, 4, 5]]]

expanded at dim 0:
shape: (2, 3, Repeat([1, 2, 3], 2))
[[[0],
  [1, 2],
  [3, 4, 5]],

 [[0],
  [1, 2],
  [3, 4, 5]]]


Here on the other hand, the ragged dimension does not nest within the dimension being expanded. In these cases, the *component* of the ragged dimension will repeat, as dictated by the array's outward shape. For example:

In [73]:
a = arange(3, 1, [1, 2, 3])
print("a:")
print("shape:", a.shape)
print(a)

print("\nexpanded at dim 1:")
ex = a.expand(-1, 3, -1)
print("shape:", ex.shape)
print(ex)

a:
shape: (3, 1, [1, 2, 3])
[[[0]],

 [[1, 2]],

 [[3, 4, 5]]]

expanded at dim 1:
shape: (3, 3, Runs([1, 2, 3], 3))
[[[0],
  [0],
  [0]],

 [[1, 2],
  [1, 2],
  [1, 2]],

 [[3, 4, 5],
  [3, 4, 5],
  [3, 4, 5]]]


There can also be ragged expansions (where the expansion shape contains ragged dimensions):

In [82]:
batch_size = 2
seq_lens = [3, 4]
embed_dim = 4 # leave num_heads * head_dim flattened for example

print(f"{batch_size=} {seq_lens=} {embed_dim=}")

print("\nseqs:")
seqs = rand(batch_size, seq_lens, embed_dim)
print("shape:", seqs.shape)
print(seqs)

print("\nembedding dim indexes:")
channels = arange(embed_dim)
print("shape:", channels.shape)
print(channels)

print("\nembedding dim indexes expanded to (batch_size, seq_lens):")
seq_channels = channels.expand(batch_size, seq_lens, -1)
print("shape:", seq_channels.shape)
print(seq_channels)

batch_size=2 seq_lens=[3, 4] embed_dim=4

seqs:
shape: (2, [3, 4], 4)
[[[0.5988, 0.5491, 0.5487, 0.7435],
  [0.7356, 0.7945, 0.3120, 0.9949],
  [0.2556, 0.6538, 0.7603, 0.0410]],

 [[0.0383, 0.8300, 0.3204, 0.2771],
  [0.7986, 0.3528, 0.8540, 0.3388],
  [0.6433, 0.6445, 0.9027, 0.3966],
  [0.7495, 0.5815, 0.6112, 0.9893]]]

embedding dim indexes:
shape: (4)
[0, 1, 2, 3]

embedding dim indexes expanded to (batch_size, seq_lens):
shape: (2, [3, 4], 4)
[[[0, 1, 2, 3],
  [0, 1, 2, 3],
  [0, 1, 2, 3]],

 [[0, 1, 2, 3],
  [0, 1, 2, 3],
  [0, 1, 2, 3],
  [0, 1, 2, 3]]]


And ragged dimensions can interact with ragged expansions:

In [83]:
batch_size = 3
image_heights = [2, 4, 3]
image_wids = Runs([4, 8, 6], [2, 4, 3])

print(f"{batch_size=} {image_heights=} {image_wids=}")

print("\nimages:")
images = rand(batch_size, image_heights, image_wids)
print("shape:", images.shape)
print(images)

batch_size=3 image_heights=[2, 4, 3] image_wids=Runs(Seq([4, 8, 6]), Seq([2, 4, 3]))

images:
shape: (3, [2, 4, 3], Runs([4, 8, 6], [2, 4, 3]))
[[[0.5691, 0.0916, 0.5850, 0.6781],
  [0.6419, 0.6692, 0.4972, 0.2750]],

 [[0.4235, 0.6628, 0.7416, 0.3235, 0.0835, 0.2422, 0.2960, 0.7915],
  [0.8519, 0.4101, 0.5175, 0.8461, 0.7037, 0.9381, 0.9369, 0.2711],
  [0.7517, 0.5805, 0.2613, 0.7376, 0.8623, 0.4929, 0.8284, 0.4649],
  [0.7643, 0.7088, 0.7352, 0.9607, 0.1916, 0.3719, 0.7557, 0.9902]],

 [[0.0956, 0.6787, 0.5618, 0.1333, 0.1318, 0.7532],
  [0.8570, 0.9438, 0.5124, 0.7663, 0.9801, 0.5243],
  [0.2268, 0.1099, 0.1318, 0.7765, 0.8735, 0.4063]]]


In [88]:
print("\nimage indexes (unsqueezed):")
image_indexes = arange(batch_size, 1, 1)
print("shape:", image_indexes.shape)
print(image_indexes)

print("\nrimage indexes expanded to image shapes:")
image_indexes_exp = image_indexes.expand(batch_size, image_heights, image_wids)
print("shape:", image_indexes_exp.shape)
print(image_indexes_exp)


image indexes (unsqueezed):
shape: (3, 1, 1)
[[[0]],

 [[1]],

 [[2]]]

rimage indexes expanded to image shapes:
shape: (3, [2, 4, 3], Runs([4, 8, 6], [2, 4, 3]))
[[[0, 0, 0, 0],
  [0, 0, 0, 0]],

 [[1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1]],

 [[2, 2, 2, 2, 2, 2],
  [2, 2, 2, 2, 2, 2],
  [2, 2, 2, 2, 2, 2]]]


In [89]:
print("\nrow indexes (unsqueezed):")
row_indexes = arange(batch_size, image_heights, 1)
print("shape:", row_indexes.shape)
print(row_indexes)

print("\nrow indexes expanded to image shapes:")
row_indexes_exp = row_indexes.expand(batch_size, image_heights, image_wids)
print("shape:", row_indexes_exp.shape)
print(row_indexes_exp)


row indexes (unsqueezed):
shape: (3, [2, 4, 3], 1)
[[[0],
  [1]],

 [[2],
  [3],
  [4],
  [5]],

 [[6],
  [7],
  [8]]]

row indexes expanded to image shapes:
shape: (3, [2, 4, 3], Runs([4, 8, 6], [2, 4, 3]))
[[[0, 0, 0, 0],
  [1, 1, 1, 1]],

 [[2, 2, 2, 2, 2, 2, 2, 2],
  [3, 3, 3, 3, 3, 3, 3, 3],
  [4, 4, 4, 4, 4, 4, 4, 4],
  [5, 5, 5, 5, 5, 5, 5, 5]],

 [[6, 6, 6, 6, 6, 6],
  [7, 7, 7, 7, 7, 7],
  [8, 8, 8, 8, 8, 8]]]
