In [22]:
from fold import *

In [23]:
from functools import reduce
from operator import mul

from fold import *

def product(seq):
    return reduce(mul, seq, 1)

# check rectangular shapes against PT
def check_rect(shape, x, y):
    a = arange(*shape)
    at = a.transpose(x, y)
    print(f"{a.shape}.transpose({x}, {y}) -> {at.shape}")
    t = torch.arange(product(shape)).reshape(*shape)
    tt = t.transpose(x, y)
    assert tt.equal(torch.tensor(at.tolist()))
    
# for now, check others by eye
def check(shape, x, y):
    a = arange(*shape)
    print(f"shape (unpacked): {a.shape.unpack()}")
    print(a)
    print('->')
    t = a.transpose(x, y)
    print(f"shape (unpacked): {t.shape.unpack()}")
    print(t)

### all rectangular dims

In [24]:
check_rect((2, 4, 3, 2), 0, 1)
check_rect((2, 4, 3, 2), 1, 2)
check_rect((2, 4, 3, 2), 2, 3)
check_rect((2, 4, 3, 2), 0, 2)
check_rect((2, 4, 3, 2), 1, 3)
check_rect((1, 2, 3, 4), 0, 3)


(2, 4, 3, 2).transpose(0, 1) -> (4, 2, 3, 2)
(2, 4, 3, 2).transpose(1, 2) -> (2, 3, 4, 2)
(2, 4, 3, 2).transpose(2, 3) -> (2, 4, 2, 3)
(2, 4, 3, 2).transpose(0, 2) -> (3, 4, 2, 2)
(2, 4, 3, 2).transpose(1, 3) -> (2, 2, 3, 4)
(1, 2, 3, 4).transpose(0, 3) -> (4, 2, 3, 1)


### adjacent

In [25]:
# rect/rect, ragged tail
# (Ed's second example here: https://github.com/pytorch/pytorch/pull/118405#discussion_r1501562977)
check((2, 4, [1, 2, 3, 4]), 0, 1)

shape (unpacked): ([2], [4, 4], [1, 2, 3, 4, 1, 2, 3, 4])
[[[ 0],
  [ 1,  2],
  [ 3,  4,  5],
  [ 6,  7,  8,  9]],

 [[10],
  [11, 12],
  [13, 14, 15],
  [16, 17, 18, 19]]]
->
shape (unpacked): ([4], [2, 2, 2, 2], [1, 1, 2, 2, 3, 3, 4, 4])
[[[ 0],
  [10]],

 [[ 1,  2],
  [11, 12]],

 [[ 3,  4,  5],
  [13, 14, 15]],

 [[ 6,  7,  8,  9],
  [16, 17, 18, 19]]]


In [26]:
# rect/rect, non-nested ragged tail
check((2, 3, [1, 2, 3, 4, 5, 6]), 0, 1)

shape (unpacked): ([2], [3, 3], [1, 2, 3, 4, 5, 6])
[[[ 0],
  [ 1,  2],
  [ 3,  4,  5]],

 [[ 6,  7,  8,  9],
  [10, 11, 12, 13, 14],
  [15, 16, 17, 18, 19, 20]]]
->
shape (unpacked): ([3], [2, 2, 2], [1, 4, 2, 5, 3, 6])
[[[ 0],
  [ 6,  7,  8,  9]],

 [[ 1,  2],
  [10, 11, 12, 13, 14]],

 [[ 3,  4,  5],
  [15, 16, 17, 18, 19, 20]]]


In [27]:
# ragged/rect, no tail
check((3, [1, 2, 3], 2), 1, 2)

shape (unpacked): ([3], [1, 2, 3], [2, 2, 2, 2, 2, 2])
[[[ 0,  1]],

 [[ 2,  3],
  [ 4,  5]],

 [[ 6,  7],
  [ 8,  9],
  [10, 11]]]
->
shape (unpacked): ([3], [2, 2, 2], [1, 1, 2, 2, 3, 3])
[[[ 0],
  [ 1]],

 [[ 2,  4],
  [ 3,  5]],

 [[ 6,  8, 10],
  [ 7,  9, 11]]]


In [28]:
# ragged/rect, rect tail (sdpa)
B, S, H, D = 3, [1, 2, 3], 2, 3
check((B, S, H, D), 1, 2)

shape (unpacked): ([3], [1, 2, 3], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
[[[[ 0,  1,  2],
   [ 3,  4,  5]]],


 [[[ 6,  7,  8],
   [ 9, 10, 11]],

  [[12, 13, 14],
   [15, 16, 17]]],


 [[[18, 19, 20],
   [21, 22, 23]],

  [[24, 25, 26],
   [27, 28, 29]],

  [[30, 31, 32],
   [33, 34, 35]]]]
->
shape (unpacked): ([3], [2, 2, 2], [1, 1, 2, 2, 3, 3], [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
[[[[ 0,  1,  2]],

  [[ 3,  4,  5]]],


 [[[ 6,  7,  8],
   [12, 13, 14]],

  [[ 9, 10, 11],
   [15, 16, 17]]],


 [[[18, 19, 20],
   [24, 25, 26],
   [30, 31, 32]],

  [[21, 22, 23],
   [27, 28, 29],
   [33, 34, 35]]]]


In [29]:
# another sdpa
B, S, H, D = 2, [3, 2], 2, 8
check((B, S, H, D), 1, 2)

shape (unpacked): ([2], [3, 2], [2, 2, 2, 2, 2], [8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
[[[[ 0,  1,  2,  3,  4,  5,  6,  7],
   [ 8,  9, 10, 11, 12, 13, 14, 15]],

  [[16, 17, 18, 19, 20, 21, 22, 23],
   [24, 25, 26, 27, 28, 29, 30, 31]],

  [[32, 33, 34, 35, 36, 37, 38, 39],
   [40, 41, 42, 43, 44, 45, 46, 47]]],


 [[[48, 49, 50, 51, 52, 53, 54, 55],
   [56, 57, 58, 59, 60, 61, 62, 63]],

  [[64, 65, 66, 67, 68, 69, 70, 71],
   [72, 73, 74, 75, 76, 77, 78, 79]]]]
->
shape (unpacked): ([2], [2, 2], [3, 3, 2, 2], [8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
[[[[ 0,  1,  2,  3,  4,  5,  6,  7],
   [16, 17, 18, 19, 20, 21, 22, 23],
   [32, 33, 34, 35, 36, 37, 38, 39]],

  [[ 8,  9, 10, 11, 12, 13, 14, 15],
   [24, 25, 26, 27, 28, 29, 30, 31],
   [40, 41, 42, 43, 44, 45, 46, 47]]],


 [[[48, 49, 50, 51, 52, 53, 54, 55],
   [64, 65, 66, 67, 68, 69, 70, 71]],

  [[56, 57, 58, 59, 60, 61, 62, 63],
   [72, 73, 74, 75, 76, 77, 78, 79]]]]


In [30]:
# ragged/rect, ragged tail
check((3, [1, 2, 3], 2, [1, 2]), 1, 2)

shape (unpacked): ([3], [1, 2, 3], [2, 2, 2, 2, 2, 2], [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2])
[[[[ 0],
   [ 1,  2]]],


 [[[ 3],
   [ 4,  5]],

  [[ 6],
   [ 7,  8]]],


 [[[ 9],
   [10, 11]],

  [[12],
   [13, 14]],

  [[15],
   [16, 17]]]]
->
shape (unpacked): ([3], [2, 2, 2], [1, 1, 2, 2, 3, 3], [1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 2])
[[[[ 0]],

  [[ 1,  2]]],


 [[[ 3],
   [ 6]],

  [[ 4,  5],
   [ 7,  8]]],


 [[[ 9],
   [12],
   [15]],

  [[10, 11],
   [13, 14],
   [16, 17]]]]


In [31]:
# ragged/rect, non-nested ragged tail
check((3, [1, 2, 3], 2, Affine(1, 12, 1)), 1, 2)

shape (unpacked): ([3], [1, 2, 3], [2, 2, 2, 2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
[[[[ 0],
   [ 1,  2]]],


 [[[ 3,  4,  5],
   [ 6,  7,  8,  9]],

  [[10, 11, 12, 13, 14],
   [15, 16, 17, 18, 19, 20]]],


 [[[21, 22, 23, 24, 25, 26, 27],
   [28, 29, 30, 31, 32, 33, 34, 35]],

  [[36, 37, 38, 39, 40, 41, 42, 43, 44],
   [45, 46, 47, 48, 49, 50, 51, 52, 53, 54]],

  [[55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65],
   [66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77]]]]
->
shape (unpacked): ([3], [2, 2, 2], [1, 1, 2, 2, 3, 3], [1, 2, 3, 5, 4, 6, 7, 9, 11, 8, 10, 12])
[[[[ 0]],

  [[ 1,  2]]],


 [[[ 3,  4,  5],
   [10, 11, 12, 13, 14]],

  [[ 6,  7,  8,  9],
   [15, 16, 17, 18, 19, 20]]],


 [[[21, 22, 23, 24, 25, 26, 27],
   [36, 37, 38, 39, 40, 41, 42, 43, 44],
   [55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]],

  [[28, 29, 30, 31, 32, 33, 34, 35],
   [45, 46, 47, 48, 49, 50, 51, 52, 53, 54],
   [66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77]]]]


### non-adjacent

In [32]:
# rect/.../rect, ragged tail
check((2, 3, 3, [1, 2, 3]), 0, 2)

shape (unpacked): ([2], [3, 3], [3, 3, 3, 3, 3, 3], [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
[[[[ 0],
   [ 1,  2],
   [ 3,  4,  5]],

  [[ 6],
   [ 7,  8],
   [ 9, 10, 11]],

  [[12],
   [13, 14],
   [15, 16, 17]]],


 [[[18],
   [19, 20],
   [21, 22, 23]],

  [[24],
   [25, 26],
   [27, 28, 29]],

  [[30],
   [31, 32],
   [33, 34, 35]]]]
->
shape (unpacked): ([3], [3, 3, 3], [2, 2, 2, 2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3])
[[[[ 0],
   [18]],

  [[ 6],
   [24]],

  [[12],
   [30]]],


 [[[ 1,  2],
   [19, 20]],

  [[ 7,  8],
   [25, 26]],

  [[13, 14],
   [31, 32]]],


 [[[ 3,  4,  5],
   [21, 22, 23]],

  [[ 9, 10, 11],
   [27, 28, 29]],

  [[15, 16, 17],
   [33, 34, 35]]]]


In [33]:
# rect/.../rect, non-nested ragged tail
check((2, 2, 2, Affine(1, 8, 1)), 0, 2)

shape (unpacked): ([2], [2, 2], [2, 2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8])
[[[[ 0],
   [ 1,  2]],

  [[ 3,  4,  5],
   [ 6,  7,  8,  9]]],


 [[[10, 11, 12, 13, 14],
   [15, 16, 17, 18, 19, 20]],

  [[21, 22, 23, 24, 25, 26, 27],
   [28, 29, 30, 31, 32, 33, 34, 35]]]]
->
shape (unpacked): ([2], [2, 2], [2, 2, 2, 2], [1, 5, 3, 7, 2, 6, 4, 8])
[[[[ 0],
   [10, 11, 12, 13, 14]],

  [[ 3,  4,  5],
   [21, 22, 23, 24, 25, 26, 27]]],


 [[[ 1,  2],
   [15, 16, 17, 18, 19, 20]],

  [[ 6,  7,  8,  9],
   [28, 29, 30, 31, 32, 33, 34, 35]]]]


In [34]:
# ragged/.../rect, no tail
check((2, [2, 3], 2, 3), 1, 3)

shape (unpacked): ([2], [2, 3], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
[[[[ 0,  1,  2],
   [ 3,  4,  5]],

  [[ 6,  7,  8],
   [ 9, 10, 11]]],


 [[[12, 13, 14],
   [15, 16, 17]],

  [[18, 19, 20],
   [21, 22, 23]],

  [[24, 25, 26],
   [27, 28, 29]]]]
->
shape (unpacked): ([2], [3, 3], [2, 2, 2, 2, 2, 2], [2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3])
[[[[ 0,  6],
   [ 3,  9]],

  [[ 1,  7],
   [ 4, 10]],

  [[ 2,  8],
   [ 5, 11]]],


 [[[12, 18, 24],
   [15, 21, 27]],

  [[13, 19, 25],
   [16, 22, 28]],

  [[14, 20, 26],
   [17, 23, 29]]]]


In [35]:
# ragged/.../rect, rect tail
check((2, [1, 2], 2, 3, 2), 1, 3)

shape (unpacked): ([2], [1, 2], [2, 2, 2], [3, 3, 3, 3, 3, 3], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
[[[[[ 0,  1],
    [ 2,  3],
    [ 4,  5]],

   [[ 6,  7],
    [ 8,  9],
    [10, 11]]]],



 [[[[12, 13],
    [14, 15],
    [16, 17]],

   [[18, 19],
    [20, 21],
    [22, 23]]],


  [[[24, 25],
    [26, 27],
    [28, 29]],

   [[30, 31],
    [32, 33],
    [34, 35]]]]]
->
shape (unpacked): ([2], [3, 3], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
[[[[[ 0,  1]],

   [[ 6,  7]]],


  [[[ 2,  3]],

   [[ 8,  9]]],


  [[[ 4,  5]],

   [[10, 11]]]],



 [[[[12, 13],
    [24, 25]],

   [[18, 19],
    [30, 31]]],


  [[[14, 15],
    [26, 27]],

   [[20, 21],
    [32, 33]]],


  [[[16, 17],
    [28, 29]],

   [[22, 23],
    [34, 35]]]]]


In [36]:
# ragged/.../rect, ragged tail
check((2, [1, 2], 2, 3, [1, 2, 3]), 1, 3)

shape (unpacked): ([2], [1, 2], [2, 2, 2], [3, 3, 3, 3, 3, 3], [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
[[[[[ 0],
    [ 1,  2],
    [ 3,  4,  5]],

   [[ 6],
    [ 7,  8],
    [ 9, 10, 11]]]],



 [[[[12],
    [13, 14],
    [15, 16, 17]],

   [[18],
    [19, 20],
    [21, 22, 23]]],


  [[[24],
    [25, 26],
    [27, 28, 29]],

   [[30],
    [31, 32],
    [33, 34, 35]]]]]
->
shape (unpacked): ([2], [3, 3], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2], [1, 1, 2, 2, 3, 3, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3])
[[[[[ 0]],

   [[ 6]]],


  [[[ 1,  2]],

   [[ 7,  8]]],


  [[[ 3,  4,  5]],

   [[ 9, 10, 11]]]],



 [[[[12],
    [24]],

   [[18],
    [30]]],


  [[[13, 14],
    [25, 26]],

   [[19, 20],
    [31, 32]]],


  [[[15, 16, 17],
    [27, 28, 29]],

   [[21, 22, 23],
    [33, 34, 35]]]]]


In [37]:
# ragged/.../rect, non-nested ragged tail
check((2, [1, 2], 2, 2, Affine(1, 12, 1)), 1, 3)

shape (unpacked): ([2], [1, 2], [2, 2, 2], [2, 2, 2, 2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
[[[[[ 0],
    [ 1,  2]],

   [[ 3,  4,  5],
    [ 6,  7,  8,  9]]]],



 [[[[10, 11, 12, 13, 14],
    [15, 16, 17, 18, 19, 20]],

   [[21, 22, 23, 24, 25, 26, 27],
    [28, 29, 30, 31, 32, 33, 34, 35]]],


  [[[36, 37, 38, 39, 40, 41, 42, 43, 44],
    [45, 46, 47, 48, 49, 50, 51, 52, 53, 54]],

   [[55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65],
    [66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77]]]]]
->
shape (unpacked): ([2], [2, 2], [2, 2, 2, 2], [1, 1, 1, 1, 2, 2, 2, 2], [1, 3, 2, 4, 5, 9, 7, 11, 6, 10, 8, 12])
[[[[[ 0]],

   [[ 3,  4,  5]]],


  [[[ 1,  2]],

   [[ 6,  7,  8,  9]]]],



 [[[[10, 11, 12, 13, 14],
    [36, 37, 38, 39, 40, 41, 42, 43, 44]],

   [[21, 22, 23, 24, 25, 26, 27],
    [55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]]],


  [[[15, 16, 17, 18, 19, 20],
    [45, 46, 47, 48, 49, 50, 51, 52, 53, 54]],

   [[28, 29, 30, 31, 32, 33, 34, 35],
    [66, 67, 68, 69, 70, 

### transposing ragged dimensions outward causes shear

fold disallows transposes/permutes that shift ragged dimensions outward. 
For ragged shapes whose widths are not strictly descending, such transpositions 
will cause *shear*: a shift in alignment among elements. E.g.

In [38]:
a = arange(2, 3, [1, 2, 3])
print(a)

[[[ 0],
  [ 1,  2],
  [ 3,  4,  5]],

 [[ 6],
  [ 7,  8],
  [ 9, 10, 11]]]


`a.transpose(1, 2)` will error, but here's what it would look like:
```
[[[0, 1, 3],
  [2, 4],
  [5]],
 
 [[6, 7, 9],
  [8, 10],
  [11]]]
```
Note how rows have sheared when transposed to columns.

### transposing across ragged dimensions causes shear

Shearing also occurs in transpositions *across* a ragged dimension: inserting the former outer dimension inward of the ragged dimension has the same structural effect as moving the ragged dimension outward directly.

E.g. Ed's first example here: https://github.com/pytorch/pytorch/pull/118405#discussion_r1501562977

In [39]:
a = arange(3, [1, 2, 3], 5)
print(a)

[[[ 0,  1,  2,  3,  4]],

 [[ 5,  6,  7,  8,  9],
  [10, 11, 12, 13, 14]],

 [[15, 16, 17, 18, 19],
  [20, 21, 22, 23, 24],
  [25, 26, 27, 28, 29]]]


`a.transpose(0, 2)` will error, but here's what it would look like:
```
[[[0, 5, 15],
  [10, 20],
  [25]],
 [[1, 6, 16],
  [11, 21],
  [26]],
 [[2, 7, 17],
  [12, 22],
  [27]],
 [[3, 8, 18],
  [13, 23],
  [28]],
 [[4, 9, 19],
  [14, 24],
  [29]]]
```