In [1]:
import torch
from torchtext import data

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LABEL = data.Field(sequential=False, use_vocab=False)
TEXT = data.Field(sequential=True, lower=True)

train, val = data.TabularDataset.splits(path='test_text', train='train.csv', validation='val.csv',
                                        format='csv', skip_header=True,
                                        fields=[('PhraseId', None), ('SentenceId', None),
                                                ('Phrase', TEXT), ('Sentiment', LABEL)])

TEXT.build_vocab(train, vectors='glove.6B.100d', vectors_cache='vector_cache/')



###  ★★★★★Iterator输出的<font color='red'>所有</font>设置方法(<font color='red'>根据需求合理选择其中某一个</font>):

In [3]:
# 设置方法一:(全部)根据sort_key进行排序后输出
train_iterator_sort = data.Iterator(dataset=train, device=device, batch_size=12,
                                    sort=True, # Whether to sort examples according to self.sort_key
                                    #  A key to use for sorting examples in order to batch together examples with similar lengths and minimize padding. The sort_key provided to the Iterator constructor overrides the sort_key attribute of the Dataset, or defers to it if None.
                                    sort_key=lambda x: len(x.Phrase))
'''
self.sort_within_batch/self.sort_key/self.sort三者之间的区别
if sort_within_batch is None:
    self.sort_within_batch = self.sort
else:
    self.sort_within_batch = sort_within_batch
if sort_key is None:
    self.sort_key = dataset.sort_key
else:
    self.sort_key = sort_key
'''


for batch in train_iterator_sort:
    print(type(batch), end='\n\n')
    print(batch, end='\n\n')
    print(batch.Phrase, end='\n\n')
    print(batch.Phrase.shape, end='\n\n')
    print(batch.Sentiment)
    break

<class 'torchtext.data.batch.Batch'>


[torchtext.data.batch.Batch of size 12]
	[.Phrase]:[torch.LongTensor of size 1x12]
	[.Sentiment]:[torch.LongTensor of size 12]

tensor([[13658,  1854,   640, 12904, 13753, 14087, 11736, 14194, 12555, 13138,
         12541,     1]])

torch.Size([1, 12])

tensor([2, 2, 3, 2, 2, 2, 2, 3, 1, 3, 2, 1])




In [4]:
# 设置方法二:按照原有顺序输出
train_iterator_no_shuff = data.Iterator(dataset=train, device=device, batch_size=12,
                                        shuffle=False)
for batch in train_iterator_no_shuff:
    print(type(batch), end='\n\n')
    print(batch, end='\n\n')
    print(batch.Phrase, end='\n\n')
    print(batch.Phrase.shape, end='\n\n')
    print(batch.Sentiment)
    break

<class 'torchtext.data.batch.Batch'>


[torchtext.data.batch.Batch of size 12]
	[.Phrase]:[torch.LongTensor of size 17x12]
	[.Sentiment]:[torch.LongTensor of size 12]

tensor([[    2,   666,   142,   177,    13,    90,   271, 12541,   121,     4,
            11,     4],
        [ 1685, 11049,    10,   554,     9,     7,    76,     1,    23,  1120,
             2, 16459],
        [  537, 15341,  2464,    43, 13876,   116,     1,     1,    21,   681,
            32,  1289],
        [    5,     1,     8,  6316,   227,     1,     1,     1,    54,     1,
           388,     1],
        [ 4198,     1,     1,  2164,    50,     1,     1,     1,     1,     1,
             6,     1],
        [    6,     1,     1,     1,     2,     1,     1,     1,     1,     1,
         11439,     1],
        [12652,     1,     1,     1,   268,     1,     1,     1,     1,     1,
           146,     1],
        [  479,     1,     1,     1,    18,     1,     1,     1,     1,     1,
             1,     1],
        

In [5]:
# 设置方法二:完全随机输出
train_iterator_shuff = data.Iterator(dataset=train, device=device, batch_size=12,
                                     shuffle=True)
for batch in train_iterator_shuff:
    print(type(batch), end='\n\n')
    print(batch, end='\n\n')
    print(batch.Phrase, end='\n\n')
    print(batch.Phrase.shape, end='\n\n')
    print(batch.Sentiment)
    break

<class 'torchtext.data.batch.Batch'>


[torchtext.data.batch.Batch of size 12]
	[.Phrase]:[torch.LongTensor of size 19x12]
	[.Sentiment]:[torch.LongTensor of size 12]

tensor([[  180,    50,     2,     2,    11,     4,     7,     5,    19,    49,
            41,    85],
        [ 1246,   100,   145,   120,    21,    62,  2487,     2,  2122,     5,
            24,    72],
        [    1,    27,    11, 13184,   117,   128,    47,   106,   127,  3379,
           211,  2270],
        [    1,  8437, 14618,   227,   419,   235,    71,     1,    60,     9,
           186,    19],
        [    1,    34,  2923,    47,    16,    37,    54,     1,    58,   117,
             4,   185],
        [    1,  2243,     1,    10,   122,    27,   307,     1,  2577,     1,
            20,   518],
        [    1,     1,     1,  1610,     1,  1259,    22,     1,  1663,     1,
           383,     1],
        [    1,     1,     1,  2138,     1,     3,    33,     1,    11,     1,
             5,     1],
        

In [6]:
# 设置方法三:对随机选择的每个batch内的数据根据sort_key进行排序后输出
train_iterator_com = data.Iterator(dataset=train, device=device, batch_size=12,
                                   shuffle=True,
                                   sort_key=lambda x: len(x.Phrase),
                                   sort_within_batch=True) # Whether to sort (in descending order according to self.sort_key) within each batch
for batch in train_iterator_com:
    print(type(batch), end='\n\n')
    print(batch, end='\n\n')
    print(batch.Phrase, end='\n\n')
    print(batch.Phrase.shape, end='\n\n')
    print(batch.Sentiment)
    break




<class 'torchtext.data.batch.Batch'>


[torchtext.data.batch.Batch of size 12]
	[.Phrase]:[torch.LongTensor of size 19x12]
	[.Sentiment]:[torch.LongTensor of size 12]

tensor([[    4,     7,    19,    41,     2,    50,    11,    85,     2,    49,
             5,   180],
        [   62,  2487,  2122,    24,   120,   100,    21,    72,   145,     5,
             2,  1246],
        [  128,    47,   127,   211, 13184,    27,   117,  2270,    11,  3379,
           106,     1],
        [  235,    71,    60,   186,   227,  8437,   419,    19, 14618,     9,
             1,     1],
        [   37,    54,    58,     4,    47,    34,    16,   185,  2923,   117,
             1,     1],
        [   27,   307,  2577,    20,    10,  2243,   122,   518,     1,     1,
             1,     1],
        [ 1259,    22,  1663,   383,  1610,     1,     1,     1,     1,     1,
             1,     1],
        [    3,    33,    11,     5,  2138,     1,     1,     1,     1,     1,
             1,     1],
        