In [1]:
import os
import requests
import zipfile
import pathlib

In [2]:
import torch
import torch.nn as nn
import torchvision

In [3]:
data_path = pathlib.Path("data/")
image_path = data_path / "pizza_steak_sushi"

if image_path.is_dir():
  print(f"{image_path} directory exist.")
else:
  image_path.mkdir(parents=True, exist_ok=True)
  print(f"{image_path} directory created.")

with open(data_path/"pizza_steak_sushi.zip", "wb") as f:
  request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
  print("Downloading pizza, steak, sushi data...")
  f.write(request.content)

with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
  print("Unzipping pizza, steak, sushi data...")
  zip_ref.extractall(image_path)

data/pizza_steak_sushi directory created.
Downloading pizza, steak, sushi data...
Unzipping pizza, steak, sushi data...


In [5]:
%%writefile going_modular/data_setup.py
import os
import torch
import torchvision

NUM_WORKERS = os.cpu_count()
def create_dataLoader(
    train_dir,
    test_dir,
    transform,
    batch_size,
    num_workers : int = NUM_WORKERS
):
  """
  Args:
    train_dir: Path to training directory.
    test_dir: Path to testing directory.
    transform: torchvision transforms to perform on training and testing data.
    batch_size: Number of samples per batch in each of the DataLoaders.
    num_workers: An integer for number of workers per DataLoader.
  """
  train_set = torchvision.datasets.ImageFolder(train_dir, transform)
  test_set = torchvision.datasets.ImageFolder(test_dir, transform)

  train_dataLoader = torch.utils.data.DataLoader(dataset = train_set, batch_size = batch_size, num_workers = num_workers, pin_memory = True, shuffle = True)
  test_dataLoader = torch.utils.data.DataLoader(dataset = test_set, batch_size = batch_size, num_workers = num_workers, pin_memory = True, shuffle = False)

  classes = train_set.classes

  return train_dataLoader, test_dataLoader, classes

Writing going_modular/data_setup.py
