In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# Dataset
- `torch.utils.data.Dataset` is an **abstract** class that represent a dataset
    - Override `__len__` and `__getitem__` methods
- `torch.utils.data.DataLoader` wraps a dataset and provides access to the data

# Customer dataset

In [2]:
class MyData(Dataset):
    def __init__(self, input_data):
        self.data=input_data
    
    def __getitem__(self, index):
        return self.data[index], index, index ** 2
    
    def __len__(self):
        return len(self.data)

# DataLoader

In [6]:
input_dict = {0:'a', 1:'b', 2:'c', 3:'d', 4:'e', 5:'f', 6:'g', 7:'h'}
train_set = MyData(input_dict)
train_loader = DataLoader(dataset=train_set, batch_size=3, shuffle=True)
print('Total batches: ', len(train_loader))
print('Example')
next(iter(train_loader))

Total batches:  3


[('e', 'h', 'd'), tensor([4, 7, 3]), tensor([16, 49,  9])]

# Train

In [7]:
for epoch in tqdm(range(7)):
    print('Epoch: ', str(epoch))
    for batch_x, batch_y, batch_pos in tqdm(train_loader):
        print('-'*10)
        print(batch_x)
        print(batch_y)
        print(batch_pos)

  0%|          | 0/7 [00:00<?, ?it/s]
  0%|          | 0/3 [00:00<?, ?it/s][A
100%|██████████| 3/3 [00:00<00:00, 1016.31it/s][A
  0%|          | 0/3 [00:00<?, ?it/s][A
100%|██████████| 3/3 [00:00<00:00, 1170.50it/s][A
  0%|          | 0/3 [00:00<?, ?it/s][A
100%|██████████| 3/3 [00:00<00:00, 972.78it/s][A
  0%|          | 0/3 [00:00<?, ?it/s][A
100%|██████████| 3/3 [00:00<00:00, 1332.09it/s][A
  0%|          | 0/3 [00:00<?, ?it/s][A
100%|██████████| 3/3 [00:00<00:00, 1588.95it/s][A
  0%|          | 0/3 [00:00<?, ?it/s][A
100%|██████████| 3/3 [00:00<00:00, 653.73it/s][A
  0%|          | 0/3 [00:00<?, ?it/s][A
100%|██████████| 7/7 [00:00<00:00, 67.86it/s]][A

Epoch:  0
----------
('h', 'd', 'b')
tensor([7, 3, 1])
tensor([49,  9,  1])
----------
('a', 'e', 'f')
tensor([0, 4, 5])
tensor([ 0, 16, 25])
----------
('g', 'c')
tensor([6, 2])
tensor([36,  4])
Epoch:  1
----------
('h', 'c', 'e')
tensor([7, 2, 4])
tensor([49,  4, 16])
----------
('a', 'b', 'd')
tensor([0, 1, 3])
tensor([0, 1, 9])
----------
('g', 'f')
tensor([6, 5])
tensor([36, 25])
Epoch:  2
----------
('c', 'h', 'a')
tensor([2, 7, 0])
tensor([ 4, 49,  0])
----------
('g', 'b', 'e')
tensor([6, 1, 4])
tensor([36,  1, 16])
----------
('f', 'd')
tensor([5, 3])
tensor([25,  9])
Epoch:  3
----------
('g', 'c', 'b')
tensor([6, 2, 1])
tensor([36,  4,  1])
----------
('h', 'd', 'e')
tensor([7, 3, 4])
tensor([49,  9, 16])
----------
('f', 'a')
tensor([5, 0])
tensor([25,  0])
Epoch:  4
----------
('f', 'b', 'c')
tensor([5, 1, 2])
tensor([25,  1,  4])
----------
('e', 'g', 'h')
tensor([4, 6, 7])
tensor([16, 36, 49])
----------
('a', 'd')
tensor([0, 3])
tensor([0, 9])
Epoch:  5
----------
('a'


