# Training a Transformer Model 101

**References**
- *Transformer Math 101: [Blog](https://blog.eleuther.ai/transformer-math/)*
- *Scaling Laws - CS 886 Presentation - Carter Blair & Cole Wyeth: [YouTube Video](https://youtu.be/_09CeX44tqg?si=U0Pkb7Hr0gb9imOR)*

##### Einsum

`torch.einsum(equation, *operands) -> Tensor`

equation (str): The equation string specifies the subscripts for each dimension of the input operands in the same order as the dimensions. It uses letters in [a-zA-Z] to represent subscripts. The subscripts for each operand are separated by a comma (','). For example, 'ij,jk->ik' specifies subscripts for two 2D operands. The subscripts that appear exactly once in the equation will be part of the output, sorted in increasing alphabetical order. Optionally, you can define the output subscripts by adding an arrow ('->') at the end of the equation followed by the subscripts for the output.

Note that for example for Q, if you want bhkn instead of bhnk, you can do that since the operation is always the sum, its just the output dimension is different so essentially something like a reshape.

The torch.einsum() function in PyTorch is a powerful tool that allows you to perform various multi-dimensional, linear algebraic operations in a concise way. Here are some common use cases:

1. **Element-wise multiplication**: You can use `einsum` to perform element-wise multiplication of two tensors, similar to the `*` operator.
    ```python
    A = torch.rand(3, 3)
    B = torch.rand(3, 3)
    result = torch.einsum('ij,ij->ij', A, B)
    ```

2. **Matrix multiplication**: einsum can be used to perform matrix multiplication, similar to `torch.matmul()` or the `@` operator.
    ```python
    A = torch.rand(3, 4)
    B = torch.rand(4, 5)
    result = torch.einsum('ij,jk->ik', A, B)
    ```

3. **Batch matrix multiplication**: If you have a batch of matrices and you want to multiply them, you can use `einsum` to do this in a single operation.
    ```python
    A = torch.rand(10, 3, 4)
    B = torch.rand(10, 4, 5)
    result = torch.einsum('bij,bjk->bik', A, B)
    ```

4. **Dot product**: `einsum` can be used to compute the dot product of two vectors.
    ```python
    a = torch.rand(5)
    b = torch.rand(5)
    result = torch.einsum('i,i->', a, b)
    ```

5. **Sum along a dimension**: You can use `einsum` to sum the elements of a tensor along a specific dimension, similar to `torch.sum()`.
    ```python
    A = torch.rand(3, 4, 5)
    sum_along_dim1 = torch.einsum('ijk->ik', A)
    ```

6. **Complex operations**: `einsum` really shines when you want to perform more complex operations that involve multiple steps and intermediate tensors.
