# PyTorch Custom Datasets 

### Import PyTorch setup device agnostic code.

In [None]:
import torch
from torch import nn

print(torch.__version__)

In [None]:
device = "cuda" if torch.cuda.is_available() else 'cpu'
device

# Data

dataset is a subset of Food101 dataset


In [None]:
import requests
import zipfile
from pathlib import Path

data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi"

if image_path.is_dir():
    print(f"{image_path} directory already exists...skipping download")
else:
    print(f"{image_path} does not exist, creating one...")
    image_path.mkdir(parents=True, exist_ok=True)

with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
    request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
    print('downloading pizza, steak, and sushi data')
    f.write(request.content)

with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
    print("unzipping...")
    zip_ref.extractall(image_path)

# Exploring data

In [None]:
import os
def walk_through_dir(dir_path):
    for dirpath, dirnames, filenames in os.walk(dir_path):
        print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

In [None]:
walk_through_dir(image_path)

In [None]:
train_dir = image_path / "train"
test_dir = image_path / "test"

train_dir, test_dir

# Visualizing images

1. get all image path
2. pick random image: random.choice()
3. get the image class name using `pathlib.Path.parent.stem`
4. open image with Pillow
5. show image meta data

In [None]:
import random
from PIL import Image

# random.seed(42)

image_path_list = list(image_path.glob("*/*/*.jpg"))
image_path_list

random_image_path = random.choice(image_path_list)
random_image_path

image_class = random_image_path.parent.stem
image_class

img = Image.open(random_image_path)

print(f"random_image_path: {random_image_path}")
print(f"Image class: {image_class}")
print(f"Image height: {img.height}")
print(f"Image width: {img.width}")
img

In [None]:
import numpy as np
import matplotlib.pyplot as plt

img_as_array = np.asarray(img)

plt.figure(figsize=(10,7))
plt.imshow(img_as_array)
plt.title(f"Image class: {image_class} | Image shape: {img_as_array.shape} -> [height, width, color_channels]")
plt.axis(False)

In [None]:
img_as_array