# Import 

In [None]:
# Import torch
import torch
import torch.nn as nn
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt

## Setup device agnostic

In [None]:
# TODO: Setup device agnostic code
from argparse import ArgumentParser

parser = ArgumentParser(description="computer vision model argument")
parser.add_argument("--disable-cuda",  type=bool, default=False, help="Choose cuda device to train model?")
parser.add_argument("--learning-rate", "-lr", type=float, default=.01, help="Learning rate")
parser.add_argument("--epochs", "-e", type=int, default=40, help="Epochs")
parser.add_argument("--MODEL-PATH", type=str, default="../../Module/models", help="Model save path")
parser.add_argument('--file', '-f', type=str)
args = parser.parse_args()

args.device = None
if not args.disable_cuda and torch.cuda.is_available():
    args.device = "cuda"
else:
    args.device = 'cpu'

## Download Dataset

In [None]:
import os
import zipfile
import requests
from pathlib import Path


url = "https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip"
data_path = Path("data")
image_path = data_path / "pizza_steak_sushi"

if image_path.is_dir():
    print(f"{image_path} has already been downloaded...")
else:
    print(f"Folder {image_path} wasn't founded. Creating one...")
    
    image_path.mkdir(parents=True, exist_ok=True)
    
    with open(image_path / "pizza_steak_sushi.zip", "wb") as f:
        print("Downloading pizza_steak_sushi.zip...")
        response = requests.get(url)
        f.write(response.content)
        
    with zipfile.ZipFile(image_path / "pizza_steak_sushi.zip") as zip_file:
        print(f"Extracting file...")
        zip_file.extractall(image_path)
        
    print(f"Downloaded Successfully")

## Data Exploding

### Explode shape, Visualize Image

In [None]:
def walk_through_path(folder_path):
    for root, dirs, files in os.walk(folder_path):
        print(f"There are {len(dirs)} directories and {len(files)} files in {root}")
        
walk_through_path(image_path)

In [None]:
list(image_path.glob("*/*/*.jpg"))

In [None]:
from PIL import Image

Image.open(list(image_path.glob("*/*/*.jpg"))[50])

### Get random image and visualize

1. Get all of the image paths using `pathlib.Path.glob()` to find all of the files ending in .jpg.
2. Pick a random image path using Python's `random.choice()`.
3. Get the image class name using `pathlib.Path.parent.stem`.
4. And since we're working with images, we'll open the random image path using `PIL.Image.open()` (`PIL` stands for **Python Image Library**).
5. We'll then show the image and print some metadata.

In [None]:
import random
from PIL import Image

random.seed(10)

# Get all image path in every dir and convert it to list
image_list = list(image_path.glob("*/*/*.jpg"))

# Get random image
random_image = random.choice(image_list)

# Get class name
class_random_image = random_image.parent.stem

img = Image.open(random_image)

# Print
print(f"Class name image: {class_random_image}")
print(f"Root directory: {random_image}")
print(f"Width x Height: {img.width} x {img.height}")
img

## 3. Transforming data
Now what if we wanted to load our image data into **PyTorch**?

Before we can use our image data with PyTorch we need to:

- Turn it into tensors (numerical representations of our images).
- Turn it into a `torch.utils.data.Dataset` and subsequently a `torch.utils.data.DataLoader`, we'll call these Dataset and *DataLoader* for short.
- There are several different kinds of pre-built datasets and dataset loaders for **PyTorch**, depending on the problem you're working on.

### 3.1 Transforming data with torchvision.transforms
We've got folders of images but before we can use them with PyTorch, we need to convert them into tensors.

One of the ways we can do this is by using the torchvision.transforms module.

`torchvision.transforms` contains many pre-built methods for formatting images, turning them into tensors and even manipulating them for data augmentation (the practice of altering data to make it harder for a model to learn, we'll see this later on) purposes .

To get experience with `torchvision.transforms`, let's write a series of transform steps that:

 - Resize the images using `transforms.Resize()` (from about *512x512* to *64x64*, the same shape as the images on the [CNN Explainer website](https://poloclub.github.io/cnn-explainer/)).
 - Flip our images randomly on the horizontal using `transforms.RandomHorizontalFlip()` (this could be considered a form of data augmentation because it will artificially change our image data).
 - Turn our images from a `PIL` image to a PyTorch tensor using `transforms.ToTensor()`.
 - We can compile all of these steps using `torchvision.transforms.Compose()`.

In [None]:
from torchvision.transforms import Compose
from torchvision.transforms import Resize
from torchvision.transforms import RandomHorizontalFlip
from torchvision.transforms import ToTensor

data_transforms = Compose([
    Resize(size=(64, 64)),
    RandomHorizontalFlip(p=.7),
    ToTensor()
])

In [None]:
random.seed(82)

def plot_transform_image(image_paths, transforms, n=4):
    """
    Plot transform random n images
    Print class name and size
    
    Arg:
        - image_paths: path to dataset, this must be Path instance
        - transforms: data transform pipeline
        - n: number of image that will be ploted
    """
    image_list = list(image_paths.glob("*/*/*.jpg"))
    image_random_choice = random.sample(list(image_list), k=n)
    for idx, image_path in enumerate(image_random_choice):
        with Image.open(image_path) as img:
            fig, (ax1, ax2) = plt.subplots(1, 2)
            ax1.imshow(img)
            ax1.set_title(f"Original \nSize: {img.size}")
            ax1.axis(False)
            
            # Transform and plot image
            # Note: permute() will change shape of image to suit matplotlib 
            # (PyTorch default is [C, H, W] but Matplotlib is [H, W, C])
            transformed_image = transforms(img).permute(1, 2, 0)
            ax2.imshow(transformed_image)
            ax2.set_title(f"Original \nSize: {transformed_image.shape}")
            ax2.axis(False)
            
            fig.suptitle(f"Class name {image_path.parent.stem}", fontsize=16)
    
    
plot_transform_image(image_path, data_transforms)


### 4. Turn loaded images into `DataLoader`'s

In [None]:
from torch.utils.data import DataLoader

train_transform = Compose([
    Resize(size=(64, 64)),
    RandomHorizontalFlip(p=.7),
    ToTensor()
])

test_transform = Compose([
    Resize(size=(64, 64)),
    ToTensor()
])

In [None]:
from torchvision.datasets import ImageFolder

train_dir = image_path / "train"
test_dir = image_path / "test"

train_dataset = ImageFolder(root=train_dir,
                            transform=train_transform,
                            target_transform=None)


test_dataset = ImageFolder(root=test_dir,
            transform=test_transform,
            target_transform=None)


In [None]:
os.cpu_count()

In [None]:
batch_size = 1
num_workers = 2

train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              shuffle=True)

test_dataloader = DataLoader(dataset=test_dataset,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              shuffle=False)




In [None]:
train_dataset.class_to_idx, train_dataset.classes

In [None]:
class_names = train_dataset.classes
class_names

In [None]:
image, label = next(iter(train_dataloader))

print(f"Shape {image.shape} -> [batch_size, color_channel, width, height] \nClass {class_names[label]}")
plt.imshow(image.squeeze().permute(1, 2, 0))
plt.axis("off")

## Loading Image Data with a Custom Dataset

To see this in action, let's work towards replicating `torchvision.datasets.ImageFolder()` by subclassing `torch.utils.data.Dataset` (the base class for all Dataset's in PyTorch).

We'll start by importing the modules we need:

- Python's `os` for dealing with directories (our data is stored in directories).
- Python's `pathlib` for dealing with filepaths (each of our images has a unique filepath).
- `torch` for all things PyTorch.
- `PIL`'s Image class for loading images.
- `torch.utils.data.Dataset` to subclass and create our own custom Dataset.
- `torchvision.transforms` to turn our images into tensors.
- Various types from Python's `typing` module to add type hints to our code.
> Note: You can customize the following steps for your own dataset. The premise remains: write code to load your data in the format you'd like it.

In [None]:
import os
import pathlib
import torch

from torch.utils.data import Dataset
from torchvision import transforms
from typing import Tuple, List, Dict

### Create helper function to get class names

Let's write a helper function capable of creating a list of class names and a dictionary of class names and their indexes given a directory path.

To do so, we'll:

- Get the class names using `os.scandir()` to traverse a target directory (ideally the directory is in standard image classification format).
- Raise an error if the class names aren't found (if this happens, there might be something wrong with the directory structure).
- Turn the class names into a dictionary of numerical labels, one for each class.

Let's see a small example of step 1 before we write the full function.

In [None]:
target_directory = train_dir
print(f"Target dir {target_directory}")

In [None]:
for row in os.scandir(target_directory): 
    print(row.name)

In [None]:
def scan_class_names(directory: str) -> Tuple[List[str], Dict[str, int]]:
    class_names = sorted([ dirs.name for dirs in os.scandir(directory) if dirs.is_dir() ])
    
    if not class_names:
        raise FileNotFoundError(f"Couldn't find any classes in {directory}")
        
    class_to_idx = {classes: i for i, classes in enumerate(class_names)}
        
    return class_names, class_to_idx

In [None]:
class_names, class_to_idx = scan_class_names(target_directory)

In [None]:
class_names, class_to_idx

In [None]:
class CustomDataset(Dataset):
    
    def __init__(self, root: str, transform=None) -> None:
        
        self.paths = list(pathlib.Path(root).glob("*/*.jpg"))
        
        self.transform = transform
        
        self.classes, self.class_to_idx = scan_class_names(root)
        
    def __len__(self) -> int:
        return len(self.paths)
    
    def load_image(self, index: int) -> Image.Image:
        img_path = self.paths[index]
        return Image.open(img_path)
    
    def __getitem__(self, index: int) -> Tuple[Image.Image, int]:
        
        image_load = self.load_image(index)
        class_names = self.paths[index].parent.stem
        class_idx = self.class_to_idx[class_names]
        
        if self.transform:
            return self.transform(image_load), class_idx
        return image_load, class_idx

In [None]:
train_customdataset = CustomDataset(root=train_dir,
                                    transform=train_transform)

test_customdataset = CustomDataset(root=test_dir,
                                    transform=test_transform)

In [None]:
img, label = train_customdataset[0]
print(class_names[label])
plt.imshow(img.squeeze().permute(1, 2, 0));

In [None]:
len(train_customdataset), len(test_customdataset)

In [None]:
# Check for equality amongst our custom Dataset and ImageFolder Dataset
print((len(train_customdataset) == len(train_dataset)) & (len(test_customdataset) == len(test_dataset)))
print(train_customdataset.classes == train_dataset.classes)
print(train_customdataset.class_to_idx == train_dataset.class_to_idx)

In [None]:
len(test_customdataset)
sorted(list(test_customdataset), key=lambda x: x[1])

In [None]:
def visualize_random_image_customdataset(dataset : Dataset,
                                         n : int = 4,
                                         seed: int = 82):
    random.seed(seed)
    size = len(dataset)
    plt.figure(figsize=(12, 6))
    for idx in range(n):
        ranint = random.randint(0, size)
        image, label = dataset[ranint]
        
        plt.subplot( (n // 4) + 1 , 4, idx + 1)
        plt.imshow(image.squeeze().permute(1, 2, 0))
        plt.title(class_names[label])
        plt.axis(False)
        
    plt.show()


In [None]:
visualize_random_image_customdataset(test_customdataset)

In [None]:
from torchvision.transforms import Compose
from torchvision.transforms import Resize
from torchvision.transforms import TrivialAugmentWide
from torchvision.transforms import ToTensor

train_transforms = Compose([
    Resize(size=(256, 256)),
    TrivialAugmentWide(num_magnitude_bins=31),
    ToTensor()
])

test_transforms = Compose([
    Resize(size=(256, 256)),
    ToTensor()
])

In [None]:
plot_transform_image(image_path, train_transforms)

## Model TinyVGG without data augment

In [None]:
simple_transform = Compose([
    Resize(size=(64, 64)),
    ToTensor()
])

In [None]:
test_dir

In [None]:
train_simple_dataset = ImageFolder(root=train_dir,
                                   transform=simple_transform,
                                   target_transform=None)

test_simple_dataset = ImageFolder(root=test_dir,
                                   transform=simple_transform,
                                   target_transform=None)

train_simple_dataset, test_simple_dataset

In [None]:
batch_size = 32
num_workers = 8

train_simple_dataloader = DataLoader(dataset=train_simple_dataset,
                           batch_size=batch_size,
                           num_workers=num_workers,
                           shuffle=True)

test_simple_dataloader = DataLoader(dataset=test_simple_dataset, 
                                    batch_size=batch_size,
                                    num_workers=num_workers,
                                    shuffle=False)


train_simple_dataloader, test_simple_dataloader

In [None]:
class TinyVGG(nn.Module):
    """
    Model architecture copying TinyVGG from: 
    https://poloclub.github.io/cnn-explainer/
    """
    def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:
        super().__init__()
        self.conv_block_1 = nn.Sequential(
            nn.Conv2d(in_channels=input_shape, 
                      out_channels=hidden_units, 
                      kernel_size=3, # how big is the square that's going over the image?
                      stride=1, # default
                      padding='same'), # options = "valid" (no padding) or "same" (output has same shape as input) or int for specific number 
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units, 
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1,
                      padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,
                         stride=2) # default stride value is same as kernel_size
        )
        self.conv_block_2 = nn.Sequential(
            nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,
                         stride=2) # default stride value is same as kernel_size
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            # Where did this in_features shape come from? 
            # It's because each layer of our network compresses and changes the shape of our inputs data.
            nn.Linear(in_features=hidden_units*16*16,
                      out_features=output_shape)
        )
    
    def forward(self, x: torch.Tensor):
        x = self.conv_block_1(x) 
        # print(x.shape)
        x = self.conv_block_2(x) 
        # print(x.shape)
        x = self.classifier(x)
        # print(x.shape)
        return x
        # return self.classifier(self.conv_block_2(self.conv_block_1(x))) # <- leverage the benefits of operator fusion

torch.manual_seed(42)
model_0 = TinyVGG(input_shape=3, # number of color channels (3 for RGB) 
                  hidden_units=10, 
                  output_shape=len(train_simple_dataset.classes)).to(args.device)
model_0

In [None]:
image, label = next(iter(train_simple_dataloader))

print(f"Size input image: {image.shape}")

with torch.inference_mode():
    infer_res = model_0(image.type(torch.float).to(args.device))

    
print(f"Size output result: {infer_res.shape}")

In [None]:
try:
    from torchinfo import summary
except:
    !pip install -q torchinfo
    from torchinfo import summary
    
summary(model_0, input_size=[1, 3, 64, 64])

In [None]:
def train_step(model: nn.Module,
               dataloader: torch.utils.data.DataLoader, 
               loss_fn: nn.Module,
               optimizer: torch.optim.Optimizer,
               accuracy_fn,
               device: torch.device = args.device) -> Dict[str, float]:
    
    model.train()
    
    total_loss, total_acc = 0, 0
    for idx, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        y_pred = model(X)
        
        loss = loss_fn(y_pred, y)
        acc = accuracy_fn(y_pred.argmax(dim = 1), y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss
        total_acc += acc
        
        if idx % 4 == 0:
            print(f"Trained on {idx * len(X)}/{len(dataloader.dataset)}")
            
    total_loss /= len(dataloader)
    total_acc /= len(dataloader)
    
    print(f"Train loss {total_loss:.4f} | Train accuracy {total_acc:.4f}")
    
    return {"loss_score": total_loss, "acc_score" :total_acc}
    

In [None]:
def test_step(model: nn.Module, 
              dataloader: torch.utils.data.DataLoader,
              loss_fn: nn.Module,
              accuracy_fn,
              device: torch.device = args.device) -> Dict[str, float]:
    model.eval()
    total_loss, total_acc = 0, 0 
    with torch.inference_mode():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            
            y_pred = model(X)
            loss = loss_fn(y_pred, y)
            acc = accuracy_fn(y_pred.argmax(dim = 1), y)
            
            total_loss += loss
            total_acc += acc
            
        total_loss /= len(dataloader)
        total_acc /= len(dataloader)
    
    print(f"Test loss {total_loss:.4f} | Test accuracy {total_acc:.4f}")
    
    return {"loss_score": total_loss, "acc_score": total_acc}


In [None]:
def accuracy_fn(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    return (y_pred == y_true).sum() / len(y_pred)

In [None]:
loss_fn = nn.CrossEntropyLoss()

optimizer_v0 = torch.optim.SGD(params=model_0.parameters(), lr=.01)

acc_fn = accuracy_fn

In [None]:
from tqdm.auto import tqdm
from timeit import default_timer as timer

def print_train_time(start: float, end: float, device: torch.device = args.device) -> float:
    total_time = end - start
    print(f"Total train time on {device}: {total_time:.3f}")
    return total_time

In [None]:
args.epochs = 5

epochs = args.epochs

start_train = timer()

for epoch in tqdm(range(epochs)):
    print(f"Epoch {epoch}:\n-----------\n")
    
    train_step(model=model_0, 
               dataloader=train_simple_dataloader,
               loss_fn=loss_fn,
               accuracy_fn=acc_fn,
               optimizer=optimizer_v0)
    
    test_step(model=model_0,
              dataloader=test_simple_dataloader,
              loss_fn=loss_fn,
              accuracy_fn=acc_fn)
    
end_train = timer()
   
time_train_v0 = print_train_time(start=start_train,
                                 end=end_train)    