In [1]:
import torch
from torch.utils.data import DataLoader
import numpy as np


data = np.array([
    [0.1, 7.4, 0],
    [-0.2, 5.3, 0],
    [0.2, 8.2, 1],
    [0.2, 7.7, 1]
])
loader = DataLoader(data, batch_size=2, shuffle=False)
batch = next(iter(loader))
batch

tensor([[ 0.1000,  7.4000,  0.0000],
        [-0.2000,  5.3000,  0.0000]], dtype=torch.float64)

In [2]:
dict_data = [
    {'x1': 0.1, 'x2': 7.4, 'y': 0},
    {'x1': -0.2, 'x2': 5.3, 'y': 0},
    {'x1': 0.2, 'x2': 8.2, 'y': 1},
    {'x1': 0.2, 'x2': 7.7, 'y': 10},
]
loader = DataLoader(dict_data, batch_size=2, shuffle=False)
batch = next(iter(loader))
batch

{'x1': tensor([ 0.1000, -0.2000], dtype=torch.float64),
 'x2': tensor([7.4000, 5.3000], dtype=torch.float64),
 'y': tensor([0, 0])}

In [3]:
nlp_data = [
    {'tokenized_input': [1, 4, 5, 9, 3, 2], 'label':0},
    {'tokenized_input': [1, 7, 3, 14, 48, 7, 23, 154, 2], 'label':0},
    {'tokenized_input': [1, 30, 67, 117, 21, 15, 2], 'label':1},
    {'tokenized_input': [1, 17, 2], 'label':0},
]
loader = DataLoader(nlp_data, batch_size=2, shuffle=False)
batch = next(iter(loader))

RuntimeError: each element in list of batch should be of equal size

In [4]:
from torch.nn.utils.rnn import pad_sequence


def custom_collate(data):
    inputs = [torch.tensor(d['tokenized_input']) for d in data]
    labels = [d['label'] for d in data]

    inputs = pad_sequence(inputs, batch_first=True)
    labels = torch.tensor(labels)

    return {'tokenized_input': inputs,'label': labels}


loader = DataLoader(nlp_data, batch_size=2, shuffle=False, collate_fn=custom_collate)

iter_loader = iter(loader)
batch1 = next(iter_loader)
print(batch1)
batch2 = next(iter_loader)
print(batch2)

{'tokenized_input': tensor([[  1,   4,   5,   9,   3,   2,   0,   0,   0],
        [  1,   7,   3,  14,  48,   7,  23, 154,   2]]), 'label': tensor([0, 0])}
{'tokenized_input': tensor([[  1,  30,  67, 117,  21,  15,   2],
        [  1,  17,   2,   0,   0,   0,   0]]), 'label': tensor([1, 0])}


In [5]:
img = torch.rand([100,100,3])

caption_data = [
    {'tokenized_input': torch.Tensor([1, 4, 5, 9, 3, 2]), 'image': img},
    {'tokenized_input': torch.Tensor([1, 7, 3, 14, 48, 7, 23, 154, 2]), 'image': img},
    {'tokenized_input': torch.Tensor([1, 30, 67, 117, 21, 15, 2]), 'image': img},
    {'tokenized_input': torch.Tensor([1, 17, 2]), 'image': img},
]

In [15]:
def collate_v2(batch):
    imgs = [item['image'].unsqueeze(0) for item in batch]
    img = torch.cat(imgs, dim=0)
    targets = [item['tokenized_input'] for item in batch]
    targets = pad_sequence(targets, batch_first=False)
    return img, targets


loader = DataLoader(caption_data, batch_size=2, shuffle=False, collate_fn=collate_v2)
batch1 = next(iter(loader))
batch1

tensor([[[0.9793, 0.7228, 0.1015],
         [0.8943, 0.6690, 0.6087],
         [0.9201, 0.5518, 0.0158],
         ...,
         [0.7315, 0.6927, 0.6763],
         [0.3446, 0.1579, 0.1253],
         [0.9792, 0.3437, 0.3341]],

        [[0.1959, 0.9009, 0.3502],
         [0.7227, 0.5673, 0.2331],
         [0.4654, 0.3096, 0.3922],
         ...,
         [0.7267, 0.2307, 0.9683],
         [0.0573, 0.4756, 0.9980],
         [0.1970, 0.6830, 0.4833]],

        [[0.9497, 0.2714, 0.9619],
         [0.2788, 0.5174, 0.0272],
         [0.5832, 0.5106, 0.9080],
         ...,
         [0.0483, 0.9736, 0.2991],
         [0.7753, 0.1790, 0.8592],
         [0.8276, 0.4556, 0.1739]],

        ...,

        [[0.0342, 0.5419, 0.6770],
         [0.1287, 0.0173, 0.7120],
         [0.3442, 0.6893, 0.6555],
         ...,
         [0.4775, 0.5405, 0.8124],
         [0.7603, 0.0475, 0.9764],
         [0.3287, 0.0549, 0.4437]],

        [[0.2611, 0.6442, 0.5079],
         [0.4211, 0.2508, 0.9381],
         [0.

(tensor([[[[0.9793, 0.7228, 0.1015],
           [0.8943, 0.6690, 0.6087],
           [0.9201, 0.5518, 0.0158],
           ...,
           [0.7315, 0.6927, 0.6763],
           [0.3446, 0.1579, 0.1253],
           [0.9792, 0.3437, 0.3341]],
 
          [[0.1959, 0.9009, 0.3502],
           [0.7227, 0.5673, 0.2331],
           [0.4654, 0.3096, 0.3922],
           ...,
           [0.7267, 0.2307, 0.9683],
           [0.0573, 0.4756, 0.9980],
           [0.1970, 0.6830, 0.4833]],
 
          [[0.9497, 0.2714, 0.9619],
           [0.2788, 0.5174, 0.0272],
           [0.5832, 0.5106, 0.9080],
           ...,
           [0.0483, 0.9736, 0.2991],
           [0.7753, 0.1790, 0.8592],
           [0.8276, 0.4556, 0.1739]],
 
          ...,
 
          [[0.0342, 0.5419, 0.6770],
           [0.1287, 0.0173, 0.7120],
           [0.3442, 0.6893, 0.6555],
           ...,
           [0.4775, 0.5405, 0.8124],
           [0.7603, 0.0475, 0.9764],
           [0.3287, 0.0549, 0.4437]],
 
          [[0.2611,