## Einsum Operator
`torch.einsum` is a **powerful and compact way to express tensor operations** (like matrix multiplication, dot products, outer products, transposes, or reductions) using **Einstein summation notation**.

It allows you to **specify exactly how indices should align, reduce, or broadcast** — without writing loops or calling multiple PyTorch ops.

---

### 1. **Einstein summation rule**

If an index appears **twice**, it’s **summed over**.
If it appears **once**, it’s **kept**.

Example of notation:

$$
C_{ij} = \sum_k A_{ik} B_{kj}
$$

is written in `einsum` as:

```python
torch.einsum('ik,kj->ij', A, B)
```

which is standard **matrix multiplication**.

---

### 2. **Simple numerical examples**

#### **Example 1 — Dot product**

$$
c = \sum_i a_i b_i
$$






In [12]:
import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])

# Dot product
c = torch.einsum('i,i->', a, b)
print(c)


tensor(32)


Explanation:
$$
1×4 + 2×5 + 3×6 = 32
$$

---



### **Example 2: Sum of all Pairwise Products**

`torch.einsum('i,j->', a, b)`

Now the indices are **different** (`i` and `j`).

Each of them appears **only once**, and both **disappear** in the output (no index after `->`),
so both are **summed over**.

That means:

$$
c = \sum_i \sum_j a_i b_j
$$

This is the **sum of all pairwise products** between elements of `a` and `b`.




In [13]:
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
c = torch.einsum('i,j->', a, b)
print(c)

tensor(90)


Manual calculation:

$$
(1×4 + 1×5 + 1×6) + (2×4 + 2×5 + 2×6) + (3×4 + 3×5 + 3×6) = 90
$$


---



| Expression  | Meaning                      | Formula                  | Result Type |
| ----------- | ---------------------------- | ------------------------ | ----------- |
| `'i,i->'`   | Dot product                  | $$\sum_i a_i b_i$$       | Scalar      |
| `'i,j->ij'` | Outer product                | $$a_i b_j$$              | Matrix      |
| `'i,j->'`   | Sum of all pairwise products | $$\sum_i\sum_j a_i b_j$$ | Scalar      |

---


#### **Example 3 — Outer product**

$$
C_{ij} = a_i b_j
$$

In [14]:

c=torch.einsum('i,j->ij', a, b)
print(c)


tensor([[ 4,  5,  6],
        [ 8, 10, 12],
        [12, 15, 18]])


No summation here — each pair is multiplied directly.

---

#### **Example 4 — Matrix multiplication**

$$
C_{ij} = \sum_k A_{ik} B_{kj}
$$



In [15]:
A = torch.tensor([[1, 2],
                  [3, 4]])
B = torch.tensor([[5, 6],
                  [7, 8]])

C = torch.einsum('ik,kj->ij', A, B)
print(C)


tensor([[19, 22],
        [43, 50]])


Same as `A @ B`.

---

#### **Example 5 — Sum over specific axes**

You can sum over an axis without writing `torch.sum`.

In [16]:
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

# Sum along rows (i.e., keep columns)
torch.einsum('ij->j', x)

tensor([5, 7, 9])

Equivalent to `x.sum(dim=0)`.

---

### 3. **Why we need `einsum`**

- ✅ **Expressive and compact**: One line replaces multiple `matmul`, `transpose`, and `sum` operations.
- ✅ **Clear intent**: The subscript notation shows how indices interact.
- ✅ **Efficient**: PyTorch optimizes many `einsum` patterns to use fast BLAS or cuBLAS kernels.
- ✅ **Flexible**: Works for batched operations, attention mechanisms, and tensor contractions in transformers.

---

