<a href="https://colab.research.google.com/github/christophergaughan/PyTorch/blob/main/PyTorch_Custom_Datasets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Here we must remember that there are 3 PyTorch Domains for using custom datasets. **Remeber** Different domain Libraries contain DataLoading funtions for different data sources. i.e. you'll want to look into each of these PyTorch domain libraries for existing dqta loading functions and customizable data loading functions:

* Visual- `torchvision.datasets`
* Text - `torchtext.datasets`
* Audio - `torchaudio.datasets`
* Recommendation system - `torchrec.datasets`

We've used some datasets with PyTorch so far.

BUT, how do you get your own data into PyTorch?

One way to do this is via *custom datasets*





## Importing PyTorch and setting up device agnostic code

In [None]:
import torch
from torch import nn

torch.__version__

In [None]:
# set up device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
!nvidia-smi

## Get data- we'll be getting food images (Food-101 Data Set)
* we'll start off with just three categories of foodand use just 10% of the data
* dataset is obviously just a subset of the full dataset.
* 3 classes, 100 images/class
* When starting out ML projects, it's important to try things on a small scale and only *then* increase the dataset i.e. scale it up
* at this point speed of experients is is faster b/c datset is smaller

In [None]:
import requests
import zipfile
from pathlib import Path

# Setup path to data folder
data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi"

# If the image folder doesn't exist, download it and prepare it...
if image_path.is_dir():
    print(f"{image_path} directory exists.")
else:
    print(f"Did not find {image_path} directory, creating one...")
    image_path.mkdir(parents=True, exist_ok=True)

    # Download pizza, steak, sushi data
    with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
        request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
        print("Downloading pizza, steak, sushi data...")
        f.write(request.content)

    # Unzip pizza, steak, sushi data
    with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
        print("Unzipping pizza, steak, sushi data...")
        zip_ref.extractall(image_path)

## Becoming one with the data (data prep and data exploration)

In [None]:
import os
def walk_through_dir(dir_path):
    """
    Walks through dir_path returning its contents.
    Args:
    dir_path (str or pathlib.Path): target directory
    Return: target data directory
    os.walk: directory tree maker
    """
    for dirpath, dirnames, filenames in os.walk(dir_path):
        print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'")

In [None]:
walk_through_dir(image_path)

In [None]:
# Setup our training and testing paths
train_dir = image_path / "train"
test_dir = image_path / "test"

train_dir, test_dir

### Visualizing an image

let's write some code to:
1. get all the image paths
2. pick a random inage path using Pythons random choice()
3. Get the image class name using `pathlib.Path.parent.stem`
4. Since we're working with images, let's open the image with Python's PIL
5. We'll show the image and print metadata

In [None]:
#import random
from PIL import Image

# Set seed
random.seed(42) # <- try changing this and see what happens

# 1. Get all image paths (* means "any combination")
image_path_list = list(image_path.glob("*/*/*.jpg"))

# 2. Get random image path
random_image_path = random.choice(image_path_list)

# 3. Get image class from path name (the image class is the name of the directory where the image is stored)
image_class = random_image_path.parent.stem

# 4. Open image
img = Image.open(random_image_path)

# 5. Print metadata
print(f"Random image path: {random_image_path}")
print(f"Image class: {image_class}")
print(f"Image height: {img.height}")
print(f"Image width: {img.width}")
img

### That doesn't look like 'Pizza'

## Visualize with matplotlib

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Turn the image into an array
img_as_array = np.asarray(img)

# Plot the image with matplotlib
plt.figure(figsize=(10, 7))
plt.imshow(img_as_array)
plt.title(f"Image class: {image_class} | Image shape: {img_as_array.shape} -> (height, width, color_channels)")
plt.axis(False);

## Transform all the images into torch.tensors

before we can use our image data with PyTorch:
1. Turn your target data into tensors
2. Turn it into a `torch.utils.data.Dataset` and subsequently a `torch.utils.data.DataLoader`, we'll call those `dataset` and `DataLoader`

NOTE:  we will be using `imagefolder` in PyTorch that has a `transform` function

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

### Transforming data with `torchvision.transforms`
turn .jpeg's --> toch.tensors

In [None]:
# Write a transform for image
data_transform = transforms.Compose([
    # Resize the images to 64x64
    transforms.Resize(size=(64, 64)),
    # Flip the images randomly on horizontal- data augmentation
    transforms.RandomHorizontalFlip(p=0.5),
    # Turn the image into a torch.Tensor- normalizes from 0 --> 1
    transforms.ToTensor()
])

In [None]:
data_transform(img).shape

### Visualizing transformed image
transforms help you get your images ready to be used with model/perform data augmentation-
https://pytorch.org/vision/stable/transforms.html

In [None]:
def plot_transformed_images(image_paths, transform, n=3, seed=42):
    """Plots a series of random images from image_paths.

    Will open n image paths from image_paths, transform them
    with transform and plot them side by side.

    Args:
        image_paths (list): List of target image paths.
        transform (PyTorch Transforms): Transforms to apply to images.
        n (int, optional): Number of images to plot. Defaults to 3.
        seed (int, optional): Random seed for the random generator. Defaults to 42.
    """
    random.seed(seed)
    random_image_paths = random.sample(image_paths, k=n)
    for image_path in random_image_paths:
        with Image.open(image_path) as f:
            fig, ax = plt.subplots(1, 2)
            ax[0].imshow(f)
            ax[0].set_title(f"Original \nSize: {f.size}")
            ax[0].axis("off")

            # Transform and plot image
            # Note: permute() will change shape of image to suit matplotlib
            # (PyTorch default is [C, H, W] but Matplotlib is [H, W, C])
            transformed_image = transform(f).permute(1, 2, 0) # note we need to change the shape for matplotlib (C, H, W) -> (H, W, C)
            ax[1].imshow(transformed_image)
            ax[1].set_title(f"Transformed \nSize: {transformed_image.shape}")
            ax[1].axis("off")

            fig.suptitle(f"Class: {image_path.parent.stem}", fontsize=16)

plot_transformed_images(image_path_list,
                        transform=data_transform,
                        n=3)