In [None]:
"""
Tensor views and indexing
"""

import torch

# finding the largest value in each sliding window

x = torch.tensor([1, 3, 5, 3, 4, 8, -3, 4, 1, 6, 10, -2, 4, 7, 9, 2, 4, 8])
print("Original tensor:")
print(x)
print("\n")

x = torch.amax(x.unfold(0, 3, 1), 1)
print("Max value per window:")
print(x)
print("\n")


# switching the diagonals on a 3x3 tensor

def switch_diag(x):
  B, N, _ = x.shape

  row_idx = torch.arange(N).unsqueeze(1).expand(N, N)
  col_idx = torch.arange(N).unsqueeze(0).expand(N, N)

  row_flip = torch.flip(row_idx, dims=[0])

  condition = ((row_idx == col_idx) | (row_idx == N - 1 - col_idx))

  indices = torch.where(condition, row_flip, row_idx)
  switched = torch.gather(x, 1, indices.unsqueeze(0).expand(B, -1, -1))

  return switched

x = torch.rand(3, 3, 3)
print("Original tensor:")
print(x)
print("\n")

x = switch_diag(x)

print("Diagonals swapped:")
print(x)
print("\n")

# circularly shift rows in a 2D tensor given a 1D tensor of shift amounts

x = torch.rand(4, 6)
print("Original tensor:")
print(x)
print("\n")

shifts = torch.tensor([1, 2, 3, 5])
base = torch.arange(6)

# turn into horizontal and vertical tensors
shifts = shifts.unsqueeze(1)
base = base.unsqueeze(0)


indexes = (base - shifts) % 6

x = torch.gather(x, 1, indexes)
print("Circularly shifted rows:")
print(x)



Original tensor:
tensor([ 1,  3,  5,  3,  4,  8, -3,  4,  1,  6, 10, -2,  4,  7,  9,  2,  4,  8])


Max value per window:
tensor([ 5,  5,  5,  8,  8,  8,  4,  6, 10, 10, 10,  7,  9,  9,  9,  8])


Original tensor:
tensor([[[0.2305, 0.0887, 0.2495],
         [0.2267, 0.2038, 0.6351],
         [0.2012, 0.2304, 0.3218]],

        [[0.3466, 0.5151, 0.8706],
         [0.6147, 0.9274, 0.8861],
         [0.2861, 0.1986, 0.6178]],

        [[0.0638, 0.1078, 0.5703],
         [0.8292, 0.3267, 0.4190],
         [0.4318, 0.1250, 0.0242]]])


Diagonals swapped:
tensor([[[0.2012, 0.0887, 0.3218],
         [0.2267, 0.2038, 0.6351],
         [0.2305, 0.2304, 0.2495]],

        [[0.2861, 0.5151, 0.6178],
         [0.6147, 0.9274, 0.8861],
         [0.3466, 0.1986, 0.8706]],

        [[0.4318, 0.1078, 0.0242],
         [0.8292, 0.3267, 0.4190],
         [0.0638, 0.1250, 0.5703]]])


Original tensor:
tensor([[0.4054, 0.8881, 0.2399, 0.4808, 0.7113, 0.2123],
        [0.8397, 0.8563, 0.3416, 0.2179, 0.096