# Preprocesamiento y carga de datos

In [1]:
import torch

## PyTorch Dataset

### Creación de Datasets

##### Obtener sub Dataset (omitir clases)

In [None]:
from torch.utils.data import Dataset

class SubDataset(Dataset):

    def __init__(self, original_dataset, keep_classes):
        super().__init__()
        self.original_dataset = original_dataset
        self.indices = [i for i, data in enumerate(original_dataset) if data[1] in keep_classes]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, index):
        original_index = self.indices[index]
        return self.original_dataset[original_index]

##### Particionar dataset para entrenamiento y validación

In [None]:
from torch.utils.data import random_split

def split_dataset(dataset, test_ratio):
    n = len(dataset)
    n_train, n_val = n-int(n*test_ratio), int(n*test_ratio)
    return random_split(dataset, [n_train, n_val])
    
# Usar como dataset_train, dataset_val = split_dataset(dataset, test_ratio)

In [None]:
from torch.utils.data import random_split

def split_dataset(dataset, train_ratio, dev_ratio):
    n = len(dataset)
    n_train, n_dev = int(n*train_ratio), int(n*dev_ratio)
    n_test = n - n_train - n_dev
    return random_split(dataset, [n_train, n_dev, n_test])

# usar como dataset_train, dataset_dev, _ = split_dataset(dataset, train_ratio=0.7, dev_ratio=0.3)

##### Dataset a partir de tensores

In [3]:
from torch.utils.data import Dataset

class DatasetFromTensors(Dataset):
    
    def __init__(self, images, labels, transform=None, target_transform=None):
        super().__init__()
        self.images = images
        self.labels = torch.LongTensor(labels)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, index):
        image = self.images[index]
        label = self.labels[index]
        
        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            label = self.target_transform(label)
        
        return image, label

##### Dataset a partir de un CSV con nombres de archivos

In [4]:
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset

class LocalFilesDataset(Dataset):

    def __init__(self, imgs_dir, data_info, transform=None, target_transform=None):
        super().__init__()
        self.imgs_dir = imgs_dir
        self.data_info = pd.read_csv(data_info)  # cada línea de data_info contiene nombre_archivo.jpg, target.
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, index):
        img_name, label = self.data_info.iloc[index]
        img_path = os.path.join(self.imgs_dir, img_name)
        image = read_image(img_path)
        
        if self.transform:
            image = self.transform(image)
            
        if self.target_transform:
            label = self.target_transform(label)
            
        return image, label

##### Dataset de acuerdo a una lista de carpetas

In [None]:
from PIL import Image
import os
from torch.utils.data import Dataset

class FoldersDataset(Dataset):
    
    def __init__(self,
                 images_folders: list,
                 transform=None,
                 format_filter = ('.jpg', '.jpeg', '.png')):
        super().__init__()
        
        self.data_filepaths = []
        self.data_labels = []
        
        for num_class, class_folder in enumerate(images_folders):
            
            corrupted_images = []
            for image_filename in os.listdir(class_folder):
                image_path = os.path.join(class_folder, image_filename)
                
                if os.path.splitext(image_path)[-1] in format_filter:
                    try:
                        image = Image.open(image_path)
                    except:
                        corrupted_images.append(image_filename)
                    else:
                        self.data_filepaths.append(image_path)
                        self.data_labels.append(num_class)
            
            if len(corrupted_images) > 0:
                print(f'{len(corrupted_images)} corrupted image(s) in {class_folder}:')   
                for file in corrupted_images:
                    print(f'- {file}')
        
        self.transform = transform
        
    def __len__(self):
        return len(self.data_filepaths)
    
    def __getitem__(self, index):
        image = Image.open(self.data_filepaths[index])
        label = self.data_labels[index]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label
    
# También se podría usar torchvision.datasets.ImageFolder.

##### Dataset normalizado por canales

In [None]:
from torch.utils.data import Dataset
from torchvision import transforms

class NormalizedDataset(Dataset):

    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

        # Se concatenan las imágenes en la última dimensión ya que se normalizará con respecto a la primera dimensión (channels):
        images = torch.stack([img for img, _ in dataset], dim=3)  # torch.Size([C, H, W, len(dataset)]).
        n_channels = images.shape[0]
        flatten_images = images.view(n_channels, -1)  # torch.Size([n_channels, H*W*len(dataset)]).

        self.mean = flatten_images.mean(dim=1)
        self.std = flatten_images.std(dim=1)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        image, label = self.dataset[index]
        normalized_image = transforms.Normalize(self.mean, self.std)(image)
        return normalized_image, label
    
# Usar solo en datasets pequeños.

##### Otros

In [None]:
# leer archivo que contiene nombre de imágenes y etiquetas:

with open(dataset_file, 'r') as f:
    lines = f.readlines()
    
images, labels = [], []

for line in lines:
    line = line.strip()
    image, label = line.split(' ')
    images.append(image)
    labels.append(label)
    
# Clases str a int:
self.classes_num = dict()
self.classes_num['clase1'] = 0
self.classes_num['clase2'] = 1
self.classes_num['clase3'] = 2
# luego usar label = self.classes_num[label]

# Transformación para ver imagen (monocromática):
image = std * image + mean
image = np.clip(image, 0, 1)

# Pillow:
img = transforms.ToPILImage()(tensor)  # tensor a pillow.
img.save('imagen.jpg')  # guardar desde pillow.

### Visualización de Datasets

##### mostrar lista de imágenes

In [2]:
import matplotlib.pyplot as plt

def show_images(images, titles=None, rows=None, permute=False):
    
    n = len(images)

    if rows == None:
        rows = -(n//-10)  # math.ceil(n/10)
    cols = -(n//-rows)  # math.ceil(n/rows)
    
    fig = plt.figure(figsize=(2*cols, 2*rows))
    fig.subplots_adjust(wspace=0.5)
    
    for i in range(n):
        fig.add_subplot(rows, cols, i + 1)
        image = images[i] if permute == False else images[i].permute(1, 2, 0)
        plt.imshow(image)
        plt.axis('off')
        if titles != None:
            plt.title(titles[i])
            
    plt.show(block=True)

##### Mostrar instancias de un dataset

In [1]:
from torchvision.utils import make_grid
from matplotlib import pyplot as plt

def dataset_show(dataset, n=24, shuffle=True):
    
    plt.figure(figsize=(15, 15))
    
    if shuffle:
        indices = torch.randint(0, len(dataset), [n])
    else:
        indices = range(n)
    
    images = [dataset[i][0] for i in indices]
    images_grid = make_grid(images, normalize=True)
    plt.imshow(images_grid.permute(1,2,0))
    plt.title('Dataset samples')
    plt.show()

##### Mostrar una imagen de cada clase de un dataset

In [None]:
def one_sample_per_class(dataset):
    
    labels = list(set([label for img, label in dataset]))
    images = []

    for label in labels:
        img = next(img for img, img_label in dataset if img_label == label)
        images.append(img)
    
    show_images(images, labels, permute=True)

### Transformaciones

##### Obtener estadísticas de un dataloader

In [None]:
from torch.utils.data import DataLoader

def dataloader_statistics(dataset):
    
    dataloader = DataLoader(dataset, batch_size=len(dataset))
    
    cnt = 0
    fst_moment = torch.empty(3)
    snd_moment = torch.empty(3)

    for images, labels in dataloader:
        b, c, h, w = images.shape
        nb_pixels = b * h * w
        sum_ = torch.sum(images, dim=[0, 2, 3])
        sum_of_square = torch.sum(images ** 2, dim=[0, 2, 3])
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)
        cnt += nb_pixels

    mean, std = fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)   
         
    return mean, std

##### Codificación one-hot

In [None]:
def one_hot(n_classes):
    return transforms.Compose([lambda x: torch.LongTensor([x]),
                               lambda x: torch.nn.functional.one_hot(x, n_classes)])

# Usar como one_hot(n_classes) en el transform_target.

##### Pasar a imagen RGB

In [None]:
def to_rgb(image):
    if image.shape[0] == 1:  # grayscale
        return image.expand(3, -1, -1)
    if image.shape[0] > 3:  # PNG
        return image[0:3]
    return image

# Usar como to_rgb en el transform.

##### Esquema de transformación común

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

### Datasets clásicos

In [None]:
from torchvision import datasets, transforms

# Fake data:
from torchvision.datasets import FakeData
dataset = FakeData(1000, (3, 224, 224), 10, transform=transforms.ToTensor())

# MNIST:
mean, std = (0.1307,), (0.3081,)
labels = list(range(10))

# Fashion MNIST:
mean, std = (0.2859,), (0.3530,)
labels = 'T-shirt/top Trouser Pullover Dress Coat Sandal Shirt Sneaker Bag Ankle boot'.split()

# CIFAR-10:
mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
labels = 'airplane automobile bird cat deer dog frog horse ship truck'.split()

# Caltech 101:
mean, std =(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
dataset = datasets.Caltech101('data/Caltech101',
                              download=True,
                              transform=transforms.Compose([transforms.ToTensor(),
                                                            transforms.Resize((227, 227)),
                                                            lambda x: x.expand(3, -1, -1),
                                                            transforms.Normalize(mean, std)]))
!find . -name '*.DS_Store' - type f - delete

# ImageNet:
import requests
response = requests.get('https://git.io/JJkYN')
labels = response.text.split('\n')

## Dataloaders

### Imágenes de un batch a partir de un dataloader

In [1]:
def show_batch(dataloader):
    images, labels = next(iter(dataloader))
    labels = [f'Label: {label.item()}' for label in labels]
    show_images(images, labels, permute=True)

### Cargar un batch de imágenes (para inferencia)

##### Cargar una sola imagen

In [None]:
from torchvision.io import read_image
from torchvision import transforms

def img2batch(img_path, force_grayscale=False, image_resize=False):
    image = read_image(img_path)

    if force_grayscale:
        image = transforms.Grayscale()(image)

    if image_resize:
        image = transforms.Resize(image_resize)(image)
    return image.unsqueeze(0)

##### Cargar un directorio con imágenes

In [None]:
import os
from torchvision.io import read_image
from torchvision import transforms

def dir2batch(dir_path, image_resize=False):
    filenames = [name for name in os.listdir(dir_path) if os.path.splitext(name)[-1] == '.jpg']

    # La reserva de memoria es más eficiente que torch.stack((img1_tensor, img2_tensor, ...), dim=0).
    batch = torch.zeros(len(filenames), 3, *image_resize, dtype=torch.uint8)
    
    for i, filename in enumerate(filenames):
        image = read_image(os.path.join(dir_path, filename))
        if image_resize:
            image = transforms.Resize(image_resize)(image)
        batch[i] = image

    return batch