# Dataset and DataLoader Pytorch Tutorial

## In summary, the neural network training pipeline in Pytorch is as follows: 

### 1. Data definition
* Dataset: create a Dataset class (torch.utils.data.Dataset) designed to return a single instance when indexed. 
* DataLoader: create a DataLoader class (torch.utils.data.DataLoader) to sample instance batches using parallel workers. 

### 2. Model Definition
* Model: build a neural network architecture using Module classes (torch.nn.Module).
* Loss: a loss function is also a PyTorch module, and hence you can create your own loss function, or use a default implementation from the framework (e.g., torch.nn.CrossEntropyLoss). More details at: https://pytorch.org/docs/stable/nn.html#loss-functions.

### 3. Training
* Hyper-parameter tuning: try several hyper-parameter combinations in order to find what works best. 
* Training: train your models during several epochs. Use some early stopping criteria (e.g., stop training if validation accuracy stop increasing in 5 epochs)
* Evaluation: evaluate your best model on test set.

**Note**: there are a plethora of standard datasets and neural net models at: https://pytorch.org/docs/stable/torchvision/index.html

**First, let's get started by importing useful stuff, and defining cool image plot functions.**

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import models
from torchvision import transforms

In [None]:
def plot_image(sample_img, sample_label):
    plt.imshow(sample_img, cmap='gray', )
    _ = plt.title(f'Label: {sample_label}')
    plt.xticks([])
    plt.yticks([])

def plot_multiple_images(images, labels, n=6):
    fig, ax = plt.subplots(1,n)
    for i, (img, label) in enumerate(zip(images, labels)):
        ax[i].imshow(img, )
        ax[i].set_title(f'Label: {label}')
        ax[i].set_xticks([])
        ax[i].set_yticks([])

# 1 Data definition

## 1.1 Using a pre-defined torchvision dataset

### Example: MNIST

In [None]:
mnist = datasets.MNIST('mnist/', download=True)

In [None]:
sample_img, sample_label = mnist[0]

In [None]:
plot_image(sample_img, sample_label)

### Example: CIFAR-10

In [None]:
cifar = datasets.CIFAR10('cifar/', download=True)

In [None]:
sample_img, sample_label = cifar[0]

In [None]:
plot_image(sample_img, sample_label)

### Example: SVHN (house number)

In [None]:
svhn = datasets.SVHN('svhn/', download=True)

In [None]:
sample_data = [svhn[i] for i in range(6)]
sample_images, sample_labels = zip(*sample_data)

In [None]:
plot_multiple_images(sample_images, sample_labels, n=6)

## 1.2 Using your own images 

The easier way to use your own images in PyTorch dataset is by using the ImageFolder class (torchvision.datasets.ImageFolder).

ImageFolder datasets are expecting a path to your images, that should organized as follows:

* root/dog/xxx.png
* root/dog/xxy.png
* root/dog/xxz.png

* root/cat/123.png
* root/cat/nsdf3.png
* root/cat/asd932_.png

*I.e.,*each folder inside the provided root directory are going to be used as classes. The dataset itself is going to define the number of classes based on the number of folders that contain images.

### Example using the reduced RPS (Rock Paper Scissor) dataset:

In [None]:
! wget https://github.com/YoussefAch/RPS-classification-using-PyTorch/blob/master/RPS.tar.gz
! tar xf RPS.tar.gz
! ls -l RPS

In [None]:
rps = datasets.ImageFolder('RPS')

In [None]:
sample_data = [rps[i] for i in range(3)]
sample_images, sample_labels = zip(*sample_data)
plot_multiple_images(sample_images, sample_labels, n=3)

## 1.3 Building your customized dataset

Keep in mind:
* You have to override: `__init__`, `__getitem__` and `__len__` functions.
* Your class has to handle train / val / test splits
* You have to apply transforms by yourself

In [None]:
import os 
from glob import glob

def load_image(path):
    from PIL import Image
    img = Image.open(path)
    return img.convert('RGB')
    
    
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, path, transform=None):
        super().__init__()
        self.transform = transform
        self.images = glob(f'{path}/*/*.jpg')
        self.classes_str = [x.split('/')[-2] for x in self.images]
        self.class_set = sorted(list(set(self.classes_str)))
    
    def __getitem__(self, idx):
        image_path = self.images[idx]
        label_str = self.classes_str[idx]
        label_int = self.class_set.index(label_str)
        image = load_image(image_path)
        
        if self.transform is not None:
            image = self.transform(image)
            
        return image, label_int, label_str
    
    def __len__(self):
        return len(self.images)

In [None]:
transform = transforms.Compose([transforms.ToTensor()])

In [None]:
dataset = MyDataset('RPS/', transform=transform)

In [None]:
sample_img, sample_label, sample_label_str = dataset[60]
plot_image(sample_img.permute(1, 2, 0), f'{sample_label} - {sample_label_str}')

**Note**: If your dataset is static and fits entirely into the RAM memory, you can load the whole dataset once, convert it to Tensor, and then use the class TensorDataset (`torch.utils.data.TensorDataset`). 

# 2. Use a DataLoader to sample batches using your dataset

DataLoaders can handle every type of `torch.utils.data.Dataset` objects

In [None]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=6, num_workers=2)

In [None]:
for images, labels, labels_str in train_loader:
    plot_multiple_images(images.permute(0, 2, 3, 1), labels, n=6)
    break