# Data Preparation and Exploration

In [356]:
import os
import random
from PIL import Image

from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

from src.config import PROCESSED_DATA_DIR, RAW_DATA_DIR
image_path = RAW_DATA_DIR

BATCH_SIZE = 32

In [312]:
def walk_through_directory(directory):
    """Walks through dir_path returning its contents."""
    for dirpath, dirnames, filenames in os.walk(directory):
        print(f"There are {len(dirnames)} directories and {len(filenames)} images in {dirpath}")

In [313]:
walk_through_directory(image_path)

In [314]:
# Setup train and testing paths
train_dir = image_path / 'train'
test_dir = image_path / 'test'
test_dir, train_dir

## Step 2: Visualize Image

Let's write some code to:
1. Get all the image paths
2. Pick a random image path using Python's random.choice()
3. Get the image class name using `pathlib.Path.parent.stem`
4. Since we're working with images, let's open the image with Python's PIL
5. We'll then show the image and print metadata

In [315]:
# 1. Get all image paths
image_path_list = list(image_path.glob("*/*/*.jpg"))

# 2. Randomly select a random image from the image path 
random_image_path = random.choice(image_path_list)

# 3. Get image class from path name
image_class = random_image_path.parent.stem

#4. Open image
img = Image.open(random_image_path)

#5. Print metadata
print(f"Random image path: {random_image_path}")
print(f"Image class: {image_class}")
print(f"Image size: {img.size}")
img

## 3. Transforming Data

Before we can use our image data with PyTorch:
1. Turn your target data into tensors (in our case, numerical representation of our images).
2. Turn it into a `torch.utils.data.Dataset`

In [323]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
])

In [324]:
transform(img).shape

In [336]:
train_dir

In [340]:
train_data = ImageFolder(train_dir, transform=transform)
test_data = ImageFolder(test_dir, transform=transform)


In [344]:
class_names = train_data.classes
class_names
class_names_to_idx = train_data.class_to_idx
class_names_to_idx

In [349]:
random_index = random.randint(0, len(train_data))
img, label = train_data[random_index]
print("Label:", class_names[label])
img

In [357]:
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

train_loader, test_loader

In [None]:
len(train_loader), len(test_loader)