In [178]:
from collections import namedtuple
import torch
from torch.utils.data import default_collate  # torch.utils.data.DataLoader collate_fn方法默认使用的函数

In [179]:
# * `float` -> :class:`torch.Tensor`
# * `int` -> :class:`torch.Tensor`
print(default_collate([0, 1, 2, 3]))
print(default_collate([0., 1., 2., 3.]))

tensor([0, 1, 2, 3])
tensor([0., 1., 2., 3.], dtype=torch.float64)


In [180]:
# * `str` -> `str` (unchanged)
default_collate(['a', 'b', 'c'])

['a', 'b', 'c']

In [181]:
# * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])

{'A': tensor([  0, 100]), 'B': tensor([  1, 100])}

In [182]:
# `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]), default_collate([V2_1, V2_2, ...]), ...]`
Point = namedtuple('Point', ['x', 'y'])
default_collate([Point(0, 0), Point(1, 1)])

Point(x=tensor([0, 1]), y=tensor([0, 1]))

In [183]:
# `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]), default_collate([V2_1, V2_2, ...]), ...]`
default_collate([[0, 1],
                 [2, 3],
                 [4, 5]])

[tensor([0, 2, 4]), tensor([1, 3, 5])]

In [184]:
default_collate([[[1, 1], 11],
                 [[2, 2], 22],
                 [[3, 3], 33]])

[[tensor([1, 2, 3]), tensor([1, 2, 3])], tensor([11, 22, 33])]

In [185]:
default_collate([(torch.tensor([1, 1]), 11),
                 (torch.tensor([2, 2]), 22),
                 (torch.tensor([3, 3]), 33)])  # 相比于列表,合并成一个tensor

[tensor([[1, 1],
         [2, 2],
         [3, 3]]),
 tensor([11, 22, 33])]

In [186]:
default_collate([(torch.tensor([[1, 1],
                                [-1, -1]]), 11),
                 (torch.tensor([[2, 2],
                                [-2, -2]]), 22),
                 (torch.tensor([[3, 3],
                                [-3, -3]]), 33)])

[tensor([[[ 1,  1],
          [-1, -1]],
 
         [[ 2,  2],
          [-2, -2]],
 
         [[ 3,  3],
          [-3, -3]]]),
 tensor([11, 22, 33])]

In [187]:
default_collate([(torch.tensor([[[1, 1],
                                 [-1, -1]]]), 11),
                 (torch.tensor([[[2, 2],
                                 [-2, -2]]]), 22),
                 (torch.tensor([[[3, 3],
                                 [-3, -3]]]), 33)])

[tensor([[[[ 1,  1],
           [-1, -1]]],
 
 
         [[[ 2,  2],
           [-2, -2]]],
 
 
         [[[ 3,  3],
           [-3, -3]]]]),
 tensor([11, 22, 33])]