# Tensor operations

While implementing attention mechanism I stumbled across abundance of tensor arithmetic that is used in NN-context.
This notebook serves one simple purpose: get used to some common operators and patterns used in PyTorch for tensor transformations.

## Einstein summation notation
To use `torch.einsum` it is practical to keep the following in mind:
- **Sizes.** while performing einstein summation over tensors, we mark each axis of an input tensor with a label, if the same label appears in multiple tensors definitions, those dimensions must be equal.  
  E.g. we want to work with two tensors: `ij, jk->`. The second dimension of the first rank two tensor is equal to the first dimension of the second tensor.
- **Reduction.** If the index on the left side occurs twice, or if it is presneted on the left, but not on the right, it means that we sum over this index and it is reduced.  
  E.g. `i,i->`.
- **Free indices.** Any index that does appear in the output is kept (and dimension is preserved). The labels order defines the output axis order.  
  E.g.: `ij,jk->ik`, `ij->ji`.

**Intuition**:
1. Repeat a label and omit from output - the axis is contracted (summed away).
2. List a label in output - the axis is preserved, and placed according to the label order.

Interpretation of: `torch.einsum("ab,ac->ac", A, X)`, $A \in \mathbb{R}^{a \times b}$, $B \in \mathbb{R}^{a \times c}$.
- Indices: `a` = row, `b` = $A$ column, `c` = $X$ column.
- Formula:  
  $$\text{out}_{a c}=\sum_b A_{a b}\,X_{a c} =\big(\sum_b A_{a b}\big)\,X_{a c}.$$
- So we sum over `b` and it disappears: row-wise scaling of $X$ by the row sums of $A$.

Few more operations:
- `ab,ac->bc`, `a` is omitted, we sum over it, So it a matrix product: $\sum_a A_{ab}X_{ac}$, it is the same as $A^T B$.
- `ab,ac->cb`, `a` is omitted, we sum over it. The same as $B^T A$.

### Examples

In [None]:
import torch

# Outer product
a = torch.randn(3)
b = torch.randn(2)
torch.einsum("i,j->ij", a, b) # (3,2) 

# Inner product
a = torch.randn(3)  # (3,)
b = torch.randn(3)  # (3,)
torch.einsum("i,i->", a, b)

# Matrix-vector product
A = torch.randn(4, 2)
b = torch.randn(2)
torch.einsum("ij,j->i", A, b) # (4,)

# Batch matrix-vector product
A = torch.randn(100, 4, 2)
B = torch.randn(100, 2)
torch.einsum("ijk,ik->ij", A, B) # (100,4) == torch.bmm(A, B.unsqueeze(-1)).squeeze(-1)

# Attention score computation
Q = torch.randn(100, 32, 15, 64)       # (batch_size, num_heads, seq_len, d_k)
K = torch.randn(100, 32, 15, 64)       # (batch_size, num_heads, seq_len, d_k)
torch.einsum("bhid, bhjd->bhij", Q, K) # (batch_size, num_heads, seq_len, seq_len) == Q @ K.transpose(-2, -1)

# Weighted sum over tokens
X = torch.randn(100, 15, 3)
w = torch.randn(100, 15)
torch.einsum("ijk,ij->ik", X, w);

## Useful tensor operations

`squeeze`

**What:** removes size-1 dimensions

**Shape effect:** drops axes where length == 1.

**Examples:**

In [29]:
x = torch.randn(2, 1, 3, 1)
x.shape            # (2, 1, 3, 1)
x.squeeze().shape  # (2, 3)
x.squeeze(1).shape # (2, 3, 1)   # only if dim 1 is size-1 (else unchanged)

torch.Size([2, 3, 1])

`unsqueeze`

**What:** insert a size-1 dimension at a position.

**Shape effect:** rank + 1.

**Examples:**

In [30]:
x = torch.randn(4, 5)
x.shape                # (B, C)
x.unsqueeze(1).shape   # (B, 1, C)
x[:, None, :].shape    # same as unsqueeze(1)

torch.Size([4, 1, 5])

`permute`

**What:** reoder dimensions (general transpose)

**Shape effect:** same sizes, different order; returns a **view**.

**Examples:**

In [34]:
x = torch.randn(2, 4, 3, 12)

x.shape                         # (B, H, W, C)
x.permute(0, 3, 1, 2).shape     # (B, C, H, W)

# Special cases:
x.transpose(1, 2) == x.permute(0, 2, 1, 3);  # when ranks match appropriately

`chunk`

**What:** split a tensor into N equal parts along a dim

**Return:** a tuple of tensors (views when possible)

**Gotcha:** if the size is not divisible, some chunks could be larger than others.

**Examples:**

In [36]:
x = torch.randn(5, 10, 12)
x.shape                          # (B, T, 3*D)
q, k, v = x.chunk(3, dim=-1)     # (B, T, D) each
y1, y2, y3 = x.chunk(3, dim=1)   # split time steps along dim=1
print(y1.shape, y2.shape, y3.shape)  # each (B, T/3, 3*D)

torch.Size([5, 4, 12]) torch.Size([5, 4, 12]) torch.Size([5, 2, 12])


`.contigious()`

**What:** ensure memory is laid out contiguously (needed for .view).

**Shape:** unchanged.

**Gotcha:** may allocate/copy. Check `.is_contiguous()`

**Examples:**

In [None]:
y = x.permute(0,2,1)           # likely non-contiguous
z = y.contiguous().view(1, -1) # OK

`torch.flatten(x, start_dim=0, end_dim=-1)`

**What:** merge a dim range into one.

**Example**:

In [40]:
x = torch.randn(2, 3, 4, 5)
y = torch.flatten(x, start_dim=1, end_dim=-1)
y.shape  # (2, 60)

torch.Size([2, 60])

`torch.cat()` and `torch.stack()`

**What:** **concat** along an **existing** axis or **stack** along a **new** one

**Example**:

In [None]:
x = torch.randn(2, 3)
y = torch.randn(2, 3)

z_cat = torch.cat([x, y], dim=1)   # (2, 6)
z_stack = torch.stack([x, y], dim=2) # (2, 3, 2)
z_cat.shape, z_stack.shape

(torch.Size([2, 6]), torch.Size([2, 3, 2]))

`x.unbind(dim=0)`

**What:** split into tuple of slices **removing** that dim.

**Example:**

In [59]:
x = torch.randn(2, 3, 4)
a, b, c = x.unbind(dim=1)