<a href="https://colab.research.google.com/github/gnoejh/ict1022/blob/main/Transformer/9_training_equation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Transformer Training Process

This notebook provides a mathematical representation of the Transformer training process, focusing on each step in a loop over the target sequence tokens. 
The Transformer model is trained by predicting each token in the target sequence one at a time, with a loop structure that advances one token at a time 
to calculate and accumulate the loss.

---



## 1. Input and Target Sequences

Let:
- $ X = \{x_1, x_2, \dots, x_n\} $ be the **input sequence** of length $ n $.
- $ Y = \{y_1, y_2, \dots, y_m\} $ be the **target sequence** of length $ m $.
- $ \hat{y}_t $ be the predicted token distribution at time step $ t $.

The model will predict each token in the target sequence by using teacher forcing, where each previously known token in the target sequence is used as input 
to predict the next token in the sequence.

### Teacher Forcing
Teacher forcing is a training strategy where the true target tokens are used as inputs to the decoder during training, instead of using the decoder's own previous predictions. This helps the model to converge faster and learn more effectively.



## 2. Forward Pass

### Encoder Output (Static for All Time Steps)
Encode the input sequence $ X $ once for the entire target sequence:
$$
Z_{\text{encoder}} = \text{Encoder}(X)
$$

---

### Loop Over Target Sequence
For $ t = 1 $ to $ m $ (the length of the target sequence):

#### a. Decoder Input at Time Step $ t $
- The decoder receives all previous tokens $ \{y_1, y_2, \dots, y_{t-1}\} $ as input to predict the next token $ y_t $.
- Formally, let $ Y_{\text{input}}^{(t)} = \{y_1, y_2, \dots, y_{t-1}\} $.

#### b. Decoder Output at Time Step $ t $
- The decoder generates an output representation based on $ Y_{\text{input}}^{(t)} $ and the encoder output $ Z_{\text{encoder}} $:
$$
Z_{\text{decoder}}^{(t)} = \text{Decoder}(Y_{\text{input}}^{(t)}, Z_{\text{encoder}})
$$

#### c. Prediction at Time Step $ t $
- The decoder output $ Z_{\text{decoder}}^{(t)} $ is transformed into a probability distribution over the vocabulary to predict the next token:
$$
\hat{y}_t = \text{softmax}(Z_{\text{decoder}}^{(t)} W_O)
$$
where $ W_O $ is the learned output weight matrix.

#### d. Cross-Entropy Loss at Time Step $ t $
- The loss for predicting $ y_t $ is computed using cross-entropy between the predicted distribution $ \hat{y}_t $ and the true token $ y_t $:
$$
\mathcal{L}^{(t)} = - \sum_{v=1}^V y_{t, v} \log \hat{y}_{t, v}
$$
where $ y_{t, v} $ is 1 for the correct token and 0 otherwise.



## 3. Total Loss for the Sequence

The total loss for the sequence is the sum of losses over all time steps:
$$
\mathcal{L} = \sum_{t=1}^m \mathcal{L}^{(t)}
$$



## 4. Backpropagation and Parameter Update

Using **backpropagation**, gradients of the loss $ \mathcal{L} $ with respect to each parameter are calculated. The optimizer (e.g., Adam) updates these 
parameters to minimize the loss:
$$
\theta \leftarrow \theta - \eta \nabla_{\theta} \mathcal{L}
$$
where:
- $ \theta $ represents all model parameters,
- $ \eta $ is the learning rate,
- $ \nabla_{\theta} \mathcal{L} $ is the gradient of the loss with respect to $ \theta $.

This completes the training step for one sequence pair.



## 5. Overall Training Process

The overall training process involves iterating over multiple epochs, where each epoch processes the entire training dataset. For each sequence pair in the dataset, the model performs the forward pass, computes the loss, and updates the parameters using backpropagation.

### Training Loop
For each epoch:
1. Shuffle the training dataset.
2. For each sequence pair (X, Y) in the dataset:
   - Perform the forward pass to compute predictions.
   - Compute the loss for the sequence.
   - Perform backpropagation to compute gradients.
   - Update the model parameters using the optimizer.

Repeat the process for the desired number of epochs or until convergence.
