In [2]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import random

In [3]:
def make_random_len_data_list(min_len, max_len, num_data):
    random_data = []
    
    for i in range(num_data):
        sample_len = random.randrange(min_len, max_len)
        sample = [random.randint(0, 9) for ii in range(sample_len)]
        random_data.append(sample)
    
    return random_data

In [4]:
make_random_len_data_list(10, 20, 10)

[[2, 8, 5, 5, 1, 3, 9, 6, 6, 1, 9, 2, 9, 6, 2, 1, 4, 5],
 [6, 9, 3, 3, 3, 5, 7, 6, 2, 8, 4, 6, 1],
 [6, 9, 8, 7, 2, 8, 4, 0, 1, 5, 4, 1, 4, 4, 8, 9, 2, 7],
 [2, 4, 9, 1, 5, 0, 6, 4, 9, 2, 9, 1, 8, 8, 3, 5, 0],
 [5, 0, 8, 4, 3, 1, 5, 9, 5, 7, 2, 9, 6, 1, 9, 6, 6, 6, 7],
 [4, 4, 0, 2, 9, 9, 8, 5, 7, 6, 8],
 [5, 0, 5, 3, 8, 7, 0, 8, 7, 6, 1, 8, 5, 6, 0],
 [0, 9, 1, 7, 1, 6, 1, 3, 5, 6, 5, 1, 6, 3, 6, 6],
 [0, 2, 1, 3, 6, 1, 9, 1, 2, 3, 2, 3, 8, 8, 6, 9],
 [6, 1, 6, 7, 6, 1, 1, 4, 8, 8]]

In [5]:
class Dataset_custom(Dataset):
    def __init__(self, data):
        self.x = data
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx]

In [6]:
def make_same_len(batch):
    len_list = [len(sample) for sample in batch]
    max_len = max(len_list)
    
    padded_batch = []
    pad_id = 0
    
    for sample in batch:
        padded_batch.append(sample + [pad_id] * (max_len - len(sample)))
    
    return padded_batch

In [7]:
def collate_fn_custom(batch):
    padded_batch = make_same_len(batch)
    
    padded_batch = torch.tensor(padded_batch)
    
    return padded_batch

In [11]:
rd = make_random_len_data_list(10, 20, 10)
ds = Dataset_custom(rd)

In [12]:
print(len(ds))
ds[0:3]

10


[[4, 4, 9, 1, 4, 2, 3, 5, 3, 2, 8, 9, 8, 1],
 [9, 4, 8, 7, 5, 6, 2, 9, 3, 9, 8, 0, 5],
 [4, 5, 8, 2, 9, 0, 5, 0, 0, 3]]

In [13]:
collate_fn_custom(ds[0:3])

tensor([[4, 4, 9, 1, 4, 2, 3, 5, 3, 2, 8, 9, 8, 1],
        [9, 4, 8, 7, 5, 6, 2, 9, 3, 9, 8, 0, 5, 0],
        [4, 5, 8, 2, 9, 0, 5, 0, 0, 3, 0, 0, 0, 0]])

In [14]:
dl = DataLoader(
    ds,
    batch_size=2,
    shuffle=True,
    collate_fn=collate_fn_custom
)

In [15]:
for i, batch in enumerate(dl):
    print(batch)

tensor([[8, 9, 7, 8, 1, 7, 5, 0, 1, 6, 7, 6, 8, 0, 0, 0, 0, 0, 0],
        [3, 8, 3, 2, 2, 2, 8, 9, 0, 4, 8, 1, 2, 7, 0, 1, 3, 8, 9]])
tensor([[7, 8, 6, 7, 2, 4, 7, 7, 9, 6, 7, 2, 0, 0, 0, 0, 0],
        [6, 5, 9, 1, 5, 1, 1, 5, 7, 1, 7, 3, 1, 4, 4, 9, 4]])
tensor([[9, 4, 8, 7, 5, 6, 2, 9, 3, 9, 8, 0, 5, 0, 0, 0],
        [7, 7, 3, 5, 3, 4, 4, 1, 1, 9, 1, 2, 0, 4, 5, 3]])
tensor([[6, 8, 2, 9, 8, 2, 9, 4, 0, 1, 8, 5, 0, 0],
        [4, 4, 9, 1, 4, 2, 3, 5, 3, 2, 8, 9, 8, 1]])
tensor([[4, 5, 8, 2, 9, 0, 5, 0, 0, 3, 0, 0],
        [0, 2, 4, 7, 2, 1, 3, 9, 2, 3, 3, 1]])
