In [None]:
#default_exp data

In [None]:
#hide
from IPython.display import clear_output
from nbdev.export import notebook2script
from dotenv import load_dotenv
%reload_ext autoreload
%autoreload 2

_ = load_dotenv()

In [None]:
#export
import os
import torch
import logging
import multiprocessing
from easydict import EasyDict as edict
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

logger = logging.getLogger()
logger.setLevel("INFO")

# data

> 下載、載入並前處理數據並建立 Dataset 和 DataLoader 之模組。

## 取得數據根目錄

In [None]:
#export
def get_data_root(data_root=None):
    data_root = data_root if data_root else os.getenv("DATA_ROOT", ".")
    if not os.path.exists(data_root):
        os.makedirs(data_root)
    return data_root

In [None]:
predefined_data_root = os.getenv("DATA_ROOT")

if predefined_data_root:
    assert predefined_data_root == get_data_root()
else:
    assert get_data_root("tmp") == "tmp"

In [None]:
get_data_root()

'data'

## Generic Dataset / DataLoader

In [None]:
#export
class ImageOnlyDataset(Dataset):
    """常用於生成任務，只回傳圖片而不回傳標籤的 Dataset"""
    
    def __init__(self, img_label_dataset, img_idx=0):
        self.orig_dataset = img_label_dataset
        self.img_idx = img_idx
        
    def __len__(self):
        return len(self.orig_dataset)
    
    def __getitem__(self, idx):
        return self.orig_dataset[idx][self.img_idx]

## 數據集

In [None]:
#export
def get_dataset(dataset, split="full", size=None, transform=None, return_label=True, 
                  **kwargs):
    
    dataset = dataset.lower()
    if dataset == "mnist":
        size = size if size else (28, 28)
        logging.info(f"MNIST will be resized to {size}.")
        
        transform = transforms.Compose([
            transforms.Resize(size=size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ]) if not transform else transform
        
        root = get_data_root()
        ds_params = dict(root=root, transform=transform, download=True)
        if os.path.exists(os.path.join(root, "MNIST")):
            ds_params['download'] = False
        
        if split == "train":
            ds_params['train'] = True
        elif split == "test":
            ds_params['train'] = False
        dataset = datasets.MNIST(**ds_params)
        
        if not return_label:
            dataset = ImageOnlyDataset(dataset)
            
        setattr(dataset, "input_shape", (1, *size))
    else:
        raise NotImplementedError
    
    return dataset

### 指定數據集名稱

In [None]:
mnist = get_dataset("mnist")
x = mnist[0]

assert mnist.input_shape == (1, 28, 28)  # CHW
assert len(mnist) == 60_000
assert len(x) == 2
mnist.input_shape

INFO:root:MNIST will be resized to (28, 28).


(1, 28, 28)

### 切割數據集、改變圖片大小並不回傳標籤
生成任務有時不需要使用標籤資訊。


In [None]:
size = (32, 32)
mnist_test = get_dataset("mnist", split="test", size=size, return_label=False)
x = mnist_test[0]

assert mnist_test.input_shape == (1, *size)
assert len(mnist_test) == 10_000
assert len(x) == 1
mnist_test.input_shape

INFO:root:MNIST will be resized to (32, 32).


(1, 32, 32)

## DataLoader

In [None]:
#export
def get_data_loader(dataset, batch_size, shuffle=True, collate_fn=None, drop_last=True, **kwargs):
    use_cuda = torch.cuda.is_available()
    num_workers = multiprocessing.cpu_count() if use_cuda else 1
    
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 
                             num_workers=num_workers, collate_fn=collate_fn, 
                             drop_last=drop_last, pin_memory=use_cuda)
    return data_loader

In [None]:
batch_size = 32
mnist_data_loader = get_data_loader(mnist, batch_size=batch_size)
batch = next(iter(mnist_data_loader))

assert len(batch) == 2
assert batch[0].shape[0] == batch_size

In [None]:
#hide
notebook2script()
clear_output()