<a href="https://colab.research.google.com/github/leideng/AI-primer/blob/main/pytorch/tensor_slice.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# This is a notebook about how to efficiently slide a pytorch tensor in different manners

## Python's slicing

## Boolean index

## Index

In [11]:
# Python's slicing
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randint(low=-5, high=5, size=(5,8), device=device)
print(f"x={x}")

x1 = x[1,:]
print(f"row 1: x1=x[1,:]={x1}")
print(f"x1.shape={x1.shape}")

y1 = x[:,1]
print(f"column 1: y1=y[:,1]={y1}")
print(f"y1.shape={y1.shape}")

x2_3 = x[2,3]
x2_35 = x[2,3:5]
print(f"x2_3=x[2,3]={x2_3}")
print(f"x2_3.shape={x2_3.shape}")
print(f"x2_35=x[2,3:5]={x2_35}")
print(f"x2_35.shape={x2_35.shape}")


x=tensor([[-5,  3, -5, -2,  3, -4,  3,  1],
        [-3,  3, -3,  4, -2,  1,  2, -4],
        [-3,  0,  0,  1,  0, -2, -4, -3],
        [ 0,  2, -3,  0,  1, -5,  3,  4],
        [-4, -1,  1, -1, -5,  3,  3,  3]])
row 1: x1=x[1,:]=tensor([-3,  3, -3,  4, -2,  1,  2, -4])
x1.shape=torch.Size([8])
column 1: y1=y[:,1]=tensor([ 3,  3,  0,  2, -1])
y1.shape=torch.Size([5])
x2_3=x[2,3]=1
x2_3.shape=torch.Size([])
x2_35=x[2,3:5]=tensor([1, 0])
x2_35.shape=torch.Size([2])


In [15]:
# Boolean Index
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randint(low=1,high=100, size=(5,10), device=device)
print(f"x={x}")
print(f"x.shape={x.shape}")

idx = (x >= 80)
print(f"idx=(x >= 80)={idx}")
print(f"idx.shape={idx.shape}")

y = x[idx]
print(f"y={y}")  #flattern to 1D tensor
print(f"y.shape={y.shape}")

x=tensor([[44, 95, 71, 60, 23, 31, 36, 81, 68, 22],
        [12,  9, 32, 64, 99, 66, 72, 49, 12, 16],
        [10, 66, 10, 46, 60, 48, 46, 98, 35, 22],
        [74,  1, 37, 43, 22, 84, 84, 13, 47, 29],
        [86,  2, 12, 20, 32, 51, 61, 90, 34,  1]])
x.shape=torch.Size([5, 10])
idx=(x >= 80)=tensor([[False,  True, False, False, False, False, False,  True, False, False],
        [False, False, False, False,  True, False, False, False, False, False],
        [False, False, False, False, False, False, False,  True, False, False],
        [False, False, False, False, False,  True,  True, False, False, False],
        [ True, False, False, False, False, False, False,  True, False, False]])
idx.shape=torch.Size([5, 10])
y=tensor([95, 81, 99, 98, 84, 84, 86, 90])
y.shape=torch.Size([8])


In [18]:
# Advanced Index
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# Define the block_table tensor
block_table = torch.tensor([[17, 0, 14, 12, 15, 29, 8, -1, -1, -1],
                            [6, 24, 28, 22, 27, 25, -1, -1, -1, -1],
                            [2, 4, 21, 1, 9, 11, 13, 3, 16, -1]], device=device)

# Define the index tensor
idx = torch.tensor([[0, 2, 5, 6, 9],
                    [1, 2, 3, 4, 8],
                    [0, 2, 3, 5, 7]])

# Get the number of rows from block_table and the number of columns from idx
num_rows = block_table.shape[0]
num_cols_idx = idx.shape[1]

# Create a row index tensor that broadcasts across the columns of idx
# torch.arange(num_rows) creates [0, 1, 2]
# .unsqueeze(1) changes it to [[0], [1], [2]]
# .expand(-1, num_cols_idx) expands it to [[0,0,0,0,0], [1,1,1,1,1], [2,2,2,2,2]]
row_indices = torch.arange(num_rows).unsqueeze(1).expand(-1, num_cols_idx)

# Use advanced indexing to get the sliced tensor
sliced_tensor = block_table[row_indices, idx]

print("Original block_table:")
print(block_table)
print("\nRow idx:")
print(row_indices)
print("\nOriginal (cloumn) idx:")
print(idx)
print("\nSliced tensor:")
print(sliced_tensor)

Original block_table:
tensor([[17,  0, 14, 12, 15, 29,  8, -1, -1, -1],
        [ 6, 24, 28, 22, 27, 25, -1, -1, -1, -1],
        [ 2,  4, 21,  1,  9, 11, 13,  3, 16, -1]])

Row idx:
tensor([[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2]])

Original (cloumn) idx:
tensor([[0, 2, 5, 6, 9],
        [1, 2, 3, 4, 8],
        [0, 2, 3, 5, 7]])

Sliced tensor:
tensor([[17, 14, 29,  8, -1],
        [24, 28, 22, 27, -1],
        [ 2, 21,  1, 11,  3]])


**Explanatiopn**
- `num_rows = block_table.shape[0]`: This gets the number of rows in your block_table (which is 3).
- `num_cols_idx = idx.shape[1]`: This gets the number of columns in your idx tensor (which is 5).
- `row_indices = torch.arange(num_rows)`: This creates a 1D tensor [0, 1, 2]. These are the row indices we want to select from block_table.
- `.unsqueeze(1)`: This reshapes row_indices from [0, 1, 2] to [[0], [1], [2]]. This is crucial for broadcasting.
- `.expand(-1, num_cols_idx)`: This expands the row_indices tensor to have the same shape as idx (3x5). The -1 means "keep the size of this dimension as is". So, [[0], [1], [2]] becomes [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2]].
- `block_table[row_indices, idx]`: This is the advanced indexing step. PyTorch takes corresponding elements from row_indices and idx to form (row, col) pairs for indexing.
  - For the first row, it uses `(0, idx[0,0]), (0, idx[0,1]), ..., (0, idx[0,4])`.
  - For the second row, it uses `(1, idx[1,0]), (1, idx[1,1]), ..., (1, idx[1,4])`.
  - And so on.

This method is efficient as it leverages PyTorch's optimized C++ backend for tensor operations.