In [1]:
import os
from torchtext import data, datasets
import pandas as pd
import pickle

In [2]:
base_dir = os.getcwd()
atis_data = os.path.join(base_dir, 'atis')

In [3]:
'''
    构建训练集与验证集
'''
def build_dataset():
    
    tokenize = lambda s:s.split()
    
    SOURCE = data.Field(sequential=True, tokenize=tokenize,
                        lower=True, use_vocab=True,
                        init_token='<sos>', eos_token='<eos>',
                        pad_token='<pad>', unk_token='<unk>',
                        batch_first=True, fix_length=50,
                        include_lengths=True) #include_lengths=True为方便之后使用torch的pack_padded_sequence
    
    TARGET = data.Field(sequential=True, tokenize=tokenize,
                        lower=True, use_vocab=True,
                        init_token='<sos>', eos_token='<eos>',
                        pad_token='<pad>', unk_token='<unk>',
                        batch_first=True, fix_length=50,
                        include_lengths=True) #include_lengths=True为方便之后使用torch的pack_padded_sequence
    LABEL = data.Field(
                    sequential=False,
                    use_vocab=True)
    
    train, val = data.TabularDataset.splits(
                                            path=atis_data,
                                            skip_header=True,
                                            train='atis.train.csv',
                                            validation='atis.test.csv',
                                            format='csv',
                                            fields=[('index', None), ('intent', LABEL), ('source', SOURCE), ('target', TARGET)])
    print('train data info:')
    print(len(train))
    print(vars(train[0]))
    print('val data info:')
    print(len(val))
    print(vars(val[0]))
    
    SOURCE.build_vocab(train, val)
    TARGET.build_vocab(train, val)
    LABEL.build_vocab(train, val)
    
    print('vocab info:')
    print('source vocab size:{}'.format(len(SOURCE.vocab)))
    print('target vocab size:{}'.format(len(TARGET.vocab)))
    print('label vocab size:{}'.format(len(LABEL.vocab)))
    
    
    #train_iter, val_iter = data.BucketIterator.splits(
    #                                                (train, val),
    #                                                batch_sizes=(128, len(val)),
    #                                                #shuffle=True,
    #                                                sort_within_batch=True, #为true则一个batch内的数据会按sort_key规则降序排序
    #                                                sort_key=lambda x: len(x.source)) #这里按src的长度降序排序，主要是为后面pack,pad操作)

    train_iter, val_iter = data.Iterator.splits(
                                                (train, val),
                                                batch_sizes=(128, len(val)), # 训练集设置为128,验证集整个集合用于测试
                                                shuffle=True,
                                                sort_within_batch=True, #为true则一个batch内的数据会按sort_key规则降序排序
                                                sort_key=lambda x: len(x.source)) #这里按src的长度降序排序，主要是为后面pack,pad操作)
    
    return train_iter, val_iter



In [4]:
train_iter, val_iter = build_dataset()
print('train_iter size:{}'.format(len(train_iter)))
print('val_iter size:{}'.format(len(val_iter)))

train data info:
4978
{'intent': 'flight', 'source': ['i', 'want', 'to', 'fly', 'from', 'boston', 'at', '838', 'am', 'and', 'arrive', 'in', 'denver', 'at', '1110', 'in', 'the', 'morning'], 'target': ['o', 'o', 'o', 'o', 'o', 'b-fromloc.city_name', 'o', 'b-depart_time.time', 'i-depart_time.time', 'o', 'o', 'o', 'b-toloc.city_name', 'o', 'b-arrive_time.time', 'o', 'o', 'b-arrive_time.period_of_day']}
val data info:
893
{'intent': 'flight', 'source': ['i', 'would', 'like', 'to', 'find', 'a', 'flight', 'from', 'charlotte', 'to', 'las', 'vegas', 'that', 'makes', 'a', 'stop', 'in', 'st.', 'louis'], 'target': ['o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'b-fromloc.city_name', 'o', 'b-toloc.city_name', 'i-toloc.city_name', 'o', 'o', 'o', 'o', 'o', 'b-stoploc.city_name', 'i-stoploc.city_name']}
vocab info:
source vocab size:945
target vocab size:133
label vocab size:27
train_iter size:39
val_iter size:1


In [5]:
for i,batch in enumerate(train_iter):
    print(batch.intent)
    print(batch.source)
    print(batch.target)
    break

tensor([ 1,  1,  1,  1,  2,  1,  1,  6, 26,  2,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  2,  1,  1,  2,  1,  1,  1,  2,  1,  1,  1,  1,  4,  1,  1,  1,  1,
         6,  1,  2,  1,  1,  1,  7,  2,  1,  1,  1,  1,  1,  1,  4,  1,  1,  1,
         1,  1,  2,  1,  1,  7,  3,  4,  8,  1,  1,  1,  1,  1, 11,  1,  1,  1,
         2,  1,  1,  2,  1,  1,  1,  1,  1,  1,  1,  8,  1,  2,  1,  2,  1,  1,
         1,  1,  1,  1,  8,  1,  4,  3,  1,  1,  1,  1,  2,  1,  1,  1,  1,  1,
         1,  4,  1,  1,  1,  3,  1,  1,  2,  1, 18,  1,  3,  1, 10,  1,  1,  1,
         1,  1])
(tensor([[  2,  13,  81,  ...,   1,   1,   1],
        [  2,  13,  40,  ...,   1,   1,   1],
        [  2,  13, 189,  ...,   1,   1,   1],
        ...,
        [  2,  38,  11,  ...,   1,   1,   1],
        [  2,   6,   5,  ...,   1,   1,   1],
        [  2,   6,   5,  ...,   1,   1,   1]]), tensor([24, 23, 23, 20, 20, 19, 19, 19, 19, 18, 18, 18, 18, 18, 18, 18, 17, 17,
        17, 17, 17, 17, 17, 17, 17, 16, 16, 16, 16,

1