# 06. PyTorch Transfer Learning

Let's take a well performing pre-trained model and adjust it to one of our own problems.

We can increase the performance of our model by getting a help from other already trained good models.

**What we're going to cover**
- take a pretrained model from `torchvision.models`
- and customize it to work on (and hopefully improve) our `foodvision` problem

Steps that we will do:
1. Setting up
2. Get data
3. Create datasets ad dataloaders
4. Get and Customize a pretrained model
5. Train model
6. Evaluate the model by plotting loss curves
7. Make predictions on images from the test set

## 1. Setting up

In [16]:
try:
    import torch
    import torchvision
    assert int(torch.__version__.split('.')[0]) >=2, 'torch version should be 2.0+'
    assert int(torchvision.__version__.split('.')[1]) >=17, 'torchvision version should be 0.17+'
    print(f'torch version: {torch.__version__}')
    print(f'torchvision version: {torchvision.__version__}')
except:
    print('[INFO] torch/torchvision verson not as required, installing nightly versions')
    !pip install -U torch torchvision
    import torch
    import torchvision
    print(f'torch version: {torch.__version__}')
    print(f'torchvision version: {torchvision.__version__}')
    

torch version: 2.2.2
torchvision version: 0.17.2


In [19]:
import matplotlib.pyplot as plot
import torch
import torchvision
from torch import nn
from torchvision import transforms

# try torchinfo, install if not installed
try:
    from torchinfo import summary
except:
    print(f"[INFO] Couldn't find torchinfo, intalling it.")
    !pip install -q torchinfo
    print(f'installed torchinfo {torchinfo.__version__}')
    from torchinfo import summary

## setup device agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

## 2. Getting Data

before we start transfer learning, we need a dataset.


In [23]:
from pathlib import Path
from zipfile import ZipFile
import requests
import os

data_url = "https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip"

data_path = Path('data')
image_data_path = data_path / 'pizza_steak_sushi'

# create folders if not already available
if image_data_path.is_dir():
    print('folders already eists')
else:
    print("Folders doesn't exist, creating one..")
    image_data_path.mkdir(parents=True, exist_ok=True)

    # download data
    with open(data_path / 'pizza_steak_sushi.zip', 'wb') as f:
        request = requests.get(data_url)
        print(f'Downloading data...')
        f.write(request.content)
        print(f'Download done.')
    # unzip
    with ZipFile(data_path/'pizza_steak_sushi.zip', 'r') as zip_file:
        print(f'Unzipping...')
        zip_file.extractall(image_data_path)
        print('Extracted..')
    # remove zip file
    os.remove(data_path/'pizza_steak_sushi.zip')


Folders doesn't exist, creating one..
Downloading data...
Download done.
Unzipping...


## Create datasets and dataloader
now that we have downloaded the data we should create `datasets` and corresponding `dataloader`