<a href="https://colab.research.google.com/github/ioannis-toumpoglou/pytorch-repo/blob/main/06_pytorch_transfer_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 06. PyTorch Transfer Learning

What is transfer learning?

Transfer learning is about getting the parameters of what one model has learned on a dataset and applying them to another problem.

* Pretrained model = foundation models

In [1]:
import torch
import torchvision

print(torch.__version__)  # want 1.12+
print(torchvision.__version__)  # want 0.13+

2.0.1+cu118
0.15.2+cu118


In [2]:
# Continue with regular imports
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn
from torchvision import transforms

# Try to get torchinfo, install it if it doesn't work
try:
    from torchinfo import summary
except:
    print("[INFO] Couldn't find torchinfo... installing it.")
    !pip install -q torchinfo
    from torchinfo import summary

# Try to import the going_modular directory, download it from GitHub if it doesn't work
try:
    from going_modular.going_modular import data_setup, engine
except:
    # Get the going_modular scripts
    print("[INFO] Couldn't find going_modular scripts... downloading them from GitHub.")
    !git clone https://github.com/mrdbourke/pytorch-deep-learning
    !mv pytorch-deep-learning/going_modular .
    !rm -rf pytorch-deep-learning
    from going_modular.going_modular import data_setup, engine

[INFO] Couldn't find torchinfo... installing it.
[INFO] Couldn't find going_modular scripts... downloading them from GitHub.
Cloning into 'pytorch-deep-learning'...
remote: Enumerating objects: 3824, done.[K
remote: Counting objects: 100% (467/467), done.[K
remote: Compressing objects: 100% (261/261), done.[K
remote: Total 3824 (delta 246), reused 410 (delta 199), pack-reused 3357[K
Receiving objects: 100% (3824/3824), 650.63 MiB | 38.08 MiB/s, done.
Resolving deltas: 100% (2202/2202), done.
Updating files: 100% (248/248), done.


In [3]:
# Setup device-agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
!nvidia-smi

Mon Jun 19 09:36:37 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P8    12W /  70W |      3MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## 1. Get data

Get the pizza, steak, sushi data to build a transfer learning model on.

In [6]:
import os
import zipfile
from pathlib import Path
import requests

# Setup data path
data_path = Path('data/')
image_path = data_path / 'pizza_steak_sushi'

# Download and prepare data if not exist
if image_path.is_dir():
  print(f'{image_path} directory exists, skipping download')
else:
  print(f'{image_path} not found, downloading')
  image_path.mkdir(parents=True, exist_ok=True)

  # download data
  with open(data_path / 'pizza_steak_sushi.zip', 'wb') as f:
    request = requests.get('https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip')
    print(f'Downloading data...')
    f.write(request.content)

  # unzip data
  with zipfile.ZipFile(data_path / 'pizza_steak_sushi.zip', 'r') as zip_ref:
    print(f'Unzipping data...')
    zip_ref.extractall(image_path)

  # remove zip file
  os.remove(data_path / 'pizza_steak_sushi.zip')

data/pizza_steak_sushi not found, downloading
Downloading data...
Unzipping data...


In [7]:
# Setup directory path
train_dir = image_path / 'train'
test_dir = image_path / 'test'

train_dir, test_dir

(PosixPath('data/pizza_steak_sushi/train'),
 PosixPath('data/pizza_steak_sushi/test'))