In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from tqdm import tqdm

# Dataset
- `torch.utils.data.Dataset` is an **abstract** class that represent a dataset
    - Override `__len__` and `__getitem__` methods
- `torch.utils.data.DataLoader` wraps a dataset and provides access to the data

# Customer dataset

In [2]:
class MyData(Dataset):
    def __init__(self, input_data):
        self.data=input_data
    
    def __getitem__(self, index):
        return self.data[index], index, index ** 2
    
    def __len__(self):
        return len(self.data)

# DataLoader

In [3]:
input_dict = {0:'a', 1:'b', 2:'c', 3:'d', 4:'e', 5:'f', 6:'g', 7:'h'}
train_set = MyData(input_dict)
train_loader = DataLoader(dataset=train_set, batch_size=3, shuffle=True)
print('Total batches: ', len(train_loader))
print('Example')
next(iter(train_loader))

Total batches:  3
Example


[('f', 'c', 'a'), tensor([5, 2, 0]), tensor([25,  4,  0])]

# Train

In [4]:
for epoch in tqdm(range(7)):
    print('Epoch: ', str(epoch))
    for batch_x, batch_y, batch_pos in tqdm(train_loader):
        print('-'*10)
        print(batch_x)
        print(batch_y)
        print(batch_pos)

  0%|          | 0/7 [00:00<?, ?it/s]
  0%|          | 0/3 [00:00<?, ?it/s][A
100%|██████████| 3/3 [00:00<00:00, 353.53it/s][A
  0%|          | 0/3 [00:00<?, ?it/s][A
100%|██████████| 3/3 [00:00<00:00, 229.67it/s][A
  0%|          | 0/3 [00:00<?, ?it/s][A
100%|██████████| 3/3 [00:00<00:00, 683.97it/s][A
  0%|          | 0/3 [00:00<?, ?it/s][A
 57%|█████▋    | 4/7 [00:00<00:00, 35.87it/s]][A
  0%|          | 0/3 [00:00<?, ?it/s][A
100%|██████████| 3/3 [00:00<00:00, 324.10it/s][A
  0%|          | 0/3 [00:00<?, ?it/s][A
100%|██████████| 3/3 [00:00<00:00, 394.54it/s][A
  0%|          | 0/3 [00:00<?, ?it/s][A
100%|██████████| 3/3 [00:00<00:00, 348.85it/s]

Epoch:  0
----------
('a', 'c', 'g')
tensor([0, 2, 6])
tensor([ 0,  4, 36])
----------
('e', 'b', 'd')
tensor([4, 1, 3])
tensor([16,  1,  9])
----------
('f', 'h')
tensor([5, 7])
tensor([25, 49])
Epoch:  1
----------
('h', 'g', 'c')
tensor([7, 6, 2])
tensor([49, 36,  4])
----------
('b', 'e', 'd')
tensor([1, 4, 3])
tensor([ 1, 16,  9])
----------
('f', 'a')
tensor([5, 0])
tensor([25,  0])
Epoch:  2
----------
('a', 'c', 'b')
tensor([0, 2, 1])
tensor([0, 4, 1])
----------
('d', 'h', 'g')
tensor([3, 7, 6])
tensor([ 9, 49, 36])
----------
('e', 'f')
tensor([4, 5])
tensor([16, 25])
Epoch:  3
----------
('d', 'f', 'c')
tensor([3, 5, 2])
tensor([ 9, 25,  4])
----------
('e', 'b', 'g')
tensor([4, 1, 6])
tensor([16,  1, 36])
----------
('h', 'a')
tensor([7, 0])
tensor([49,  0])
Epoch:  4
----------
('d', 'c', 'g')
tensor([3, 2, 6])
tensor([ 9,  4, 36])
----------
('f', 'e', 'a')
tensor([5, 4, 0])
tensor([25, 16,  0])
----------
('h', 'b')
tensor([7, 1])
tensor([49,  1])
Epoch:  5
----------
('

100%|██████████| 7/7 [00:00<00:00, 42.20it/s]


# TensorDataset
- `torch.utils.data.TensorDataset` wraps tensors (i.e., features, labels)
- Use it as input to DataLoader

In [5]:
# from https://discuss.pytorch.org/t/make-a-tensordataset-and-dataloader-with-multiple-inputs-parameters/26605/2
nb_samples = 10

# torch.randn returns a tensor filled with random numbers from a standard normal distribution
features = torch.randn(nb_samples, 10)
# tensor with size 1, elements ranges between [0, 10)
labels = torch.empty(nb_samples, dtype=torch.long).random_(10)
adjacency = torch.randn(nb_samples, 5)
laplacian = torch.randn(nb_samples, 7)

dataset = TensorDataset(features, labels, adjacency, laplacian)
loader = DataLoader(
    dataset,
    batch_size=2
)

for batch_idx, (x, y, a, l) in enumerate(loader):
    print(x.shape, y.shape, a.shape, l.shape)

torch.Size([2, 10]) torch.Size([2]) torch.Size([2, 5]) torch.Size([2, 7])
torch.Size([2, 10]) torch.Size([2]) torch.Size([2, 5]) torch.Size([2, 7])
torch.Size([2, 10]) torch.Size([2]) torch.Size([2, 5]) torch.Size([2, 7])
torch.Size([2, 10]) torch.Size([2]) torch.Size([2, 5]) torch.Size([2, 7])
torch.Size([2, 10]) torch.Size([2]) torch.Size([2, 5]) torch.Size([2, 7])
