<a href="https://colab.research.google.com/github/charlesfrye/AdventOfCode/blob/master/lightning/perceptron/dataloading.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Learning About Dataloading

## Installing and Importing Libraries

In [None]:
%%capture
!pip install pytorch-lightning wandb

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.datasets

mnist = torchvision.datasets.MNIST(".", download=True)

## Producing a `DataLoader`

In [None]:
class PerceptronDataModule(pl.LightningDataModule):

  def __init__(self, batch_size=64):
    super().__init__()  # ⚡: we inherit from LightningDataModule
    self.batch_size = batch_size

  def prepare_data(self): # ⚡: how do we set up the data?
    # download the data from the internet
    mnist = torchvision.datasets.MNIST(".", train=True, download=True)

    # set up shapes and types
    self.digits, self.is_5 = mnist.data.float(), (mnist.targets == 5)[:, None].float()
    self.dataset = torch.utils.data.TensorDataset(self.digits, self.is_5)

  def train_dataloader(self):  # ⚡: how do we go from dataset to dataloader?
    """The DataLoaders returned by a DataModule produce data for a model.
    
    This DataLoader is used during training."""
    return DataLoader(self.dataset, batch_size=self.batch_size)

In [None]:
dmodule = PerceptronDataModule()
dmodule.prepare_data()

## Examining the Data

The raw data is attached to the `DataModule` as arrays:

In [None]:
dmodule.digits[0], dmodule.is_5[0] # return first entry of digits

In [None]:
def show(im):
  plt.imshow(im, cmap="Greys"); plt.axis("off");

show(dmodule.digits[0]); print(dmodule.is_5[0])  # try idx!=0

But the model doesn't see the data this way.

In order to orchestrate the loading and processing of the data effectively,
PyTorch/Lightning uses `DataLoader`s:


In [None]:
trainloader = dmodule.train_dataloader()

To use `DataLoader`s, we iterate over them.
Each iteration produces a "batch".

```python
for batch in dataloader:
  x, y = batch
  # do stuff to x and y with Model
```

#### Q Use the `.shape` method to determine the shape and the `dtype` method to determine the type of the tensors in the `batch`.

In [None]:
for batch in trainloader:
  digit, target = batch
  print("") # YOUR CODE HERE
  break # only iterate once

show(digit[0])