## Rough Notebook for Understanding Terms in Llama 2 Notebook

In [1]:
import torch

In [4]:
a = torch.arange(1, 5)
b = torch.arange(6, 10)
print("Matrix A : ")
print(a)
print(f"Shape of A is : {a.shape}")
print("Matrix B : ")
print(b)
print(f"Shape of B is : {b.shape}")

Matrix A : 
tensor([1, 2, 3, 4])
Shape of A is : torch.Size([4])
Matrix B : 
tensor([6, 7, 8, 9])
Shape of B is : torch.Size([4])


In [5]:
out = torch.outer(a, b).float()
print("Output matrix : ")
print(out)
print(f"Output matrix shape : {out.shape}")

Output matrix : 
tensor([[ 6.,  7.,  8.,  9.],
        [12., 14., 16., 18.],
        [18., 21., 24., 27.],
        [24., 28., 32., 36.]])
Output matrix shape : torch.Size([4, 4])


In [6]:
ones_like = torch.ones_like(out)
print("Ones like matrix : ")
print(ones_like)
print(f"Ones like matrix shape : {ones_like.shape}")

Ones like matrix : 
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
Ones like matrix shape : torch.Size([4, 4])


In [7]:
polar = torch.polar(ones_like, out)
print("Polar matrix : ")
print(polar)
print(f"Polar matrix shape : {polar.shape}")

Polar matrix : 
tensor([[ 0.9602-0.2794j,  0.7539+0.6570j, -0.1455+0.9894j, -0.9111+0.4121j],
        [ 0.8439-0.5366j,  0.1367+0.9906j, -0.9577-0.2879j,  0.6603-0.7510j],
        [ 0.6603-0.7510j, -0.5477+0.8367j,  0.4242-0.9056j, -0.2921+0.9564j],
        [ 0.4242-0.9056j, -0.9626+0.2709j,  0.8342+0.5514j, -0.1280-0.9918j]])
Polar matrix shape : torch.Size([4, 4])


### Polar Matrix calculation : 

To compute the polar form using `torch.polar` in PyTorch, follow these steps:

### 1. **Outer Product Calculation**  
The outer product of vectors `a` and `b` creates a matrix where each element `(i, j)` is `a[i] * b[j]`. This matrix (`out`) contains **angles in radians** for the polar form.  
**Example**:  
If `a = [1, 2, 3, 4]` and `b = [6, 7, 8, 9]`, the first row of `out` is:  
`[1*6, 1*7, 1*8, 1*9] = [6, 7, 8, 9]` (angles in radians).

---

### 2. **Magnitude Matrix**  
`torch.ones_like(out)` generates a matrix of the same shape as `out`, filled with `1.0`. This sets the **magnitude** of every complex number to 1.  

---

### 3. **Polar Conversion**  
For each element in the matrices:  
- **Magnitude**: Taken from `ones_like` (always `1.0` here).  
- **Angle**: Taken from `out` (values like 6, 7, 8, 9 in radians).  

The polar form is computed as:  
$
\text{complex_number} = \text{magnitude} \cdot \left( \cos(\text{angle}) + i \cdot \sin(\text{angle}) \right)
$

**Example Calculations**:  
- For `angle = 6` (first element of `out`):  
  $
  \cos(6) \approx 0.9602, \quad \sin(6) \approx -0.2794 \implies 0.9602 - 0.2794i
  $
- For `angle = 7` (second element):  
  $
  \cos(7) \approx 0.7539, \quad \sin(7) \approx 0.6570 \implies 0.7539 + 0.6570i
  $

---

### 4. **Result**  
The final `polar` tensor is a **4x4 complex matrix** where each element lies on the unit circle (magnitude = 1) with angles derived from the outer product matrix.

---

### Key Formula:  
$
\text{Polar element} = \cos(\theta) + i \cdot \sin(\theta) \quad \text{where } \theta \text{ is from the outer product matrix}
$  
**Note**: Discrepancies in the user’s output (e.g., `sin(8) ≈ 0.0894` instead of `≈ 0.9894`) may indicate formatting errors or typos.

In [10]:
mat = torch.randn(5,2)
print("Matrix : ")
print(mat)
comp = torch.view_as_complex(mat.float())
print("Complex matrix : ")
print(comp)

Matrix : 
tensor([[ 0.1296,  0.3259],
        [ 1.3146, -1.8053],
        [-0.5212, -0.5961],
        [ 0.7684,  1.0464],
        [ 0.3083, -1.0735]])
Complex matrix : 
tensor([ 0.1296+0.3259j,  1.3146-1.8053j, -0.5212-0.5961j,  0.7684+1.0464j,
         0.3083-1.0735j])


In [15]:
print(f"Currently shape of comp is : {comp.shape}")
comp_reshaped = torch.view_as_complex(mat.float().reshape(*mat.shape[:-1], -1, 2))
print("Comp reshaped : ")
print(comp_reshaped)
print(f"Shape of comp reshaped : {comp_reshaped.shape}")
intermediate_mat = mat.float().reshape(*mat.shape[:-1], -1, 2)
print("Intermediate mat : ")
print(intermediate_mat)
print(f"Shape of intermediate mat : {intermediate_mat.shape}")

Currently shape of comp is : torch.Size([5])
Comp reshaped : 
tensor([[ 0.1296+0.3259j],
        [ 1.3146-1.8053j],
        [-0.5212-0.5961j],
        [ 0.7684+1.0464j],
        [ 0.3083-1.0735j]])
Shape of comp reshaped : torch.Size([5, 1])
Intermediate mat : 
tensor([[[ 0.1296,  0.3259]],

        [[ 1.3146, -1.8053]],

        [[-0.5212, -0.5961]],

        [[ 0.7684,  1.0464]],

        [[ 0.3083, -1.0735]]])
Shape of intermediate mat : torch.Size([5, 1, 2])


In [21]:
print(intermediate_mat)
print(f"Shape of Intermediate mat : {intermediate_mat.shape}")
'''
Adding a dimension of 1 at the 0th position, 
so the shape goes from [5, 1, 2] -> [1, 5, 1, 2]
'''
copy = intermediate_mat.unsqueeze(0)
print("Copy : ")
print(copy)
print(f"New copy shape : {copy.shape}")
'''
Adding a dimension of 1 at the 2nd position, 
so the shape goes from [1, 5, 1, 2] -> [1, 5, 1, 1, 2]
'''
new_copy = copy.unsqueeze(2)
print("New Copy : ")
print(new_copy)
print(f"Shape of new copy : {new_copy.shape}")

tensor([[[ 0.1296,  0.3259]],

        [[ 1.3146, -1.8053]],

        [[-0.5212, -0.5961]],

        [[ 0.7684,  1.0464]],

        [[ 0.3083, -1.0735]]])
Shape of Intermediate mat : torch.Size([5, 1, 2])
Copy : 
tensor([[[[ 0.1296,  0.3259]],

         [[ 1.3146, -1.8053]],

         [[-0.5212, -0.5961]],

         [[ 0.7684,  1.0464]],

         [[ 0.3083, -1.0735]]]])
New copy shape : torch.Size([1, 5, 1, 2])
New Copy : 
tensor([[[[[ 0.1296,  0.3259]]],


         [[[ 1.3146, -1.8053]]],


         [[[-0.5212, -0.5961]]],


         [[[ 0.7684,  1.0464]]],


         [[[ 0.3083, -1.0735]]]]])
Shape of new copy : torch.Size([1, 5, 1, 1, 2])


### RMS Norm

In [43]:
import torch
import torch.nn as nn

class RMSNorm(nn.Module):

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        
    def _norm(self, x : torch.Tensor):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdims=True) + self.eps)
    
    def forward(self, x: torch.Tensor):
        return self.weight * self._norm(x.float()).type_as(x)
    
# input of (batch_size=2, seq_len=3, dim=4)
x = torch.tensor([
    [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0], [5.0, 6.0, 7.0, 8.0]],
    [[-1.0, -2.0, -3.0, -4.0], [-2.0, -3.0, -4.0, -5.0], [-5.0, -6.0, -7.0, -8.0]]
])
norm = RMSNorm(dim=4)
print("Input tensor : ")
print(x)
print("After norm : ")
print(norm(x))

Input tensor : 
tensor([[[ 1.,  2.,  3.,  4.],
         [ 2.,  3.,  4.,  5.],
         [ 5.,  6.,  7.,  8.]],

        [[-1., -2., -3., -4.],
         [-2., -3., -4., -5.],
         [-5., -6., -7., -8.]]])
After norm : 
tensor([[[ 0.3651,  0.7303,  1.0954,  1.4606],
         [ 0.5443,  0.8165,  1.0887,  1.3608],
         [ 0.7581,  0.9097,  1.0613,  1.2130]],

        [[-0.3651, -0.7303, -1.0954, -1.4606],
         [-0.5443, -0.8165, -1.0887, -1.3608],
         [-0.7581, -0.9097, -1.0613, -1.2130]]], grad_fn=<MulBackward0>)


### **Step-by-Step Explanation of `x.pow(2).mean(-1, keepdims=True)`**

---

#### **1. `x.pow(2)`**  
- **Purpose**: Squares every element in the tensor `x`.  
- **Example**:  
  If `x = [[1.0, 2.0], [3.0, 4.0]]`, then:  
  ```python
  x.pow(2) = [[1.0, 4.0], [9.0, 16.0]]
  ```

---

#### **2. `.mean(-1, keepdims=True)`**  
- **Purpose**:  
  - Computes the **mean** across the **last dimension** (axis `-1`).  
  - `keepdims=True` retains the original number of dimensions (e.g., keeps a 3D tensor as 3D).  

- **Example**:  
  Suppose `x.pow(2)` has shape `(2, 3, 4)` (batch size=2, sequence length=3, features=4).  
  After `.mean(-1, keepdims=True)`:  
  - **Shape**: `(2, 3, 1)` (mean computed over the last dimension, size 4 → 1).  
  - **Result**: Each token (last dimension) is replaced by the mean of its squared values.  

---

#### **3. Full Example**  
Let’s use a concrete tensor:  
```python
x = torch.tensor([
    [[1.0, 2.0, 3.0, 4.0],  # Batch 1, Token 1
     [2.0, 3.0, 4.0, 5.0]], # Batch 1, Token 2
    
    [[-1.0, -2.0, -3.0, -4.0],  # Batch 2, Token 1
     [-2.0, -3.0, -4.0, -5.0]]  # Batch 2, Token 2
])  # Shape: (2, 2, 4)
```

**Step 1: `x.pow(2)`**  
```python
tensor([
    [[1.0, 4.0, 9.0, 16.0],
     [4.0, 9.0, 16.0, 25.0]],
    
    [[1.0, 4.0, 9.0, 16.0],
     [4.0, 9.0, 16.0, 25.0]]
])
```

**Step 2: `.mean(-1, keepdims=True)`**  
For each token vector (last dimension):  
- Batch 1, Token 1: \((1 + 4 + 9 + 16)/4 = 7.5\)  
- Batch 1, Token 2: \((4 + 9 + 16 + 25)/4 = 13.5\)  
- Batch 2, Token 1: \((1 + 4 + 9 + 16)/4 = 7.5\)  
- Batch 2, Token 2: \((4 + 9 + 16 + 25)/4 = 13.5\)  

**Result**:  
```python
tensor([
    [[7.5],  # Shape: (2, 2, 1)
     [13.5]],
    
    [[7.5],
     [13.5]]
])
```

---

#### **4. Why is this used in RMSNorm?**  
- **RMS (Root Mean Square)**:  
  $
  \text{RMS} = \sqrt{\text{mean}(x^2)}
  $ 
  This term computes the mean of squared values (without the square root, since `torch.rsqrt` handles it later).  

- **Normalization**:  
  Each token is scaled by $\frac{1}{\sqrt{\text{mean}(x^2) + \epsilon}}$, ensuring features have consistent magnitude.  

---

#### **Key Formula**  
$
\text{mean}(x^2) = \frac{1}{d} \sum_{i=1}^d x_i^2 \quad \text{(computed along the last dimension)}
$  
where \(d\) is the feature dimension (`dim` in the code).  

**Note**: `keepdims=True` ensures the output shape matches the input for broadcasting (e.g., multiplying with `x`).