# **Custom Training Logic**

| | |
|-|-|
| Author(s) | [Keeyana Jones](https://github.com/keeyanajones/) |

## **Overview**

Custom Training Logic often referred to as a custom training loop, is a fundamental concept in machine learning and deep learning, especially when you need more control and flexibility over the model training process than what pre-built, high level APIs (like Keras's `model.fit()`) offer.

Here is a breakdown of what it means and why its used:

### **What is Custom Training Logic?**

At is core, a custom training loop involves writing the entire training process from scratch, step by step.  Instead of calling a single `fit()` method, you explicitly define:

1. **Iteration over Epochs:** The number of times you want to iterate through your entire dataset.

2. **Iteration over Batches:** Within each epoch, how you'll process your data in smaller chunks (batches). 

3. **Forward Pass:** How the input data is fed through your model to generate predictions.  

4. **Loss Calculation:** How the model's predictions are compared to the actual target values to quantify the error (loss).

5. **Backward Pass (Gradient Calculation):** How the gradients of the loss with respect to the models trainable parameters (weights and biases) are computed.  This is typically done using automatic differentiation libraries (like `torch.autograd` in PyTorch or `tf.GradientTape` in TensorFlow).

6. **Optimizer Step:** How the models parameters are updated using an optimization algorithm (e.g., Stochastic Gradient Descent (SGD), Adam, RMSprop) and the calculated gradients.

7. **Metrics Calculation and Logging:** How you track and report performance metrics (e.g., accuracy, precision, recall, F1-score) during training.

8. **Validation/Evaluation:** Periodically evaluating the model's performance on a separate validation dataset to monitor for overfitting and track generalization.  

9. **Checkpointing:** Saving the model's state (weights, optimizer state) at regular intervals or when performance improves, so you can resume training later or load the best model.  

10. **Learning Rate Scheduling (Optional):** Dynamically adjusting the learning rate during training.

### **Why Use Custom Training Logic?**

While high level APs are excellent for rapid prototyping and standard tasks, custom training logic becomes necessary or highly beneficial for several reasons:

1. **Non-Standard Training Procedures** 
   - **Generative Adversarial Networks (GANs):**  GANs involve training two networks (generator and discriminator) in an adversarial manner.  This requires a specific, interleaved training loop that high level `fit()` methods don't natively support.
   - **Reinforcement Learning:** Training agents in RL often involves complex interactions with environments, value functions, and policy updates that go beyond simple supervised learning loops.  
   - **Meta-Learning:** Training models to learn to learn requires nested optimization loops.
   - **Multi-task Learning:** When a single model is trained to perform multiple tasks simultaneously, with potentially different loss functions or optimization schemes for each task.  

2. **Advanced Optimization Techniques**    
   - Implementing custom learning rate schedules that are not build tinto standard optimizers (e.g., cyclical learning rates, warm up schedules).
   - Applying custom regularization techniques directly in the loss computation or gradient updates. 
   - Developing novel optimization algorithms.

3. **Complex Loss Functions**
   - When loss functions is highly custom, involves multiple components, or requires specific intermediate computations that are not easily expressed within standard `compile()` or `add_loss()` methods. 

4. **Debugging and Granular Control** 
   - For researchers and advanced practitioners, a custom loop offers unparalleled visibility into the training process. You can inspect gradients, activations, and losses at any point, making debugging complex models much easier.
   - It allows you to control every single step, which is invaluable for fine turning performance or implementing experimental ideas.  

5. **Resource Management and Distributed Training** 
   - When training on multiple GPUs, TPUs, or across a cluster of machines, you often need to explicitly manage data distribution, gradient aggregation, and synchronization, which is more easily done with a custom loop.

6. **Integration with custom Data Pipelines**
   - if your data loading and preprocessing pipeline is highly specialized and doesn't fit neatly into standard `dataLoader` or `if.data.Dataset` patterns, a custom loop provides the flexibility to integrate it seamlessly.

7. **Research and Experimentation**
   - for cutting edge research, new model architectures, and novel training methodologies, pre-built abstractions often fall short.  Custom training logic is the bread and butter of ML research. 

In [None]:
### EXAMPLE (CONCEPTUAL) Custom Training Loop

# Model 
# Optimizer
# Loss_fit
# Train_Dataloader
# Val_Dataloader

for epoch in range(num_epochs):
    # Training Phase 
    model_train()
    total_train_loss = 0
    correct_train_predictions = 0
    total_train_samples = 0

    for batch_idx, (inputs, targets) in enumerate(train_dataloader):
        # 1. forward pass
        predictions = model(inputs) 
        # 2. calculate loss
        loss = loss_fn(predictions, targets)
        # 3. Zero gradients (clear previous gradients)  
        optimizer.zero_grand()
        # 4. Backward pass (compute gradients)
        loss.backward()
        # 5. Optimizer step (update model parameters)
        optimizer.step()

        total_train_loss += loss.item()

        predicted = torch.max(predictions.data, 1) 
        total_train_samples += targets.size(0)
        correct_train_predictions += (predicted == targets).sum().item()

    avg_train_loss = total_train_loss / len(train_dataloader)           
    train_accuracy = correct_train_predictions / total_train_samples
    print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss: .4f}, Train Act: {train_accuracy:.4f}")


# Validation Phase
model_eval()
total_val_loss = 0
correct_val_predictions = 0
total_val_samples = 0 

with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(val_dataloader):
        predictions = model(inputs)
        loss = loss_fn(predictions, targets)
        total_val_loss += loss.item()
        predicted = torch.max(predictions.data, 1)
        total_val_samples += targets.size(0)
        correct_val_predictions += (predicted == targets).sum().item()

    avg_val_loss = total_val_loss / len(val_dataloader)           
    val_accuracy = correct_val_predictions / total_val_samples
    print(f"Epoch {epoch+1}, Val Loss: {avg_val_loss: .4f}, Val Act: {val_accuracy:.4f}\n")
    

Frameworks provide convenient abstractions, understanding and implementing custom training logic is essential for deep learning practitioners who need maximum control, flexibility and the ability to innovate beyond standard training paradigms. 

----