- Title: Tips on Dataset in PyTorch
- Slug: python-pytorch-dataset
- Date: 2020-02-26 22:40:06
- Category: Programming
- Tags: programming, Python, AI, data science, machine learning, deep learning, PyTorch, Dataset
- Author: Ben Du

## Tips

1. When you implement your own Dataset class,
    you need to inherit from 
    [torch.utils.data.Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)
    (or one of its subclasses).
    You must overwrite the 2 methods `__len__` and `__getitem__`.

2. When you implement your own Dataset class for image classification,
    it is best to inherit from 
    [torchvision.datasets.vision.VisionDataset](https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py#L6)
    .
    For example, 
    [torchvision.datasets.MNIST](https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py)
    subclasses 
    [torchvision.datasets.vision.VisionDataset](https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py#L6)
    . 
    You can use it as a template.
    Notice you still only have to overwrite the 2 methods `__len__` and `__getitem__`
    (even though the implementation of 
    [torchvision.datasets.MNIST](https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py)
    is much more complicated than that).
    [torchvision.datasets.MNIST](https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py)
    downloads data into the directory `MNIST/raw` 
    and make a copy of ready-to-use data into the directory `MNIST/processed`. 
    It doesn't matter whether you follow this convention or not
    as long as you overwrite the 2 methods `__len__` and `__getitem__`.
    What's more, the parameter `root` for the constructor of 
    [torchvision.datasets.vision.VisionDataset](https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py#L6)
    is not critical 
    as long as your Dataset subclass knows where and how to load the data
    (e.g., you can pass the full path of the data file as parameter for your Dataset subclass). 
    You can set it to `None` if you like. 
  

3. When you implement a Dataset class for image classification,
    it is best to have the method `__getitem__` return `(PIL.Image, target)`
    and then use `torchvision.transforms.ToTensor` to convert `PIL.Image` to tensor
    in the DataLoader.
    The reason is that transforming modules in `trochvision.transforms` 
    behave differently on `PIL.Image` 
    and their equivalent numpy array. 
    You might get surprises if you have `__getitem__` return `(torch.Tensor, target)`.
    If you do have `__getitem__` return `(torch.Tensor, target)`,
    make sure to double check that they tensors are as expected 
    before feeding them into your model for training/prediction.

1. `torchvision.transforms.ToTensor` (refered to as `ToTensor` in the following) 
    converts a PIL image to a numerical tensor with each value between [0, 1].
    `ToTensor` on a boolean numpy array returns a tensor of booleans. 
    It does not automatically convert booleans to numerical values (0/1 or 0/255). 
    `ToTensor` on a boolean numpy array of boolean values generate a tensor of 0/255 instead of 0/1!!!???


    
3. tensor.float, tensor.double on boolean tensor is suprising ...!!!

4. a boolean numpy array (black/white image), can be read in as a grayscale image (mode="L") -> 0/255
    but when use mode="1" doesn't look right!!!
    
5. no need to return target as tensor as DataLoader will automatically do that



In [42]:
import torch

## References

https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py

[VisionDataset](https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py#L6)

https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset

https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

https://pytorch.org/docs/stable/data.html