## 06 Transfert Learning with PyTorch

What is Transfer Learning?

Transfer learning is a powerful technique in machine learning where knowledge gained from solving one problem is applied to a different but related problem.  Instead of starting the training process from scratch, you leverage a pre-trained model that has already learned useful features from a large dataset. This pre-trained model serves as a starting point, and you fine-tune it on your specific task, which typically requires less data and computational resources than training a model from scratch.

Here's a breakdown of the key concepts:

**1. Pre-trained Model:** This is a model that has been trained on a massive dataset, usually for a general task like image classification (e.g., ImageNet).  These models have learned a rich set of features that can be useful for other related tasks.

**2. Feature Extraction:**  Pre-trained models are excellent feature extractors.  The early layers of the model learn general features (like edges, textures), while later layers learn more specific features (like shapes, objects).  You can use a pre-trained model to extract these features from your data without retraining the entire model.

**3. Fine-tuning:**  This involves taking a pre-trained model and adapting it to your specific task. You typically replace the final layer(s) of the pre-trained model with layers suited for your task (e.g., a new classification layer with the correct number of classes). Then, you train the modified model on your dataset.  You can choose to freeze the weights of the early layers (to preserve the general features) and only train the later layers, or you can train the entire model with a lower learning rate.

**4. When to Use Transfer Learning:**

* **Limited Data:**  Transfer learning is especially beneficial when you have a small dataset for your target task.  Training a complex model from scratch on limited data often leads to overfitting.
* **Computational Constraints:** Fine-tuning a pre-trained model is much faster and requires fewer resources than training from scratch.
* **Improved Performance:**  Transfer learning can often lead to significantly better performance, especially when the source and target tasks are closely related.

**Example:**

Imagine you want to build an image classifier to identify different types of flowers. You have a small dataset of flower images. Instead of training a CNN from scratch, you can use a pre-trained model like ResNet, which has been trained on ImageNet (a massive dataset of general images). You remove ResNet's final classification layer (which classifies ImageNet categories) and replace it with a new classification layer with the number of flower types you want to classify. Then, you fine-tune the modified ResNet on your flower image dataset. The pre-trained features learned by ResNet on ImageNet will help the model learn to classify flowers more efficiently and accurately.

**Key Benefits:**

* Faster training
* Reduced data requirements
* Improved performance (often)
* Easier prototyping

Transfer learning is a widely used and highly effective technique in deep learning, allowing you to leverage the power of large pre-trained models to solve a wide range of problems.


In [1]:
import torch
import torchvision

print(torch.__version__)
print(torchvision.__version__)

2.4.1+cu121
0.19.1+cu121


In [2]:
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn
from torchvision.transforms import v2 as transforms

# Try to get torchinfo, install it if it doen't work
try:
    from torchinfo import summary
except:
    print("torchinfo not found, installing it")
    %pip install torchinfo
    from torchinfo import summary

# Try to import the going_modular directory, if it doesn't work, clone the repo
try:
    from going_modular import data_setup, engine
except:
    print("going_modular not found, cloning the repo")
    !git clone https://gi