In [30]:
import torch
from torchtext import data

In [31]:
LABEL = data.Field(sequential=False,
                   use_vocab=True) # 若字段Sentiment已经为数字,则这里可以设置use_vocal=False
TEXT = data.Field(sequential=True, lower=True,
                  dtype=torch.float32, # Example的数据类型;默认torch.long,一般不用进行设置
                  include_lengths=True, # Whether to return a tuple of a padded minibatch and a list containing the lengths of each examples, or just a padded minibatch. Default: False.
                  fix_length=10) #  A fixed length that all examples using this field will be padded to, or None for flexible sequence lengths. Default: None.

train, val = data.TabularDataset.splits(path='test_text', train='train.csv', validation='val.csv',
                                        format='csv', skip_header=True,
                                        fields=[('PhraseId', None), ('SentenceId', None),
                                                ('Phrase', TEXT), ('Sentiment', LABEL)])

TEXT.build_vocab(train, vectors='glove.6B.100d', vectors_cache='vector_cache/')

In [32]:
LABEL.build_vocab(train)

In [33]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [34]:
train_iterator, valid_iterator = data.Iterator.splits(datasets=(train, val),
                                                      # Whether to shuffle examples between epochs
                                                      shuffle=False,
                                                      device=device,
                                                      batch_sizes=(11, 5)) # 每个批次含Example(句子)的个数,其中train_iterator批次大小为10,valid_iterator批次大小为5;若设置参数batch_size=11,则批次大小均为11

In [35]:
# 与上使用splits方法等价;若需对不同数据集设置不同的参数,则需分别进行设置(即采用下面的设置方法)
train_iterator_ = data.Iterator(dataset=train, device=device, batch_size=11, shuffle=False)
valid_iterator_ = data.Iterator(dataset=val, device=device, batch_size=5, shuffle=False)

In [36]:
'''
# Iterator __len__魔法方法的实现:
def __len__(self):
    if self.batch_size_fn is not None:
        raise NotImplementedError
    return math.ceil(len(self.dataset) / self.batch_size)

'''
print(len(train))
print(len(val))
print('*****************')
print(len(train_iterator))
print(len(valid_iterator))

124848
31212
*****************
11350
6243


In [39]:
for batch in train_iterator:
    print(type(batch), end='\n\n')
    print(batch, end='\n\n')
    print(batch.Phrase, end='\n\n') # 返回值为元组(由于include_lengths=True)
    print(batch.Phrase[0], end='\n\n') # 元素为对应单词在单词表中的id,元素1对应'<pad>'表示填充
    print(batch.Phrase[1], end='\n\n') # lengths of each examples(若设置include_lengths=Flase,则不含此项)
    print(batch.Phrase[0].shape, end='\n\n') # 由于fix_length=10,batch_size=11,故batch.Phrase.shape=(10, 11)
    print(batch.Sentiment)
    print(batch.Sentiment.dtype)
    break

<class 'torchtext.data.batch.Batch'>


[torchtext.data.batch.Batch of size 11]
	[.Phrase]:('[torch.cuda.FloatTensor of size 10x11 (GPU 0)]', '[torch.cuda.FloatTensor of size 11 (GPU 0)]')
	[.Sentiment]:[torch.cuda.LongTensor of size 11 (GPU 0)]

(tensor([[2.0000e+00, 6.6600e+02, 1.4200e+02, 1.7700e+02, 1.3000e+01, 9.0000e+01,
         2.7100e+02, 1.2541e+04, 1.2100e+02, 4.0000e+00, 1.1000e+01],
        [1.6850e+03, 1.1049e+04, 1.0000e+01, 5.5400e+02, 9.0000e+00, 7.0000e+00,
         7.6000e+01, 1.0000e+00, 2.3000e+01, 1.1200e+03, 2.0000e+00],
        [5.3700e+02, 1.5341e+04, 2.4640e+03, 4.3000e+01, 1.3876e+04, 1.1600e+02,
         1.0000e+00, 1.0000e+00, 2.1000e+01, 6.8100e+02, 3.2000e+01],
        [5.0000e+00, 1.0000e+00, 8.0000e+00, 6.3160e+03, 2.2700e+02, 1.0000e+00,
         1.0000e+00, 1.0000e+00, 5.4000e+01, 1.0000e+00, 3.8800e+02],
        [4.1980e+03, 1.0000e+00, 1.0000e+00, 2.1640e+03, 5.0000e+01, 1.0000e+00,
         1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 6.0000e+00]

In [38]:
# 可以看出,上面batch.Phrase每列元素对应一个Example(所有:fix_length=None;部分:根据fix_length的大小进行截取)单词在单词表中的id
for i in range(4):
    phrase = train.examples[i].__dict__['Phrase']
    for j in phrase:
        print(TEXT.vocab.stoi[j], end=',')
    print()


2,1685,537,5,4198,6,12652,479,15589,2,13247,1305,28,2,113,4308,6,
666,11049,15341,
142,10,2464,8,
177,554,43,6316,2164,
