In [13]:
%%writefile data_loader.py
from torchtext import data
from torchtext.data import TabularDataset

class DataLoader(object):
    
    def __init__(self, 
                 batch_size=64, 
                 device=-1, 
                 max_vocab=999999, 
                 min_freq=1,
                 use_eos=False, 
                 shuffle=True):
        '''
        #:param min_freq: Minimum frequency for loaded word.
        #:param device: Device-id to load data (-1 for CPU)
        #:param use_eos: If it is True, put <EOS> after every end of sentence.
        #:param shuffle: If it is True, random shuffle the input data.
        '''
        super(DataLoader, self).__init__()

        # Define field of the input file.
        # The input file consists of two fields.

        self.LABEL = data.Field(sequential=False,
                        use_vocab=True,
                        unk_token=None
                        )

        self.TEXT_A = data.Field(use_vocab=True, 
                               batch_first=True, 
                               include_lengths=False, 
                               eos_token='<EOS>' if use_eos else None
                               )

        self.TEXT_B = data.Field(use_vocab=True, 
                           batch_first=True, 
                           include_lengths=False, # 튜플로 길이까지 반환할꺼냐?
                           eos_token='<EOS>' if use_eos else None
                           )

        train_data = TabularDataset.splits(
                path = '.',
                train = 'train.txt',
                format = 'tsv',
                fields = [('TEXT_A', self.TEXT_A),('TEXT_B', self.TEXT_B),('LABEL',  self.LABEL)],
                skip_header = True
        )[0]

        valid_data = TabularDataset.splits(
                    path = '.',
                    train = 'test.txt',
                    format = 'tsv',
                    fields = [('TEXT_A',  self.TEXT_A),('TEXT_A', self.TEXT_B),('LABEL',  self.LABEL)],
                    skip_header = True
        )[0]

        self.train_iter, self.valid_iter = data.BucketIterator.splits(
                    (train_data, valid_data),
                    batch_size = batch_size,
                    device = 'cuda:%d' % device if device >= 0 else 'cpu',
                    shuffle = shuffle,
                    sort_key = lambda x:len(x.TEXT_A),
                    sort_within_batch = True
                    )
        
        
        self.LABEL.build_vocab(train_data)
        self.TEXT_A.build_vocab(train_data, max_size=max_vocab, min_freq=min_freq)
        self.TEXT_B.build_vocab(train_data, max_size=max_vocab, min_freq=min_freq)
        
        self.ntokens = len(self.TEXT_A.vocab.stoi) # the size of vocabulary

Overwriting data_loader.py
