In [81]:
import os
import torch
from torch.autograd import Variable

In [3]:
USE_CUDA = torch.cuda.is_available()
USE_CUDA

False

In [4]:
base_dir = 'data/twitter'
train_dir = os.path.join(base_dir, 'twitter_clean.txt')
vocab_dir = os.path.join(base_dir, 'twitter_vocab.txt')

In [49]:
class Corpus(object):
    """
    文本预处理，获取词汇表，并将字符串文本转换为数字序列。
    """

    def __init__(self, train_dir, vocab_dir):
        assert os.path.exists(train_dir), 'File %s does not exist.' % train_dir
        assert os.path.exists(vocab_dir), 'File %s does not exist.' % vocab_dir

        words = open(vocab_dir, encoding='utf-8').read().strip().split('\n')
        word_to_id = dict(zip(words, range(len(words))))
        
        assert word_to_id['<pad>'] == 0, "<pad> id should be 0."
        
        self.words = words
        self.word_to_id = word_to_id
        
        self.tokenize(train_dir)
        
    def tokenize(self, train_dir):
        data = open(train_dir, encoding='utf-8').read().strip().split('\n')
        questions, answers = [], []
        for line in data:
            question, answer = line.split(" ==> ")
            questions.append(self.text_to_ids(question))
            answers.append(self.text_to_ids(answer))
            
        total_num = len(questions)
        train_num = int(0.9 * total_num)
        self.x_train, self.y_train = questions[:train_num], answers[:train_num]
        self.x_test, self.y_test = questions[train_num:], answers[train_num:]
        
    def text_to_ids(self, text):
        return [self.word_to_id[x] for x in (text.split() + ['<eos>'])]
    
    def ids_to_text(self, ids):
        return [self.words[x] for x in ids]

    def __repr__(self):
        return "Train length: %d, Test length: %d, Vocabulary size: %d" % (len(self.x_train), 
                                                                           len(self.x_test), 
                                                                           len(self.words))

In [50]:
corpus = Corpus(train_dir, vocab_dir)
corpus

In [192]:
class DataSet(object):
    def __init__(self, data, labels, batch_size=64):
        num_batches = len(data) // batch_size
        data, labels = data[:(num_batches * batch_size)], labels[:(num_batches * batch_size)]
        self.x, self.y = [], []
        for i in range(num_batches):
            x_batch = data[(i * batch_size):((i+1)*batch_size)]
            y_batch = labels[(i * batch_size):((i+1)*batch_size)]
            x_pad, y_pad = self.pad_batch(x_batch, y_batch)
            self.x.append(x_pad)
            self.y.append(y_pad)
            
    def pad_batch(self, x_batch, y_batch):
        seq_pairs = sorted(zip(x_batch, y_batch), key=lambda p: len(p[0]), reverse=True)
        x_batch, y_batch = zip(*seq_pairs)
        
        x_maxlen = max(map(len, x_batch))
        x_pad = [self.pad_seq(s, x_maxlen) for s in x_batch]
        
        y_maxlen = max(map(len, y_batch))
        y_pad = [self.pad_seq(s, y_maxlen) for s in y_batch]
        
        input_var = Variable(torch.LongTensor(x_pad)).transpose(0, 1)
        target_var = Variable(torch.LongTensor(y_pad)).transpose(0, 1)
        
        if USE_CUDA:
            input_var = input_var.cuda()
            target_var = target_var.cuda()
        return input_var, target_var
    
    def pad_seq(self, seq, max_len):
        return seq + [0] * (max_len - len(seq))
    
    def __getitem__(self, index):
        return self.x[index], self.y[index]
    
    def __len__(self):
        return len(self.x)

In [193]:
train_data = DataSet(corpus.x_train, corpus.y_train, 10)

In [196]:
x_rnd, y_rnd = random.choice(train_data)
for i in range(10):
    print(' '.join(corpus.ids_to_text(x_rnd[:, i].data.numpy())))
    print(' '.join(corpus.ids_to_text(y_rnd[:, i].data.numpy())))
    print()

[ trump's ] ability to repeat false statements with seemingly few consequences has become a point of political ... <eos>
<unk> : since <unk> 15 . 91 % od trumps statements are false . <eos> <pad> <pad> <pad> <pad> <pad> <pad>

losing 5 straight to braves is <unk> ? can it just be they are in mets head ? <eos> <pad>
they haven't lost 5 straight to atlanta . they just took 2 out of 3 at turner . <eos> <pad> <pad>

dear trump supporters , when did murdering innocent civilians become " american " ? yours truly , <eos> <pad> <pad>
oh lord , don't ask <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>

<unk> power <unk> rose pak <unk> lived the life of her choosing <unk> by brings back memories <eos> <pad> <pad>
. my mom went to college with rose pak . here's how she looked in <unk> <eos> <pad> <pad> <pad> <pad>

dear trump supporters , when did murdering innocent civilians become " american " ? yours truly , <eos> <pad> <pad>
... i think you mean