In [117]:
import torch
import torch.nn as nn
import numpy as np
# from skimage import io, transform
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import os
import random
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import re
import unicodedata # ??
import nltk
from nltk.tokenize import TweetTokenizer
import csv
import json
from torchvision import transforms
from torch.autograd import Variable
np.random.seed(1)
random.seed(1)


# define directory structure needed for data processing
RAW_DATA_DIR = os.path.join('..', 'data/', 'raw_data/')
FORMAL_DATA_DIR = os.path.join('..', 'data/', 'formal_data/')
UNKNOWN_TOKEN = "unk"

## Split data into `train`, `val` and `test`
Split and write raw data as `acsii` format. 

In [None]:
# Turn a Unicode string to plain ASCII, thanks to
# http://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )
def normalizeString(s):
    s = unicodeToAscii(s.strip()).replace("\t", "")
    return s


# precondition: two fields with name: "headline" and "text"
def splitData(fname, test_size=0.2, val_size=0.2): 
    df = pd.read_csv('../data/' + fname, encoding='latin-1')
    df = df[["headline", "text"]] # summary text, not the entire article
    df["headline"] = df["headline"].apply(normalizeString)
    df["text"] = df["text"].apply(normalizeString)

    df = df.sample(frac=1).reset_index(drop=True) # shuffle data

    df_train, df_test = train_test_split(df, test_size=test_size, random_state=1)
    df_train, df_val = train_test_split(df_train, test_size=val_size, random_state=1)
    df_train.to_csv(RAW_DATA_DIR + "train.csv", index=False, sep="\t")
    df_val.to_csv(RAW_DATA_DIR + "val.csv", index=False, sep="\t")
    df_test.to_csv(RAW_DATA_DIR + "test.csv", index=False, sep="\t")
    
# # TODO: check whether "text" is in fact the summary and corresponds to the headline

splitData("news_summary.csv")

## Preprocess `raw_data` to `formal_data`

In [None]:
# class to store string transformation
class Transform(object):
    tokenizer = TweetTokenizer(preserve_case=False, strip_handles=True, reduce_len=True)
    @staticmethod
    def word_tokenize(s): 
        return json.dumps(Transform.tokenizer.tokenize(s))
    
    @staticmethod
    def cap(s): 
        return s.upper()
    
# tf: a transformation apply to each individual headline and text
# Apply a transformation to the dataframe in "raw_data" 
# Results of the transformation is in "formal_data"
def preprocess(fname, tf, chunksize=1000, num_chunk=-1): 
    header = True
    mode = "w"
    i = 0
    for df in pd.read_csv(RAW_DATA_DIR + fname, sep="\t", header = 0, chunksize=chunksize):
        df["headlines"] = df["headlines"].apply(tf)
        df["text"] = df["text"].apply(tf)
        df.to_csv(FORMAL_DATA_DIR+fname, columns = ['headlines','text'], mode=mode,index=False, header=header, sep="\t")
        if header == True:  
            # no header, and write in append mode from the 2nd chunk
            header = False; 
            mode = "a"
        
        i = i + 1 if i >= 0 else -1
        if i == num_chunk: 
            break; 
    

preprocess("train.csv", tf=Transform.word_tokenize)
preprocess("val.csv", tf=Transform.word_tokenize)
preprocess("test.csv", tf=Transform.word_tokenize)

## Language model from train data

In [None]:
class GloVe():
    def __init__(self, path, dim):
        self.dim = dim
        self.word_embedding_dict = {}
        with open(path) as f:
            for line in f:
                values = line.split()
                embedding = values[-dim:]
                word = ''.join(values[:-dim])
                self.word_embedding_dict[word] = np.asarray(embedding, dtype=np.float32)
    
    def get_word_vector(self, word):
        if word not in self.word_embedding_dict.keys():
            embedding = np.random.uniform(low=-1, high=1, size=self.dim).astype(np.float32)
            self.word_embedding_dict[word] = embedding
            return embedding
        else:
            return self.word_embedding_dict[word]

In [None]:
glvmodel = GloVe(os.path.join('..', 'models', 'glove.twitter.27B.200d.txt'), dim=200)

In [172]:
PAD_token = 0
SOS_token = 1 # start of sentence
EOS_token = 2 # end of sentence
PAD_word = "<PAD>"


# Modified from: Sean Robertson <https://github.com/spro/practical-pytorch>
class Lang: # language model
    def __init__(self, glvmodel, fname=None):
        self.word2index = {}
        self.word2count = {}
        self.index2word = {}
        self.gloveEmbed = []
        self.glvmodel = glvmodel
        self.size = 0  
        
        self.addWord(PAD_word)
        self.addWord("SOS")
        self.addWord("EOS")
        
        self.addWord(UNKNOWN_TOKEN)
        if fname is not None: 
            self.addCSV(fname)
        
    # add words of a sentence into the language model
    # split by ' '
    def addSentence(self, sentence):
        for word in sentence: 
            self.addWord(word)
            
    def addGlove(self, glove): 
        self.glove = glove

    def getGloveLayer(self, embed_dim): 
        # Initialize word embeddings from our pre-training GloVe embeddings
        glove_embeddings = torch.from_numpy(self.gloveEmbed)
        return nn.Embedding(self.size, embed_dim).from_pretrained(glove_embeddings, freeze=False)
        
    # add a word to language model
    def addWord(self, word):
        if word not in self.word2index:
            # assign id for word
            self.word2index[word] = self.size
            
            # count word
            self.word2count[word] = 1
            
            self.index2word[self.size] = word
            self.size += 1
            
#             gloveEmbed.append(glvmodel.get_word_vector(word))
        else:
            self.word2count[word] += 1
            
    def addDataFrame(self, df): 
        # df: headline, text
        for index, row in df.iterrows():
            self.addSentence(json.loads(row['headlines']))
            self.addSentence(json.loads(row['text']))
            
    def addCSV(self, fname): 
        # construct language model based on a file
        # file is a dataframe csv file with "headlines" and "text"
        for df in pd.read_csv(FORMAL_DATA_DIR + fname, sep="\t", header = 0, chunksize=2000):
            self.addDataFrame(df)     
    
    def size(self): 
        return len(self.word2index)
    
    def wordSeq2IdxSeq(self, word_seqs):
        default_idx = lang.word2index[UNKNOWN_TOKEN]
        idxs = []
        for word_seq in word_seqs:
            idxs.append([lang.word2index.get(w, default_idx) for w in word_seq])
    #         idxs.append(torch.LongTensor([lang.word2index.get(w, default_idx) for w in word_seq]))
        return idxs

lang = Lang([])
lang.addCSV("train.csv")

## Neural Network model

In [118]:
def init_param(param, bias_std=0.0, weight_std=0.05): 
    for name, param in self.gru.named_parameters():
        if 'bias' in name:
            nn.init.constant_(param, bias_std)
        elif 'weight' in name:
            nn.init.normal_(param, std = weight_std)
            
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=1, dropout=0.1):
        super(EncoderRNN, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout = dropout
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=self.dropout, bidirectional=True)
        
    def forward(self, input_seqs, input_lengths, hidden=None):
        # Note: we run this all at once (over multiple batches of multiple sequences)
        embedded = self.embedding(input_seqs)
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        outputs, hidden = self.gru(packed, hidden)
        outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) # unpack (back to padded)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] # Sum bidirectional outputs
        return outputs, hidden
    
class Encoder(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size, 
                    pretrained_embeddings, batch_size, num_layers):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.num_layers = num_layers
        self.word_embeddings = gloveLayer
        
        # Initialize a Gated Recurrent Unit RNN
        self.gru = nn.GRU(input_size=embedding_dim, hidden_size=hidden_dim, num_layers=num_layers)
        
#         self.hidden2label = nn.Linear(hidden_dim, n_classes)
        self.hidden = self.init_hidden()
        
        # Custom initialization of the weights and biases
        init_param(self.gru.named_parameters())
#         init_param(self.hidden2label.named_parameters())


    def init_hidden(self):
        return Variable(torch.zeros(self.num_layers, self.batch_size, self.hidden_dim).cuda())

    
#     def forward(self, input_seqs):
        
#         embedded = self.word_embeddings(input_seqs)
#         packed_embeds = nn.utils.rnn.PackedSequence(input_embed, packed_sequence.batch_sizes)
#         output, hidden = self.gru(packed_embeds, self.hidden)
#         return output, hidden
    
    def forward(self, input_seqs, input_lengths, hidden=None):
        # Note: we run this all at once (over multiple batches of multiple sequences)
        embedded = self.word_embeddings(input_seqs)
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        outputs, hidden = self.gru(packed, hidden)
        outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) # unpack (back to padded)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] # Sum bidirectional outputs
        return outputs, hidden

class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        
        self.gru = nn.GRU(hidden_size, hidden_size)
        
        # map to output space
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1) # ? what is LogSoftMax, dim = 1??

    def forward(self, input, hidden):
        # ? why the input is is 'output_size' dimension? 
        # ? why need embedding? 
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

### Training
#### Dataset

In [3]:
class SummaryDataset(Dataset):
    def __init__(self, fname, transform=None):
        self.df = pd.read_csv(FORMAL_DATA_DIR + fname, sep="\t", header = 0)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        headline = self.df["headlines"][idx]
        text = self.df["text"][idx]
        
        sample = {'headlines': headline, 'text': text}

        if self.transform:
            sample = self.transform(sample)

        return sample

# trainDataset = SummaryDataset(FORMAL_DATA_DIR + "train.csv")

In [192]:
# class ToTensor(object):
#     """Convert the pair of string lists to Tensors."""
#     def __init__(self, lang): 
#         self.lang = lang
        
#     def __call__(self, sample):
#         headlines, text = sample['headlines'], sample['text']
#         headlines = prepare_sequence_batch(headlines, self.lang)
#         text = prepare_sequence_batch(text, self.lang)
#         return (torch.Tensor(headlines), torch.Tensor(text))


# Pad a with the PAD symbol
def pad_seq(seq, max_length):
    seq += [PAD_token for i in range(max_length - len(seq))]
    return seq


"""
Create a batch ready for feeding into the network. 
Paddings are added to ensure all sequences are the same length. 
Precondition: 
    input_seqs: a list of sequences; each element is a sequencce which is a list of words
    target_seqs: same. 
    lang: the language model
Postcondition: 
    indices version of the input and target in tensor form. 
"""
def batch(input_seqs, target_seqs, lang):
    # input_seqs and target_seqs are in string format
    
    input_seqs = lang.wordSeq2IdxSeq(input_seqs)
    target_seqs = lang.wordSeq2IdxSeq(target_seqs)
    
    print(target_seqs)
    
    # Zip into pairs, sort by length (descending), unzip
    seq_pairs = sorted(zip(input_seqs, target_seqs), key=lambda p: len(p[0]), reverse=True)
    input_seqs, target_seqs = zip(*seq_pairs)
    
    # For input and target sequences, get array of lengths and pad with 0s to max length
    input_lengths = [len(s) for s in input_seqs]
    input_padded = [pad_seq(s, max(input_lengths)) for s in input_seqs]
    target_lengths = [len(s) for s in target_seqs]
    target_padded = [pad_seq(s, max(target_lengths)) for s in target_seqs]

    for s in input_padded: 
        print(s)
        
    for s in target_padded: 
        print(s)
    # Turn padded arrays into (batch_size x max_len) tensors, transpose into (max_len x batch_size)
    input_var = Variable(torch.LongTensor(input_padded)).transpose(0, 1)
    target_var = Variable(torch.LongTensor(target_padded)).transpose(0, 1)
    
#     input_var = input_var.cuda()
#     target_var = target_var.cuda()
    
    return input_var, input_lengths, target_var, target_lengths



# tweet_in = nn.utils.rnn.pack_sequence(prepare_sequence_batch(headlines, word2idx)).cuda()
# target = nn.utils.rnn.pack_sequence(prepare_sequence_batch(text, word2idx)).cuda()

In [197]:
# tf = transforms.Compose([ToTensor(lang)])
tf = None
dataset = SummaryDataset(fname="train.csv", transform=tf)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)


sb = next(iter(dataloader))
input_seqs = sb["text"]
target_seqs = sb["headlines"]
input_seqs = [json.loads(s) for s in input_seqs]
target_seqs = [json.loads(s) for s in target_seqs]
input_seqs, input_lengths, target_seqs, target_lengths = batch(input_seqs, target_seqs, lang)


for i_batch, sample_batched in enumerate(dataloader):
    # clear gradients, clear hidden state from the last timestep
#     gru_model.zero_grad()
#     gru_model.hidden = gru_model.init_hidden()

    input_seqs = sample_batched["text"]
    target_seqs = sample_batched["headlines"]
    
    input_seqs = [json.loads(s) for s in input_seqs]
    target_seqs = [json.loads(s) for s in target_seqs]
    
    input_seqs, input_lengths, target_seqs, target_lengths = batch(input_seqs, target_seqs, lang)
    
    input_seqs = nn.utils.rnn.pack_sequence(input_seqs) # .cuda()
    target_seqs = nn.utils.rnn.pack_sequence(target_seqs) #.cuda()

#     # forward pass
#     label_scores = gru_model(tweet_in)

#     # compute loss against true labels
#     loss = loss_function(label_scores, target)

#     # backprop the gradients and update the model parameters
#     loss.backward()
#     optimizer.step()

#     # keep track of the loss
#     running_loss += loss.item()
#     i += BATCH_SIZE
#     if i % 2000 == 0:
#         average_loss = running_loss/2000
#         if average_loss < lowest_loss:
#             lowest_loss = running_loss
#             # save our checkpoint if it is the current best
#             torch.save(gru_model.state_dict(), CHECKPOINT_FILE)
#         logging.info("running loss: %.3f @ batch %d", average_loss, batch_ind)
#         running_loss = 0.0

[[324, 1851, 863, 6360, 769, 6361, 191, 318, 1906, 59, 2068], [12754, 11440, 11182, 12755, 246, 899, 3724, 203, 120, 1727, 3520, 120], [4501, 2378, 59, 155, 2414, 59, 3712, 7316, 7, 2167, 1052], [5174, 11, 259, 4929, 17436, 10609, 17437]]
[3897, 12754, 11440, 11481, 217, 11441, 11442, 11, 90, 4873, 16, 387, 12755, 246, 899, 7, 30, 3897, 50, 81, 195, 120, 1727, 3520, 120, 37, 97, 11, 81, 4595, 59, 92, 12756, 132, 12757, 478, 30, 3897, 235, 47, 2305, 1249, 37, 11441, 11442, 50, 47, 83, 1421, 18, 30, 11481, 195, 1304, 29, 2157, 246, 12758, 59, 20, 30, 12759, 132, 8679, 7, 482, 5063, 20, 37]
[2101, 47, 105, 75, 1384, 274, 2378, 59, 3712, 7316, 7, 2167, 1052, 11, 651, 13282, 2668, 143, 3106, 14012, 990, 132, 120, 1203, 196, 9647, 120, 7316, 1009, 29, 7236, 7686, 143, 1270, 37, 83, 1676, 2704, 3021, 1471, 10083, 253, 11, 234, 75, 1644, 2378, 59, 3912, 218, 7, 6245, 1627, 11, 75, 92, 2317, 37, 30, 2767, 16, 387, 6393, 89, 30, 105, 558, 37, 0]
[1454, 4791, 651, 5174, 5175, 132, 261, 262, 263, 

[[3845, 3846, 3847, 3848, 191, 381, 692, 384, 986, 707, 294], [5860, 584, 5858, 196, 59, 60, 782, 44, 3474, 1460], [2714, 4560, 47, 287, 82, 83, 1968, 1076, 1789, 2037, 2175, 3678], [172, 2165, 143, 6203, 11388, 5994, 10698, 1811, 47, 969]]
[386, 382, 2187, 383, 384, 11, 153, 195, 294, 89, 388, 389, 11, 50, 20, 3849, 3850, 3851, 11, 3849, 525, 11, 3845, 3846, 3847, 3848, 11, 3852, 3853, 3854, 39, 1975, 38, 985, 253, 37, 3855, 41, 20, 246, 707, 985, 37, 19, 175, 18, 19, 1657, 3856, 11, 132, 211, 235, 47, 117, 3857, 75, 3858, 59, 28, 37, 20, 3850, 3859, 78, 3860, 3861, 39, 99, 1009, 782, 143, 3862, 41, 11, 20, 19, 175, 37]
[1906, 160, 5859, 5860, 16, 1316, 5858, 11, 30, 17145, 1147, 153, 195, 1586, 67, 819, 59, 252, 11, 196, 59, 60, 482, 328, 44, 3474, 1460, 37, 270, 604, 246, 481, 1201, 1989, 29, 68, 553, 143, 83, 5767, 1647, 67, 17146, 3971, 143, 5708, 132, 2081, 9697, 11, 399, 234, 83, 13187, 1573, 195, 1486, 132, 481, 195, 886, 217, 83, 3013, 37, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

[[15731, 12, 2716, 601, 482, 120, 15732, 120, 375, 3331], [441, 1399, 747, 4387, 11, 3014, 11418, 986, 14380], [498, 1015, 82, 1016, 47, 392, 132, 382, 191, 500], [3760, 3593, 83, 733, 59, 14597, 715, 191, 160]]
[2342, 193, 15731, 7216, 16, 17, 47, 83, 4612, 18, 2716, 601, 83, 120, 15732, 120, 561, 7, 15733, 67, 375, 37, 47, 30, 561, 11, 15731, 957, 92, 912, 4218, 82, 83, 2895, 59, 3841, 30, 31, 120, 2895, 752, 363, 2896, 120, 37, 481, 550, 11, 20, 3108, 675, 2227, 303, 3108, 3251, 47, 37, 3108, 599, 1482, 2227, 18, 303, 3270, 675, 808, 3268, 14815, 2228, 20, 213]
[30, 500, 29, 1017, 50, 18, 1018, 42, 498, 1015, 82, 243, 1019, 1016, 47, 392, 132, 382, 132, 30, 1020, 1021, 558, 235, 20, 491, 495, 20, 37, 20, 30, 23, 160, 235, 1022, 47, 1023, 541, 808, 1024, 42, 1025, 132, 30, 1026, 726, 37, 1018, 235, 498, 1027, 11, 20, 50, 500, 37, 30, 1016, 99, 97, 592, 500, 7, 1028, 81, 47, 30, 429, 458, 37, 0]
[83, 18211, 2787, 635, 11, 153, 195, 747, 4387, 89, 83, 7885, 18212, 217, 83, 3058, 47, 32

[[1658, 10897, 140, 30, 3038, 5514, 191, 2593, 2594, 59, 9720, 2566], [7735, 3774, 7739, 191, 6654, 7733, 2213, 1703, 287], [82, 17799, 31, 4238, 11, 698, 17800, 589, 143, 106, 1318], [8022, 5235, 2461, 1109, 5291, 217, 253, 2037, 916, 191, 724]]
[287, 726, 727, 724, 16, 421, 18, 5763, 5235, 6809, 8023, 20, 2461, 1109, 5291, 217, 253, 20, 1023, 37, 277, 11, 724, 50, 20, 30, 3404, 20, 16, 8024, 351, 7, 30, 287, 531, 8025, 532, 37, 117, 2432, 1686, 47, 30, 5183, 7, 8023, 1592, 2356, 7, 144, 5770, 771, 67, 30, 525, 234, 195, 20, 8026, 20, 29, 5647, 132, 8027, 7, 10, 4162, 59, 30, 456, 37]
[82, 17799, 31, 4238, 47, 1614, 17801, 8509, 11, 698, 16, 17802, 30, 589, 143, 30, 106, 1318, 57, 30, 1801, 1063, 7, 4238, 47, 12111, 6774, 217, 17803, 37, 30, 107, 1896, 7, 428, 11, 7404, 11, 7062, 11, 31, 132, 17804, 31, 1438, 29, 103, 1500, 3205, 270, 3198, 37, 47, 4475, 10, 4822, 11, 698, 2363, 8970, 4621, 246, 1945, 757, 1454, 4811, 395, 37, 0, 0, 0]
[2538, 351, 7, 117, 1390, 4539, 29, 30, 5902, 217

[[3217, 2539, 1329, 7, 9254, 12986, 12984, 17207], [105, 1099, 3468, 1482, 274, 3630, 246, 1114, 726], [287, 996, 59, 518, 424, 120, 7216, 7166, 120, 217, 9034], [6726, 8152, 4044, 143, 144, 9063, 647, 59, 4738, 144, 4839, 771]]
[83, 6726, 8152, 235, 647, 59, 4738, 4176, 14467, 39, 144, 4839, 771, 41, 217, 35, 3999, 47, 2710, 11, 995, 554, 246, 243, 3956, 4044, 81, 143, 4176, 145, 39, 144, 9063, 41, 11, 14468, 81, 195, 83, 960, 14469, 37, 30, 7365, 4107, 30, 14470, 14471, 6726, 143, 9446, 11, 1217, 4701, 243, 3111, 246, 83, 14472, 50, 81, 459, 92, 14473, 37, 30, 7365, 42, 20, 14474, 12307, 11, 20, 83, 14472, 50, 37]
[30, 287, 558, 16, 1106, 59, 518, 424, 83, 120, 7216, 7166, 120, 217, 9035, 9036, 1208, 59, 3841, 3973, 541, 47, 30, 7216, 8234, 11, 287, 3788, 132, 11130, 160, 1440, 1432, 50, 37, 30, 107, 558, 75, 1393, 327, 144, 1539, 146, 59, 30, 2118, 4986, 11, 19, 175, 37, 30, 7166, 75, 6836, 752, 30, 2679, 7, 8234, 11, 4011, 132, 3788, 4622, 7, 9034, 37, 0, 0, 0, 0, 0, 0, 0, 0]
[1720

[[10743, 795, 67, 5073, 1851, 196, 92, 10744, 191, 306], [1914, 7866, 59, 4252, 986, 13547, 1437, 1285], [75, 2640, 2641, 7, 2642, 2643, 82, 2644, 2645, 191, 674, 2646, 726], [15196, 59, 22, 1903, 1327, 1109, 143, 1390, 1318, 143, 16113]]
[8941, 23, 10743, 795, 57, 3800, 2909, 2197, 1851, 196, 92, 10744, 242, 81, 494, 196, 3997, 30, 341, 1958, 11, 778, 59, 306, 37, 5105, 16, 4687, 196, 707, 755, 59, 2912, 1060, 10691, 128, 47, 30, 3800, 2909, 57, 30, 46, 143, 482, 795, 37, 1013, 11, 9795, 10745, 10746, 18, 83, 10747, 1720, 10748, 83, 10749, 140, 46, 143, 5156, 37]
[2647, 2648, 2261, 1026, 726, 2649, 544, 50, 19, 339, 1947, 2650, 83, 428, 746, 59, 20, 1461, 83, 1425, 2640, 20, 29, 2651, 2652, 1770, 132, 2653, 2641, 82, 2644, 2645, 37, 30, 558, 195, 2654, 59, 932, 1194, 249, 1460, 2072, 47, 548, 11, 19, 175, 37, 2653, 674, 536, 936, 1325, 2655, 1536, 57, 20, 2656, 132, 2657, 20, 37, 0]
[2512, 1155, 679, 112, 593, 594, 132, 117, 1770, 75, 22, 1903, 1327, 1109, 143, 30, 245, 1318, 59, 3801

[[551, 552, 13657, 5847, 143, 6630, 771, 13658, 73, 74], [29, 283, 5495, 916, 11, 6950, 3888, 9682, 5495, 29, 666], [2787, 5514, 246, 707, 16882, 59, 16883, 143, 707, 29, 16884], [323, 7, 13445, 11, 13446, 59, 13447, 2379, 47, 287]]
[83, 1360, 342, 1046, 2787, 1190, 47, 4618, 246, 707, 16882, 59, 83, 16883, 34, 482, 6524, 6527, 11, 752, 83, 4532, 120, 16885, 120, 37, 778, 59, 81, 11, 16886, 1412, 438, 59, 5209, 47, 83, 16883, 219, 30, 1711, 57, 5069, 702, 16887, 37, 778, 59, 35, 1350, 11, 410, 5343, 10156, 438, 1190, 1650, 59, 270, 1302, 242, 8556, 11, 2556, 81, 707, 16888, 89, 30, 16889]
[105, 3888, 11, 153, 42, 2363, 29, 35, 3776, 3630, 47, 903, 249, 30, 2294, 9683, 11, 9684, 83, 5495, 9685, 6755, 83, 666, 29, 283, 5495, 916, 29, 115, 37, 30, 3888, 176, 3045, 9686, 9687, 29, 325, 132, 2734, 209, 9688, 132, 9689, 47, 30, 9687, 37, 78, 438, 387, 3875, 9690, 1538, 9691, 11, 9692, 8258, 132, 83, 6055, 8466, 11, 826, 2066, 37, 0, 0, 0]
[10325, 1981, 22, 29, 972, 1142, 30, 287, 558, 59, 16

[[4573, 5751, 11124, 15297, 47, 120, 15915, 120, 15916], [2068, 9309, 1499, 342, 1184, 1076, 249, 4406, 143, 14132, 83, 692], [11457, 4229, 1345, 8492, 82, 1929, 11458, 7, 11459, 47, 7376], [1792, 8968, 219, 396, 1098, 2007, 3710, 143, 30, 1345, 1318]]
[4573, 438, 7218, 601, 11124, 7434, 67, 30, 15917, 7, 120, 15915, 120, 15916, 11, 1864, 15918, 1732, 15919, 1499, 15920, 15921, 11, 83, 1834, 234, 47, 30, 2311, 1851, 865, 5751, 30, 11126, 47, 8246, 37, 30, 8667, 293, 4847, 59, 15922, 15923, 30, 1051, 15924, 11126, 4268, 342, 1507, 39, 11124, 342, 1507, 41, 47, 15916, 11, 18, 99, 387, 15925, 82, 1051, 11866, 132, 12204, 37]
[8969, 1792, 8968, 16, 1149, 219, 7, 30, 396, 1098, 8970, 143, 30, 106, 1318, 47, 117, 2405, 2039, 11, 8971, 8972, 47, 30, 3713, 283, 2007, 3710, 37, 8968, 11, 153, 195, 283, 1063, 1109, 143, 83, 1733, 8973, 3346, 11, 245, 521, 30, 396, 2933, 47, 1851, 579, 37, 30, 8974, 342, 1046, 235, 1495, 67, 1304, 3324, 11, 1178, 1217, 1109, 2931, 270, 184, 37, 0]
[819, 807, 801,

[[2677, 996, 2654, 59, 1175, 372, 120, 4485, 7, 1249, 191, 726], [445, 3470, 5050, 29, 929, 930, 2989, 82, 1614, 5030], [1056, 187, 59, 873, 2679, 1650, 59, 1430, 7, 4258, 2378], [818, 5249, 253, 254, 59, 6018, 1452, 191, 14531]]
[30, 452, 29, 304, 599, 83, 5050, 29, 456, 502, 112, 929, 930, 1114, 82, 30, 1614, 2975, 59, 252, 11, 5051, 5052, 11, 29, 68, 2188, 47, 287, 1306, 30, 1777, 5053, 1326, 4587, 37, 757, 30, 452, 99, 97, 2790, 30, 306, 7, 935, 1114, 5052, 57, 20, 4850, 3439, 20, 11, 30, 455, 456, 1206, 5054, 18, 30, 20, 5055, 1114, 20, 837, 208, 1165, 37]
[105, 23, 160, 6858, 6859, 29, 304, 50, 117, 558, 235, 2654, 59, 1717, 30, 4485, 7, 1249, 7, 30, 372, 7, 30, 107, 37, 20, 2037, 797, 89, 1256, 160, 543, 544, 132, 363, 558, 235, 1217, 143, 30, 453, 37, 369, 42, 491, 143, 1641, 11, 4514, 11, 1102, 29, 1782, 1101, 11, 20, 30, 105, 23, 160, 50, 757, 6856, 143, 500, 6860, 37, 0, 0]
[125, 5674, 14532, 14531, 16, 50, 18, 125, 801, 818, 806, 5249, 28, 254, 59, 6018, 1185, 29, 30, 5902,

[[261, 262, 16, 5172, 2629, 263, 83, 11892, 1151, 191, 306], [1903, 799, 1989, 12232, 9136, 217, 1355, 10], [7320, 5155, 59, 7321, 3053, 7322, 4667, 986, 2907, 7323, 1358], [1341, 75, 92, 1304, 1069, 78, 6865, 1946, 1947, 191, 1353]]
[778, 59, 306, 11, 261, 262, 263, 16, 5172, 2629, 263, 83, 11892, 1151, 59, 1404, 28, 143, 117, 8121, 4539, 47, 9967, 3230, 3125, 956, 31, 37, 83, 6845, 50, 11, 20, 2629, 195, 2711, 808, 206, 632, 19, 99, 3143, 117, 11678, 808, 29, 686, 83, 1343, 1626, 11, 259, 629, 59, 1262, 28, 2227, 37, 20, 2629, 235, 345, 1091, 83, 6187, 9968, 47, 30, 31, 37]
[773, 125, 618, 1182, 408, 1352, 1353, 50, 18, 83, 1269, 828, 120, 618, 2558, 42, 124, 83, 1425, 2830, 89, 1186, 59, 208, 611, 7, 30, 781, 14891, 18, 30, 1341, 235, 47, 37, 20, 363, 875, 59, 686, 202, 235, 11, 8742, 808, 270, 107, 7, 14892, 75, 196, 245, 143, 1247, 37, 1341, 75, 92, 1304, 1069, 78, 6865, 1946, 1946, 1947, 11, 20, 1353, 50, 37]
[30, 90, 287, 1699, 429, 39, 7320, 41, 16, 1996, 59, 7321, 3053, 7322, 

[[7057, 235, 35, 7058, 2463, 11, 303, 168, 482, 191, 7059, 7060], [500, 5088, 2149, 9220, 47, 7283, 53, 81, 729, 852, 47, 424], [672, 324, 83, 8835, 768, 143, 3041, 7, 30, 453, 191, 162], [17985, 889, 7, 193, 886, 47, 969, 11, 1608, 1467]]
[217, 30, 500, 71, 3733, 2989, 11, 159, 160, 161, 162, 29, 972, 2790, 30, 3215, 687, 768, 57, 83, 20, 8835, 20, 132, 20, 8836, 20, 797, 860, 30, 3041, 7, 30, 453, 47, 30, 373, 37, 2292, 18, 30, 768, 3595, 30, 8837, 7, 740, 1273, 132, 4850, 3744, 11, 162, 175, 11, 20, 687, 195, 83, 8838, 8839, 59, 740, 11, 4850, 3744, 672, 3081, 11, 132, 8840, 37, 20]
[500, 450, 451, 23, 9221, 9222, 29, 509, 50, 30, 2149, 9220, 217, 7283, 75, 92, 1581, 246, 458, 53, 500, 729, 83, 852, 47, 30, 107, 37, 270, 604, 351, 7, 30, 106, 7, 30, 9223, 458, 47, 424, 37, 514, 18, 2149, 9220, 235, 83, 1970, 7, 9224, 11, 9222, 50, 11, 20, 81, 235, 196, 309, 59, 92, 326, 47, 311, 1282, 37, 20, 0, 0, 0, 0, 0, 0]
[30, 17986, 889, 7, 16006, 193, 17987, 3192, 195, 886, 47, 482, 7863, 47,

[[2704, 2899, 8666, 59, 381, 202, 3507, 47, 7134, 3535, 187], [1775, 494, 864, 3062, 59, 4836, 3252, 144, 6317, 3797, 47, 145, 554, 144], [806, 235, 1121, 8997, 47, 6537, 191, 6825, 1336], [2792, 187, 5827, 637, 11, 614, 3934, 5828, 1119, 191, 306]]
[125, 3782, 8639, 6825, 1336, 29, 728, 50, 18, 801, 818, 806, 235, 1121, 8997, 47, 30, 6537, 132, 235, 2166, 59, 14115, 249, 2859, 47, 30, 5149, 807, 37, 20, 16636, 75, 14252, 5781, 246, 35, 7046, 37, 206, 303, 3306, 280, 2166, 59, 3572, 37, 30, 5596, 280, 8896, 217, 30, 6537, 81, 961, 280, 63, 518, 59, 1686, 1304, 11, 20, 50, 1336, 37]
[30, 864, 16, 163, 83, 1853, 7, 144, 6317, 771, 5504, 67, 16491, 243, 16492, 4810, 47, 30, 183, 145, 554, 37, 864, 16, 16018, 30, 16493, 5504, 39, 16494, 41, 1119, 59, 334, 4822, 67, 16495, 7, 1874, 132, 16491, 30, 16496, 3198, 37, 1872, 11, 327, 311, 146, 1425, 11889, 4238, 75, 92, 862, 424, 217, 1814, 1863, 59, 2157, 2692, 132, 5647, 37, 0]
[8667, 217, 30, 1208, 7, 761, 8668, 438, 2317, 35, 8669, 31, 18, 4

[[996, 2124, 410, 144, 2368, 771, 860, 3828, 3185, 2114], [203, 168, 12171, 12944, 1593, 191, 500, 29, 1796, 1536], [4848, 9628, 195, 83, 8056, 2343, 11, 742, 10248, 10249], [1036, 5847, 75, 7183, 7184, 1040, 29, 117, 7185, 59, 252, 144]]
[12532, 452, 120, 12, 7, 243, 2298, 3798, 707, 12945, 89, 500, 11, 1906, 160, 529, 8077, 50, 11, 20, 81, 235, 168, 83, 12171, 12944, 83, 1593, 37, 20, 2321, 1060, 452, 120, 768, 59, 1378, 243, 2298, 3798, 59, 2091, 59, 8007, 2154, 67, 12946, 678, 11, 8077, 50, 274, 12947, 293, 20, 6662, 5851, 20, 1018, 757, 535, 293, 9780, 820, 7535, 59, 30, 2294, 7534, 202, 37]
[30, 1653, 1715, 1671, 7, 252, 39, 3193, 41, 2124, 144, 17190, 771, 410, 30, 245, 1508, 554, 860, 3828, 3185, 2114, 37, 270, 8947, 35, 9032, 7, 223, 144, 17191, 771, 860, 4016, 132, 144, 17192, 771, 860, 15294, 168, 11381, 132, 17193, 7, 3185, 8526, 37, 466, 11, 83, 1801, 7, 1060, 17194, 771, 3185, 2114, 438, 387, 8399, 424, 59, 68, 4249, 11, 3723, 37, 0, 0, 0]
[107, 72, 7, 252, 1413, 10248, 1

[[11325, 2433, 59, 4443, 1205, 342, 1046, 1968, 1076, 923], [7038, 144, 7039, 16793, 202, 47, 3013, 57, 78, 1246, 191, 1274, 59, 16794], [303, 468, 837, 363, 7876, 191, 4673, 411, 10859, 4632], [7708, 7, 9873, 31, 29, 9874, 9875, 6043]]
[14841, 4345, 11, 30, 4632, 7, 4673, 10859, 234, 195, 2380, 89, 823, 11, 50, 18, 19, 468, 837, 117, 7876, 89, 196, 6203, 57, 30, 3404, 660, 37, 20, 6, 7, 30, 6735, 11, 81, 660, 4092, 5293, 144, 143, 83, 16796, 7, 83, 4621, 363, 805, 1201, 13157, 808, 303, 17097, 132, 3054, 7728, 11, 20, 19, 2447, 37, 19, 175, 1949, 2988, 59, 4673, 6932, 53, 462, 83, 10034, 37]
[30, 287, 1275, 1276, 29, 728, 358, 18, 16794, 600, 1393, 5098, 1109, 140, 311, 9214, 59, 16795, 82, 1288, 132, 600, 196, 16793, 2154, 757, 5498, 30, 5325, 37, 20, 3738, 1217, 83, 16796, 7, 83, 4621, 143, 83, 2463, 59, 9026, 30, 4399, 144, 2814, 30, 16796, 7, 83, 4621, 4876, 16797, 11, 20, 30, 1276, 50, 47, 16798, 59, 13587, 3013, 47, 287, 37, 0, 0, 0, 0, 0, 0]
[30, 7708, 7, 9876, 5543, 9877, 11, 

[[75, 832, 59, 8332, 1124, 954, 47, 125, 488, 191, 8646], [7460, 7461, 1182, 486, 589, 6272, 246, 591, 5792], [8988, 2474, 1345, 1124, 807, 8492, 59, 2055, 346, 4304, 12319, 11458, 47, 83, 4926], [12635, 143, 1370, 1412, 2613, 14067, 12919, 12920, 191, 2438]]
[8988, 2474, 939, 2309, 106, 807, 8492, 59, 2055, 311, 4304, 12319, 11458, 47, 83, 184, 11, 246, 8501, 15428, 8806, 39, 15621, 41, 249, 2859, 47, 30, 915, 916, 807, 29, 824, 37, 30, 69, 342, 1046, 939, 30, 1066, 1130, 59, 2055, 4837, 12319, 11, 26, 19, 6796, 15622, 8806, 39, 15623, 41, 249, 3099, 4439, 47, 2667, 37, 2474, 235, 176, 30, 3571, 1130, 8492, 59, 4446, 6583, 807, 5925, 47, 83, 184, 37]
[30, 105, 8649, 2196, 39, 8646, 41, 16, 50, 18, 81, 75, 832, 59, 8332, 125, 488, 7891, 1130, 954, 37, 17065, 17066, 11, 153, 4115, 8646, 120, 3298, 7257, 11, 50, 18, 362, 2269, 819, 16, 7720, 3865, 125, 488, 11, 30, 712, 2617, 75, 5781, 14021, 11, 3923, 819, 14811, 3913, 29, 125, 17067, 37, 466, 11, 1130, 193, 16891, 263, 968, 47, 261, 26

[[692, 6716, 1147, 29, 13258, 310, 11, 13259, 82, 144, 1507, 146, 8152], [5175, 6418, 1835, 5179, 5823, 67, 1576, 7692, 5439], [303, 3267, 5939, 486, 237, 82, 4768, 47, 2311, 191, 1654], [1109, 4387, 47, 6497, 7123, 47, 13888]]
[260, 1654, 485, 16, 50, 19, 10087, 18, 5939, 5943, 486, 237, 82, 4768, 47, 2311, 37, 19, 175, 11, 20, 5939, 16, 83, 4128, 143, 10088, 37, 303, 434, 628, 47, 10089, 83, 31, 140, 1452, 2256, 270, 808, 7205, 37, 20, 34, 30, 4611, 7, 274, 31, 120, 3505, 120, 11, 5939, 99, 50, 18, 81, 195, 83, 10090, 3266, 83, 10091, 82, 1654, 143, 30, 425, 5213, 37]
[83, 692, 1908, 83, 1147, 29, 83, 13258, 6638, 132, 973, 6815, 482, 2571, 8152, 158, 144, 1507, 146, 37, 78, 99, 1106, 59, 155, 477, 11, 132, 1908, 217, 83, 969, 8349, 1069, 30, 592, 50, 19, 629, 59, 154, 482, 83, 8152, 37, 481, 116, 28, 482, 945, 8152, 143, 5713, 11, 399, 234, 19, 1009, 59, 1461, 83, 3468, 132, 3108, 2601, 37, 0, 0, 0, 0, 0]
[5174, 5175, 16, 4841, 1835, 83, 5179, 5823, 67, 1576, 7692, 5439, 59, 4929, 3

[[4430, 1303, 346, 342, 3280, 5407, 143, 252, 1355, 1346, 449], [2386, 16357, 720, 59, 969, 17894, 246, 1854, 7676], [15217, 9454, 82, 5600, 11, 4285, 15218, 11, 12, 15219], [13692, 3200, 13693, 143, 196, 9065, 243, 688, 47, 8509]]
[1945, 408, 13694, 13692, 16, 1365, 1329, 12892, 1505, 13693, 143, 196, 9065, 243, 5068, 47, 8509, 37, 20, 13693, 16, 83, 13695, 13696, 59, 2337, 13697, 488, 808, 812, 7039, 815, 13698, 196, 362, 975, 47, 30, 2864, 11, 132, 303, 3306, 5069, 1989, 83, 3294, 5176, 11, 20, 19, 175, 37, 13692, 277, 50, 13693, 235, 707, 1395, 57, 83, 1042, 7, 5596, 59, 3259, 1835, 8509, 37]
[773, 3202, 4436, 4430, 16, 345, 1308, 83, 5408, 5407, 59, 1341, 143, 30, 449, 7, 1355, 1346, 7, 30, 125, 618, 1182, 37, 20, 5409, 132, 83, 1346, 7, 5410, 705, 70, 47, 30, 125, 904, 905, 132, 16, 5312, 82, 63, 735, 3528, 986, 11, 20, 117, 3829, 684, 37, 5411, 2262, 7, 4997, 16, 513, 52, 4430, 59, 1378, 83, 5412, 3829, 37, 0, 0, 0, 0]
[30, 17895, 17896, 2386, 234, 8266, 17897, 776, 5132, 7, 720

[[424, 218, 2018, 59, 7176, 7177, 3220, 47, 3675], [5090, 3013, 486, 144, 5091, 47, 5092, 3254, 246, 672, 324], [11699, 11700, 12480, 8083, 4850, 3439, 59, 12358, 1150], [3418, 13627, 409, 143, 606, 249, 590, 591]]
[30, 450, 451, 218, 29, 509, 1259, 2018, 59, 7176, 7177, 3220, 47, 2672, 7178, 7, 30, 3675, 5441, 59, 127, 7179, 37, 57, 305, 306, 11, 1839, 2172, 293, 971, 47, 3675, 132, 3528, 1936, 219, 4718, 120, 1633, 132, 1210, 293, 7180, 424, 47, 6452, 132, 27, 37, 20, 839, 752, 7181, 479, 75, 92, 1164, 249, 7182, 11, 20, 50, 30, 218, 37]
[5093, 3013, 47, 5090, 11, 5094, 451, 16, 580, 410, 144, 2175, 771, 47, 5095, 144, 224, 132, 144, 5096, 5097, 3254, 67, 5098, 57, 3016, 246, 30, 672, 1314, 5099, 99, 5100, 11, 83, 4142, 5101, 50, 37, 30, 3013, 1896, 16, 2259, 59, 30, 558, 132, 30, 4863, 72, 7, 252, 29, 30, 1683, 132, 235, 5102, 83, 5103, 11, 30, 1350, 175, 37]
[11699, 11700, 16, 1365, 40, 3983, 18, 2863, 29, 12481, 132, 3841, 4850, 3439, 132, 50, 202, 153, 12482, 686, 688, 42, 20, 12

[[17951, 6704, 8532, 497, 891, 89, 6944, 6945, 2378], [997, 7, 16842, 143, 1284, 681, 83, 2355, 7, 111, 191, 6249], [698, 10394, 9036, 29, 9342, 11, 9476, 870, 3843, 7, 2784], [8506, 73, 123, 165, 8598, 67, 144, 16032, 2561, 9024, 840]]
[30, 725, 1277, 29, 115, 3450, 18, 30, 40, 11, 731, 3019, 132, 15666, 11, 1720, 997, 16842, 132, 16843, 9367, 47, 30, 1647, 26, 1018, 235, 83, 324, 29, 8824, 7310, 532, 37, 20, 81, 16, 387, 358, 18, 1839, 7, 30, 1196, 4165, 16844, 8748, 1319, 14070, 30, 1063, 7, 471, 2078, 59, 92, 467, 89, 455, 678, 11, 20, 50, 30, 1284, 1285, 37]
[698, 29, 304, 1898, 243, 2550, 7, 252, 13316, 47, 243, 9115, 89, 2623, 30, 17213, 17214, 17105, 11, 132, 50, 18, 799, 106, 542, 9035, 9036, 20, 17215, 8330, 20, 81, 37, 1013, 11, 81, 9476, 1036, 9036, 99, 2259, 1060, 30, 17216, 17105, 29, 17217, 132, 30, 10750, 7, 30, 17218, 3280, 47, 17219, 47, 17220, 11, 47, 30, 1204, 2784, 47, 5826, 37, 0]
[3369, 29, 2561, 16033, 7, 144, 346, 146, 140, 84, 75, 196, 3830, 59, 8506, 73, 123,

[[1036, 42, 30, 90, 1667, 752, 30, 159, 123, 144], [125, 1125, 1182, 7518, 47, 7979, 7935, 47, 1772], [5906, 617, 13333, 47, 185, 70, 5949, 183, 191, 836], [4207, 1851, 9301, 53, 4632, 235, 1611, 83, 8570, 411]]
[4079, 4573, 42, 5366, 83, 90, 1228, 18, 957, 9301, 53, 83, 1151, 4632, 235, 1060, 59, 438, 83, 8570, 411, 37, 8667, 11, 47, 11794, 82, 9808, 4211, 11795, 11, 42, 491, 29, 3206, 18, 957, 3085, 132, 11796, 30, 11797, 7, 30, 4632, 37, 20, 30, 3145, 378, 11798, 82, 30, 11799, 841, 59, 6032, 4635, 132, 833, 30, 1151, 11, 20, 50, 83, 11800, 37]
[30, 159, 123, 3723, 11, 234, 16, 387, 2334, 89, 30, 2908, 2909, 11, 604, 82, 83, 1063, 7, 15366, 11, 651, 2147, 83, 840, 7, 144, 346, 146, 29, 2561, 6080, 37, 30, 123, 176, 2769, 3185, 6068, 143, 2549, 4371, 2155, 2595, 132, 143, 975, 83, 5813, 37, 277, 11, 30, 123, 15367, 4371, 2155, 1321, 59, 4369, 4427, 44, 820, 45, 46, 37, 0, 0]
[778, 59, 306, 11, 260, 5906, 745, 5908, 75, 92, 912, 1178, 83, 13333, 47, 185, 70, 408, 5909, 13334, 183, 31,

[[2188, 146, 202, 6971, 59, 996, 7111, 29, 6245, 7112], [252, 729, 2575, 217, 5476, 683, 342, 1223, 6936, 2098, 6937], [418, 7344, 7345, 75, 7346, 120, 870, 456, 120, 191, 1636], [1050, 5549, 1099, 59, 1166, 18119, 143, 754, 217, 541]]
[287, 1026, 726, 1635, 1636, 1365, 418, 3931, 143, 2538, 1217, 29, 1196, 132, 196, 5618, 30, 456, 4292, 1060, 30, 1777, 4489, 7347, 37, 144, 456, 890, 2267, 1946, 495, 234, 456, 808, 235, 7348, 67, 117, 1374, 7349, 11, 144, 19, 50, 37, 19, 176, 1663, 3931, 143, 887, 1185, 1109, 7, 30, 7350, 7, 30, 456, 11, 1370, 81, 195, 4694, 89, 7351, 7, 202, 37]
[252, 2028, 810, 5773, 6938, 47, 30, 5476, 683, 342, 1223, 6936, 2098, 6937, 47, 5395, 29, 115, 59, 189, 30, 2575, 2598, 47, 30, 2931, 37, 252, 4, 30, 1631, 47, 30, 4621, 6939, 399, 83, 3237, 5049, 89, 6940, 11, 757, 6941, 6942, 6796, 83, 6943, 59, 1461, 81, 6938, 37, 6937, 252, 163, 83, 2561, 2562, 7, 144, 1507, 146, 143, 1464, 1914, 37, 0, 0]
[30, 5145, 1642, 16, 580, 7113, 67, 7114, 146, 5850, 6, 7, 30, 122

[[13694, 14043, 5944, 120, 14044, 120, 1964, 30, 1965], [799, 10397, 196, 3400, 44, 3297, 191, 504, 505], [12611, 4971, 478, 4618, 814, 1851, 2055, 2998, 252, 1947, 191, 3877], [591, 4355, 4833, 143, 2071, 11900, 2660, 2490, 82, 4834, 11, 7315]]
[408, 13694, 14043, 31, 14044, 234, 199, 29, 728, 11, 20, 235, 1109, 7, 30, 3695, 349, 5068, 5158, 187, 11, 20, 374, 1971, 1338, 37, 30, 7401, 374, 11, 20, 203, 83, 14045, 8337, 7, 14046, 11, 20, 757, 1338, 7, 252, 1471, 81, 20, 30, 106, 2305, 13749, 143, 281, 561, 20, 217, 3749, 37, 30, 31, 16, 387, 1982, 14047, 39, 1971, 1338, 11, 30, 7401, 41, 132, 9145, 1732, 172, 39, 1338, 7, 252, 41, 37]
[246, 3297, 13474, 217, 30, 13475, 2091, 2995, 293, 1663, 11, 1906, 160, 504, 505, 50, 11, 20, 3297, 235, 48, 71, 8234, 37, 799, 10397, 4876, 3400, 44, 3297, 37, 20, 505, 50, 3297, 235, 799, 497, 7099, 132, 278, 600, 92, 496, 7, 81, 37, 19, 175, 11, 20, 13476, 13477, 13477, 3850, 13476, 805, 4647, 13478, 13479, 13480, 13481, 37, 13482, 5742, 3363, 2055, 4

[[192, 207, 1543, 363, 1544, 59, 1545, 488, 191, 1546, 1547], [125, 1125, 1768, 59, 2520, 243, 120, 17647, 120, 89, 1702], [4612, 89, 17706, 784, 18191, 775, 12, 481, 15627, 29, 28], [3331, 7, 2279, 67, 951, 7, 5600, 6375, 5062, 3205, 151]]
[192, 1548, 497, 1546, 1547, 16, 50, 192, 494, 196, 1543, 482, 1544, 59, 1545, 488, 37, 20, 812, 1549, 815, 280, 1550, 83, 1551, 1552, 153, 1553, 83, 1554, 1555, 808, 57, 482, 497, 11, 1556, 804, 1018, 11, 206, 192, 1557, 1036, 481, 1558, 731, 1559, 132, 1560, 11, 20, 481, 175, 37, 1546, 277, 50, 481, 195, 1561, 89, 1562, 1563, 47, 185, 70, 11, 333, 132, 1564, 1565, 37]
[399, 30, 421, 2486, 7, 17706, 784, 3897, 17710, 17712, 11, 4266, 67, 117, 775, 18192, 17711, 2478, 585, 588, 11, 234, 195, 345, 979, 11, 17, 481, 99, 387, 6065, 29, 482, 636, 82, 117, 18193, 13715, 18194, 37, 1109, 7, 30, 606, 3507, 11, 20, 19, 633, 1859, 1185, 11, 19, 195, 1121, 39, 4387, 41, 808, 303, 438, 1435, 20, 37, 30, 809, 7, 606, 438, 513, 387, 6387, 37]
[30, 125, 1125, 11,

[[806, 2897, 59, 4660, 1775, 6319, 833, 191, 7829], [498, 1109, 75, 2650, 2823, 53, 2986, 3196, 3504, 191, 3343, 3344, 2295, 726], [9143, 6827, 3452, 59, 832, 44, 3185, 191, 996], [3099, 3834, 996, 6159, 1801, 1063, 7, 736]]
[392, 3344, 382, 726, 1020, 8280, 29, 728, 50, 498, 1109, 75, 2650, 30, 2823, 53, 4301, 372, 120, 2207, 42, 3504, 37, 20, 381, 253, 92, 1946, 1692, 37, 89, 1091, 270, 39, 8452, 3674, 6121, 342, 83, 41, 11, 38, 42, 196, 7317, 30, 15309, 11, 20, 8280, 50, 37, 20, 274, 39, 15309, 120, 41, 13260, 235, 1337, 132, 1376, 15310, 206, 38, 42, 15311, 652, 7144, 234, 12011, 252, 11, 20, 481, 175, 37]
[618, 2152, 2262, 2105, 7828, 7829, 16, 50, 6479, 818, 806, 2897, 59, 4660, 1775, 6319, 833, 37, 1013, 11, 30, 773, 801, 50, 78, 75, 6150, 30, 90, 1346, 246, 10527, 806, 37, 20, 19, 16, 196, 1308, 45, 4964, 11, 19, 16, 498, 1849, 206, 369, 5899, 19, 235, 30, 801, 11, 234, 235, 30, 2049, 1873, 4128, 47, 125, 618, 11, 20, 50, 7829, 37, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[6160, 59, 2520,

[[2889, 6398, 29, 3196, 7, 48, 836, 29, 21, 22, 191, 2266], [283, 72, 1342, 799, 5934, 8918, 59, 4925, 173, 449, 672, 324], [3292, 7, 835, 7666, 3999, 7667, 191, 7668, 7669], [90, 849, 346, 7708, 1429, 849, 57, 83, 635, 11, 3867]]
[2264, 2265, 2266, 50, 30, 40, 600, 5195, 2692, 132, 8824, 39, 303, 3344, 2727, 41, 160, 6397, 6398, 29, 1036, 436, 59, 30, 836, 1433, 89, 30, 2986, 2262, 11, 234, 19, 4115, 11, 143, 16921, 30, 21, 22, 37, 2266, 175, 11, 20, 1271, 539, 78, 1164, 45, 839, 29, 30, 5662, 990, 144, 20, 30, 836, 99, 3070, 18, 22, 600, 196, 3236, 16922, 206, 1217, 15571, 5068, 37]
[30, 283, 72, 29, 824, 8919, 799, 8920, 5934, 8918, 143, 8921, 3954, 59, 4925, 173, 67, 243, 6707, 8114, 7, 8922, 173, 1650, 59, 687, 37, 20, 30, 8923, 8924, 47, 8925, 6, 788, 3744, 3254, 132, 8926, 2154, 82, 90, 8927, 1851, 607, 4219, 59, 30, 6058, 7, 870, 6559, 8928, 39, 8929, 3006, 132, 1679, 2155, 11, 1061, 11, 132, 272, 8928, 41, 11, 20, 81, 50, 37, 0]
[4778, 1080, 926, 2721, 7669, 16, 50, 30, 119, 3

[[2571, 195, 30, 1062, 3086, 8010, 47, 5771], [2910, 456, 508, 47, 424, 532, 5514, 757, 29, 847], [452, 729, 70, 11, 6766, 1109, 1343, 7, 6767, 852], [303, 339, 168, 59, 8546, 1648, 11, 498, 8547, 1060, 81, 191, 4344]]
[30, 452, 456, 16, 2168, 83, 6768, 2810, 47, 70, 457, 458, 11, 198, 5888, 471, 47, 30, 6668, 342, 2105, 222, 37, 452, 6769, 410, 6770, 173, 7, 30, 1801, 3426, 2222, 37, 30, 6248, 6253, 456, 39, 2886, 173, 3426, 2222, 41, 99, 30, 6771, 1063, 7, 471, 47, 30, 222, 206, 30, 1689, 6772, 6773, 5568, 6774, 4621, 47, 4163, 7, 3426, 2222, 217, 1791, 173, 37]
[2910, 712, 14592, 508, 11, 11340, 14593, 11, 1190, 399, 83, 8570, 411, 34, 35, 725, 847, 378, 29, 1017, 47, 13887, 880, 37, 19, 195, 10108, 59, 83, 3890, 3058, 1069, 19, 195, 747, 1586, 4387, 37, 466, 11, 30, 2910, 456, 99, 462, 14593, 30, 3737, 249, 243, 781, 2204, 14594, 3207, 14595, 37, 30, 2808, 75, 2988, 143, 9507, 29, 4861, 4048, 37, 0, 0, 0, 0]
[193, 4344, 4345, 16, 50, 18, 3738, 498, 8547, 1060, 30, 210, 18, 481, 132

[[2672, 1790, 47, 346, 1282, 47, 287, 4927, 1102, 230, 191, 7907], [11249, 2210, 12868, 82, 15557, 95, 15558, 191, 15559], [692, 521, 143, 7411, 7887, 9681, 59, 13374, 1077, 3254], [1036, 42, 14365, 8528, 59, 1180, 1124, 874, 2485, 144]]
[246, 6795, 59, 438, 669, 2743, 658, 4178, 47, 799, 2745, 1273, 2927, 132, 13229, 7, 15560, 252, 3635, 11, 11249, 5835, 15559, 9322, 47, 35, 15561, 59, 1907, 50, 11, 20, 30, 49, 2210, 12868, 82, 15557, 15558, 59, 2548, 120, 120, 37, 3334, 18, 30, 151, 13154, 3327, 16, 83, 703, 7, 2743, 1496, 4178, 11, 9322, 50, 11249, 16, 468, 12601, 30, 15562, 665, 1024, 1226, 37]
[83, 3799, 342, 1046, 692, 16, 387, 971, 89, 105, 218, 143, 1450, 35, 6305, 47, 83, 9723, 7411, 11, 2153, 202, 59, 2507, 28, 53, 78, 629, 59, 13374, 1077, 3744, 3254, 769, 90, 37, 30, 6305, 4595, 47, 30, 12737, 3497, 7, 30, 7411, 67, 103, 2029, 59, 1878, 37, 30, 218, 10581, 30, 592, 67, 30, 2507, 2287, 14003, 47, 30, 6305, 37, 0, 0, 0, 0]
[14747, 4042, 5009, 5297, 5301, 11, 153, 16, 387, 199

[11346, 11347, 11, 9286, 935, 4252, 29, 7708, 4139, 0, 0, 0, 0]
[[12464, 59, 11702, 31, 29, 12465, 1770, 970, 191, 836], [10289, 4846, 125, 5030, 410, 4469, 1710, 11657, 411], [53, 2044, 633, 541, 47, 14121, 11, 203, 3169, 253, 191, 14122], [5517, 6766, 67, 9625, 11, 9448, 9993, 143, 346, 1270]]
[193, 14122, 485, 16, 50, 18, 53, 1018, 235, 2044, 234, 633, 541, 47, 482, 638, 31, 120, 14121, 120, 11, 81, 235, 3169, 482, 37, 20, 303, 195, 30, 3817, 2457, 11372, 143, 120, 14121, 120, 632, 1953, 4779, 47, 30, 1182, 99, 808, 1121, 1798, 274, 158, 11, 20, 481, 175, 37, 466, 11, 120, 14121, 120, 11, 176, 415, 5956, 745, 11, 195, 14123, 4621, 3297, 31, 11, 246, 30, 576, 31, 120, 14124, 14125, 257, 120, 37]
[30, 3956, 7, 83, 9994, 47, 2859, 3084, 18, 83, 5517, 195, 2042, 47, 482, 9994, 143, 311, 1270, 37, 81, 99, 5983, 67, 30, 9625, 246, 7927, 83, 6827, 37, 30, 3956, 1006, 59, 2267, 1060, 30, 9995, 534, 246, 481, 195, 3186, 89, 30, 9994, 9996, 37, 20, 19, 195, 803, 495, 9997, 5182, 632, 19, 195,

[[185, 70, 186, 187, 82, 188, 59, 189, 71, 190, 191, 192], [252, 990, 243, 14144, 14145, 7, 14146], [1387, 6980, 6981, 59, 3320, 6982, 217, 30, 396, 29, 5982], [674, 3135, 1183, 3123, 1989, 246, 681, 11, 1445, 1467]]
[778, 59, 30, 90, 14145, 7, 14146, 1229, 89, 252, 11, 83, 2463, 4843, 59, 4805, 8411, 67, 1929, 332, 339, 92, 702, 120, 9769, 120, 57, 249, 30, 97, 14147, 7, 559, 332, 11, 1395, 242, 3967, 37, 30, 1834, 11, 47, 3280, 82, 153, 14146, 8683, 11, 235, 647, 59, 7678, 30, 1063, 7, 202, 702, 120, 9769, 120, 67, 14148, 771, 39, 4502, 2768, 41, 59, 1320, 146, 37]
[193, 192, 194, 16, 50, 18, 30, 31, 185, 70, 195, 196, 187, 82, 30, 197, 7, 198, 83, 71, 31, 190, 37, 20, 26, 30, 31, 199, 11, 81, 4, 18, 200, 132, 201, 67, 30, 202, 37, 203, 196, 18, 71, 204, 42, 196, 205, 11, 206, 81, 207, 208, 209, 30, 210, 39, 67, 41, 211, 30, 31, 16, 212, 11, 20, 175, 192, 37, 213, 0, 0, 0]
[30, 674, 22, 3135, 1183, 4453, 3123, 4454, 418, 235, 4455, 242, 30, 681, 293, 747, 29, 509, 37, 778, 59, 83, 83

[[4138, 521, 82, 1878, 1815, 2571, 217, 969, 1163], [14347, 2407, 106, 125, 1147, 2027, 59, 2628, 47, 2025], [8908, 975, 8136, 59, 4938, 153, 303, 3422, 363, 3928, 191, 8909], [434, 805, 975, 477, 362, 26, 1556, 1205, 140, 1539, 191, 15868]]
[193, 15868, 7447, 16, 50, 18, 481, 207, 805, 975, 477, 362, 217, 30, 4651, 7, 1205, 140, 1539, 37, 20, 1556, 196, 309, 59, 7210, 132, 14890, 410, 81, 140, 3062, 1264, 6, 11, 20, 175, 30, 6770, 342, 1046, 193, 37, 20, 303, 7038, 144, 7039, 628, 47, 5812, 47, 276, 1975, 11, 975, 477, 132, 53, 81, 207, 541, 11, 38, 2988, 944, 945, 5596, 11, 20, 15868, 277, 15869]
[773, 8029, 14347, 2407, 75, 718, 30, 106, 125, 1147, 2027, 47, 2025, 11, 246, 975, 4712, 143, 30, 6846, 2665, 2931, 143, 1412, 11, 14348, 4922, 1075, 37, 14347, 11, 153, 8416, 217, 30, 605, 14349, 13599, 7877, 47, 70, 11, 4595, 217, 30, 2025, 5968, 14350, 270, 184, 1069, 481, 195, 4712, 37, 481, 75, 2628, 217, 30, 3609, 342, 14351, 378, 47, 8668, 183, 3433, 37, 0, 0, 0, 0]
[193, 8909, 8910,

[[2068, 2018, 70, 59, 2140, 8438, 2494, 7675], [10541, 10542, 4839, 59, 92, 363, 245, 10541, 10542, 31, 191, 10543], [1772, 1250, 17769, 59, 1658, 3112, 67, 8660, 10123, 4744], [810, 287, 1699, 12257, 59, 1672, 12258, 3024]]
[2264, 10544, 10543, 16, 50, 18, 120, 10541, 10542, 4839, 120, 75, 92, 117, 245, 31, 47, 30, 120, 10541, 10542, 120, 31, 809, 37, 19, 175, 18, 246, 270, 31, 11, 3042, 4779, 459, 208, 30, 7103, 3606, 206, 81, 16, 59, 92, 10545, 1179, 37, 30, 31, 75, 1462, 30, 6948, 7, 5600, 6375, 57, 120, 10541, 10542, 120, 132, 10546, 10547, 57, 117, 10548, 120, 9665, 120, 37]
[30, 329, 1276, 29, 509, 1142, 70, 59, 2126, 243, 2222, 7, 344, 29, 30, 8439, 8438, 2494, 7675, 37, 20, 1775, 2631, 720, 600, 92, 3205, 957, 92, 1106, 1206, 11, 20, 30, 1276, 175, 132, 2589, 30, 183, 2081, 143, 1004, 4925, 37, 466, 11, 5248, 5450, 243, 8440, 7, 8441, 30, 7675, 11, 234, 75, 4329, 243, 2222, 7, 8442, 720, 2256, 70, 37, 0, 0, 0, 0]
[5029, 112, 16865, 1895, 16, 2543, 83, 15611, 111, 18, 1250, 30,

[[818, 806, 235, 30, 281, 7492, 47, 125, 1182, 191, 1168], [6164, 15085, 11, 15086, 175, 59, 15087, 5546], [3888, 631, 1217, 26, 457, 235, 47, 1369, 191, 160], [2258, 2218, 12, 35, 6338, 3439, 836, 195, 1060, 28]]
[2563, 8221, 160, 15189, 16819, 29, 1017, 50, 18, 3888, 2685, 10200, 1217, 26, 30, 457, 235, 47, 1369, 37, 20, 3270, 30, 1369, 235, 410, 11, 42, 686, 433, 1789, 144, 20, 19, 52, 37, 20, 15765, 3888, 460, 82, 16820, 37, 498, 4319, 16, 2654, 2486, 362, 47, 3335, 1338, 168, 30, 16821, 47, 16822, 140, 30, 5329, 16823, 47, 8691, 11, 20, 19, 175, 37]
[410, 5096, 90, 4339, 11, 651, 6164, 15085, 132, 15086, 11, 293, 709, 175, 59, 30, 15087, 5546, 37, 15088, 11, 9732, 59, 30, 2972, 7, 5039, 15089, 11, 132, 3440, 2510, 15090, 293, 176, 175, 37, 7083, 11, 30, 2392, 15091, 195, 462, 83, 90, 5545, 342, 59, 15092, 1349, 1482, 63, 2507, 82, 3042, 89, 498, 499, 3015, 140, 2724, 59, 2304, 132, 884, 37, 0]
[1172, 1174, 1168, 16, 17, 18, 6479, 818, 806, 132, 4441, 4435, 42, 30, 281, 1525, 132, 

[[3675, 3676, 3677, 11, 1467, 3217, 748, 294], [21, 22, 12321, 12322, 31, 410, 2392, 3231, 191, 836], [644, 1123, 1864, 2996, 2388, 57, 17168, 132, 4740], [500, 402, 12028, 1482, 47, 2618, 2646, 726, 2649, 544, 144, 1549, 12029]]
[57, 305, 306, 11, 30, 21, 22, 16, 12323, 30, 2127, 7, 30, 12324, 30, 12325, 125, 11, 234, 235, 29, 4007, 12326, 12322, 12327, 11, 410, 30, 997, 7, 4339, 120, 3231, 120, 11, 120, 2298, 120, 11, 120, 12328, 9026, 7, 252, 120, 132, 120, 7763, 252, 120, 37, 30, 12329, 583, 12330, 12331, 50, 18, 19, 339, 3108, 3939, 59, 12332, 140, 5863, 2044, 18, 12322, 16, 50, 47, 30, 12324, 37]
[500, 402, 293, 985, 12030, 1482, 34, 773, 674, 1026, 726, 2649, 418, 2655, 12031, 12032, 12033, 47, 3128, 29, 509, 234, 195, 521, 59, 2108, 4772, 2109, 37, 2618, 160, 12034, 12035, 132, 482, 699, 636, 353, 418, 745, 4478, 11370, 34, 30, 995, 342, 1736, 428, 746, 37, 12034, 99, 798, 682, 39, 683, 41, 132, 99, 12036, 472, 674, 457, 725, 29, 500, 3737, 37, 0, 0, 0, 0, 0, 0, 0]
[30, 5103, 5

[[369, 434, 2771, 2030, 3938, 4662, 191, 5891], [819, 3332, 5671, 3239, 5672, 246, 799, 5673, 447], [17706, 17707, 956, 1358, 3126, 246, 17708, 10036], [9105, 2788, 11, 5117, 6412, 7255, 246, 10440, 47, 589]]
[125, 5674, 3239, 5672, 11, 1502, 2325, 2334, 209, 709, 11, 5675, 5676, 246, 83, 1130, 618, 3332, 5677, 125, 3937, 399, 2309, 906, 2359, 189, 249, 252, 29, 1017, 37, 30, 1130, 3332, 195, 5598, 1370, 11, 20, 5678, 5679, 5680, 144, 20, 11, 234, 235, 26, 5672, 370, 117, 5681, 132, 4841, 1304, 860, 30, 1130, 3332, 132, 5682, 217, 28, 986, 802, 794, 521, 28, 37]
[5892, 29, 1356, 5893, 618, 1354, 3508, 209, 1563, 11, 773, 801, 5894, 5891, 50, 18, 5895, 1203, 196, 2771, 2030, 3938, 57, 2631, 57, 78, 837, 47, 30, 3537, 37, 20, 369, 600, 92, 4847, 59, 3572, 752, 45, 5896, 11, 196, 468, 217, 782, 11, 20, 19, 175, 37, 466, 11, 5895, 293, 5897, 5898, 89, 810, 811, 47, 30, 709, 2568, 807, 809, 37, 0, 0]
[30, 956, 2998, 12111, 1358, 7, 5335, 5336, 17706, 784, 11, 234, 195, 2636, 59, 4203, 183, 

[[9052, 996, 2210, 438, 719, 191, 2068, 29, 525, 996, 144, 1549, 2433], [1124, 9385, 3021, 443, 3368, 82, 10938, 120, 10939, 10940, 120], [2362, 2679, 443, 120, 2987, 9925, 120, 143, 267, 10008], [4782, 59, 2100, 15465, 120, 2691, 15466, 82, 9295, 14482, 66]]
[10941, 11, 35, 10942, 9385, 1505, 47, 819, 11, 29, 824, 448, 152, 83, 10034, 59, 2222, 274, 1749, 82, 83, 20, 10939, 10940, 20, 140, 138, 10938, 37, 10941, 50, 11, 20, 202, 957, 5184, 30, 10940, 274, 10943, 132, 30, 10940, 75, 1206, 155, 47, 8553, 82, 2154, 82, 83, 6505, 10939, 37, 20, 81, 97, 1486, 20, 10941, 3876, 11, 20, 9681, 59, 1820, 3145, 10944, 59, 3112, 120, 10945, 37]
[30, 15467, 4782, 558, 235, 2958, 59, 1683, 164, 14482, 66, 59, 15468, 15465, 11, 8664, 274, 9910, 59, 13662, 9275, 89, 1339, 37, 30, 768, 235, 20, 59, 127, 4850, 15465, 1864, 5803, 11, 140, 1467, 15469, 67, 3884, 274, 3459, 7512, 2316, 8404, 15470, 11, 20, 978, 50, 37, 270, 604, 1306, 83, 809, 7, 12481, 15471, 67, 2167, 132, 5952, 59, 970, 663, 15465, 37,

In [202]:
print(target_seqs)

PackedSequence(data=tensor([  672,   324, 12876,    47,  2561,   168,    83, 12877,  9996,   191,
        11965,  1350,  7516,   584,  2904,   402,    59,   717,   249,   445,
           47,   470,   471,     0,   410,  1050,   287,  1648,  1789,  1989,
          993,  1851,    47,  3723,     0,     0,   897,  1056,  4901, 12218,
           47, 12219,   250,  4635,     0,     0,     0,     0]), batch_sizes=tensor([12, 12, 12, 12]))


In [194]:
# tf = transforms.Compose([ToTensor(lang)])
tf = None
dataset = SummaryDataset(fname="train.csv", transform=tf)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)


sb = next(iter(dataloader))
input_seqs = sb["text"]
target_seqs = sb["headlines"]
input_seqs = [json.loads(s) for s in input_seqs]
target_seqs = [json.loads(s) for s in target_seqs]
input_seqs, input_lengths, target_seqs, target_lengths = batch(input_seqs, target_seqs, lang)

[[1556, 196, 4575, 9260, 9261, 11, 3262, 363, 14839, 5288, 191, 14840], [14566, 132, 4761, 7, 1897, 82, 252, 11, 1124, 1337, 191, 589], [473, 3724, 381, 144, 1549, 2989, 217, 183, 3055, 47, 2305, 1249, 191, 259, 29, 16014], [1412, 2222, 953, 7, 120, 4077, 8729, 120, 29, 375]]
[261, 262, 263, 11, 757, 2321, 1060, 10521, 10522, 10523, 16015, 39, 16014, 41, 1390, 5180, 11, 50, 11, 20, 47, 2305, 1249, 11, 303, 339, 1658, 30, 1384, 11, 140, 3724, 7154, 2989, 217, 30, 183, 3055, 37, 20, 30, 5180, 1429, 6260, 1730, 860, 259, 11, 153, 235, 29, 83, 1394, 1384, 37, 19, 175, 11, 20, 303, 339, 3108, 1203, 81, 47, 2305, 1249, 11, 3647, 363, 1287, 6, 11, 5880, 482, 16016, 37, 20]
[260, 14840, 14841, 16, 1610, 5805, 7, 4575, 9260, 9261, 132, 50, 18, 3262, 117, 14839, 5288, 37, 19, 175, 11, 20, 369, 42, 803, 1750, 47, 1464, 8982, 2445, 37, 481, 16, 35, 6931, 9617, 7, 14842, 11, 132, 303, 276, 2590, 1318, 82, 482, 37, 20, 5805, 7, 30, 311, 4575, 660, 26, 78, 345, 1201, 6, 143, 5622, 29, 14843, 1233, 47

In [None]:

def train(input_batches, input_lengths, target_batches, target_lengths, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):
    
    # Zero gradients of both optimizers
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    loss = 0 # Added onto for each word

    # Run words through encoder
    encoder_outputs, encoder_hidden = encoder(input_batches, input_lengths, None)
    
    # Prepare input and output variables
    decoder_input = Variable(torch.LongTensor([SOS_token] * batch_size))
    decoder_hidden = encoder_hidden[:decoder.n_layers] # Use last (forward) hidden state from encoder

    max_target_length = max(target_lengths)
    all_decoder_outputs = Variable(torch.zeros(max_target_length, batch_size, decoder.output_size))

    # Move new Variables to CUDA
    if USE_CUDA:
        decoder_input = decoder_input.cuda()
        all_decoder_outputs = all_decoder_outputs.cuda()

    # Run through decoder one time step at a time
    for t in range(max_target_length):
        decoder_output, decoder_hidden, decoder_attn = decoder(
            decoder_input, decoder_hidden, encoder_outputs
        )

        all_decoder_outputs[t] = decoder_output
        decoder_input = target_batches[t] # Next input is current target

    # Loss calculation and backpropagation
    loss = masked_cross_entropy(
        all_decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq
        target_batches.transpose(0, 1).contiguous(), # -> batch x seq
        target_lengths
    )
    loss.backward()
    
    # Clip gradient norms
    ec = torch.nn.utils.clip_grad_norm(encoder.parameters(), clip)
    dc = torch.nn.utils.clip_grad_norm(decoder.parameters(), clip)

    # Update parameters with optimizers
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return loss.data[0], ec, dc

In [None]:
DF = []
for df in pd.read_csv(RAW_DATA_DIR + 'train.csv', sep="\t", header = 0, chunksize=10):
    DF  = df
    break