In [1]:
import numpy as np
import torch
import torch.nn as nn
import random
import time
import copy
from torch.nn import functional as F
from torch.autograd import Variable
from torch import optim
from rouge import Rouge
from data import *
from utils import *

## Tokens
e.g.
```
[["Musicians to tackle US red tape Musicians ' groups are to tackle US visa regulations which are blamed for hindering",
  "Nigel McCune from the Musicians ' Union said British musicians"],
 ["U2 's desire to be number one U2 , who have won three prestigious Grammy Awards for their hit Vertigo",
  'But they still want more.They have to want to be'],
 ["Rocker Doherty in on-stage fight Rock singer Pete Doherty has been involved in a fight with his band 's guitarist",
  'Babyshambles , which he formed after his acrimonious departure from']]
```

In [2]:
train, dev = load_datasets('./Datasets/BBC_News_100_50.pkl', './Datasets/BBC_News_100_50.pkl')

In [3]:
train[0]

['Byrds producer Melcher dies at 62 Record producer Terry Melcher , who was behind hits by the Byrds , Ry Cooder and the Beach Boys , has died aged 62 . The son of actress Doris Day , he helped write Kokomo for the Beach Boys , which was used in the movie Cocktail , earning a 1988 Golden Globe nomination . He also produced Mr Tambourine Man for the Byrds , as well as other his such as Turn , Turn Turn . Melcher died on Friday night at his home in Beverly Hills , California , after a',
 "Record producer Terry Melcher , who was behind hits by the Byrds , Ry Cooder and the Beach Boys , has died aged 62.Rumours circulated that Melcher - who knew Manson - was the killer 's real target , because he had turned him down for a record contract.Melcher also"]

In [4]:
# train = train[0:10]
# dev = dev[0:1]

## Tokens Index
e.g.
```
[Musicians to tackle US red tape Musicians ' groups are to tackle US visa regulations which are blamed for hindering => Nigel McCune from the Musicians ' Union said British musicians
    indexed as: [2, 3, 4, 5, 6, 7, 2, 8, 9, 10, 3, 4, 5, 11, 12, 13, 10, 14, 15, 16] => [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 2],
 U2 's desire to be number one U2 , who have won three prestigious Grammy Awards for their hit Vertigo => But they still want more.They have to want to be
    indexed as: [17, 18, 19, 3, 20, 21, 22, 17, 23, 24, 25, 26, 27, 28, 29, 30, 15, 31, 32, 33] => [13, 14, 15, 16, 17, 18, 19, 16, 19, 20, 2],
 Rocker Doherty in on-stage fight Rock singer Pete Doherty has been involved in a fight with his band 's guitarist => Babyshambles , which he formed after his acrimonious departure from
    indexed as: [34, 35, 36, 37, 38, 39, 40, 41, 35, 42, 43, 44, 36, 45, 38, 46, 47, 48, 18, 49] => [21, 22, 23, 24, 25, 26, 27, 28, 29, 5, 2]]
```

In [5]:
train_data_indexed, dev_data_indexed, vocab_indexer = index_datasets(train, dev)

## Padding
- Pad the train/dev input vectors to the max length of the train/dev input documents.
- Pad the train/dev output vectors to the max length of the train/dev output summerization.

![](https://i.imgur.com/gGlkEEF.png)

In [6]:
def make_padded_input_tensor(exs, vocab_indexer, max_len):
    return np.array([[ex.x_indexed[i] if i < len(ex.x_indexed) else vocab_indexer.index_of(PAD_SYMBOL)
                        for i in range(0, max_len)] for ex in exs])

In [7]:
def make_padded_output_tensor(exs, vocab_indexer, max_len):
    return np.array([[ex.y_indexed[i] if i < len(ex.y_indexed) else vocab_indexer.index_of(PAD_SYMBOL)
                        for i in range(0, max_len)] for ex in exs])

## Batch

In [8]:
def batch_data(input_array, batch_size=2, cuda=False):
    input_batches = []
    batch_num = (int)(input_array.shape[0] / batch_size)
    start = 0
    for i in range(batch_num):
        batch = torch.from_numpy(input_array[start:start+batch_size, :])
        if cuda:
            batch = batch.cuda()
        input_batches.append(batch)
        start += batch_size
    if start != input_array.shape[0]:
        batch = torch.from_numpy(input_array[start:, :])
        if cuda:
            batch = batch.cuda()
        input_batches.append(batch)
    return input_batches

## Embedding

In [9]:
class EmbeddingLayer(nn.Module):
    # Parameters: dimension of the word embeddings, number of words, and the dropout rate to apply
    # (0.2 is often a reasonable value)
    def __init__(self, input_dim, full_dict_size, embedding_dropout_rate):
        super(EmbeddingLayer, self).__init__()
        self.dropout = nn.Dropout(embedding_dropout_rate)
        self.word_embedding = nn.Embedding(full_dict_size, input_dim)

    # Takes either a non-batched input [sent len x input_dim] or a batched input
    # [batch size x sent len x input dim]
    def forward(self, input):
        embedded_words = self.word_embedding(input)
        final_embeddings = self.dropout(embedded_words)
        return final_embeddings

## Encoder

In [10]:
# One-layer RNN encoder for batched inputs -- handles multiple sentences at once. You're free to call it with a
# leading dimension of 1 (batch size 1) but it does expect this dimension.
class RNNEncoder(nn.Module):
    # Parameters: input size (should match embedding layer), hidden size for the LSTM, dropout rate for the RNN,
    # and a boolean flag for whether or not we're using a bidirectional encoder
    def __init__(self, input_size, hidden_size, dropout, bidirect, CUDA=False):
        super(RNNEncoder, self).__init__()
        self.CUDA = CUDA
        self.bidirect = bidirect
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.reduce_h_W = nn.Linear(hidden_size * 2, hidden_size, bias=True)
        self.reduce_c_W = nn.Linear(hidden_size * 2, hidden_size, bias=True)
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True,
                               dropout=dropout, bidirectional=self.bidirect)
        self.init_weight()

    # Initializes weight matrices using Xavier initialization
    def init_weight(self):
        nn.init.xavier_uniform_(self.rnn.weight_hh_l0, gain=1)
        nn.init.xavier_uniform_(self.rnn.weight_ih_l0, gain=1)
        if self.bidirect:
            nn.init.xavier_uniform_(self.rnn.weight_hh_l0_reverse, gain=1)
            nn.init.xavier_uniform_(self.rnn.weight_ih_l0_reverse, gain=1)
        nn.init.constant_(self.rnn.bias_hh_l0, 0)
        nn.init.constant_(self.rnn.bias_ih_l0, 0)
        if self.bidirect:
            nn.init.constant_(self.rnn.bias_hh_l0_reverse, 0)
            nn.init.constant_(self.rnn.bias_ih_l0_reverse, 0)

    def get_output_size(self):
        return self.hidden_size * 2 if self.bidirect else self.hidden_size

    def sent_lens_to_mask(self, lens, max_length):
        return torch.from_numpy(np.asarray([[1 if j < lens.data[i].item() else 0 for j in range(0, max_length)] for i in range(0, lens.shape[0])]))

    # embedded_words should be a [batch size x sent len x input dim] tensor
    # input_lens is a tensor containing the length of each input sentence
    # Returns output (each word's representation), context_mask (a mask of 0s and 1s
    # reflecting where the model's output should be considered), and h_t, a *tuple* containing
    # the final states h and c from the encoder for each sentence.
    def forward(self, embedded_words, input_lens):
        # Takes the embedded sentences, "packs" them into an efficient Pytorch-internal representation
        packed_embedding = nn.utils.rnn.pack_padded_sequence(embedded_words, input_lens, batch_first=True)

        # Runs the RNN over each sequence. Returns output at each position as well as the last vectors of the RNN
        # state for each sentence (first/last vectors for bidirectional)
        output, hn = self.rnn(packed_embedding)
        

        # Unpacks the Pytorch representation into normal tensors
        output, sent_lens = nn.utils.rnn.pad_packed_sequence(output)
#         print('kdjfksdjfs: ', output.shape)
        
        # print('input_lens:', input_lens)
        max_length = input_lens.data[0].item()
        context_mask = self.sent_lens_to_mask(sent_lens, max_length)
        if self.CUDA:
            context_mask = context_mask.cuda()

        # Grabs the encoded representations out of hn, which is a weird tuple thing.
        # Note: if you want multiple LSTM layers, you'll need to change this to consult the penultimate layer
        # or gather representations from all layers.
        if self.bidirect:
            h, c = hn[0], hn[1]          # [2, 20, 200]
            # print('encoder hidden:----- ', h.shape)
            # print('encoder cell:----- ', c.shape)
            # Grab the representations from forward and backward LSTMs
            h_, c_ = torch.cat((h[0], h[1]), dim=1), torch.cat((c[0], c[1]), dim=1)      # [20, 400]
            # print('kdjfksdddddddddddjfs: ',h_.shape)
            # Reduce them by multiplying by a weight matrix so that the hidden size sent to the decoder is the same
            # as the hidden size in the encoder
            new_h = self.reduce_h_W(h_)
            new_c = self.reduce_c_W(c_)
            h_t = (new_h, new_c)
        else:
            h, c = hn[0][0], hn[1][0]
            h_t = (h, c)
        return (output, context_mask, h_t)

## Attention-based Decoder

In [11]:
class AttnRNNDecoderBahdanau(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout, bidirect):
        super(AttnRNNDecoderBahdanau, self).__init__()

        self.input_size = input_size
#         self.sent_lens = sent_lens
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout = dropout
        self.bidirect = bidirect
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers=1, 
                                dropout=dropout, bidirectional=bidirect)
        
        
        self.context = nn.Linear(hidden_size * 2 + input_size, input_size)
        self.W_h = nn.Linear(2*hidden_size, 2*hidden_size)
        self.W_c = nn.Linear(1, hidden_size * 2, bias=False)
        self.W_s = nn.Linear(hidden_size * 2, hidden_size * 2)
        self.v = nn.Linear(hidden_size * 2, 1, bias=False)
        self.V = nn.Linear(hidden_size * 2 + hidden_size, hidden_size)
        self.V_p = nn.Linear(hidden_size, output_size)
        self.P_gen_layer = nn.Linear(hidden_size * 4 + input_size, 1)

        self.init_weight()
        
    # Initializes weight matrices using Xavier initialization
    def init_weight(self):
        nn.init.xavier_uniform_(self.rnn.weight_hh_l0, gain=1)
        nn.init.xavier_uniform_(self.rnn.weight_ih_l0, gain=1)
        if self.bidirect:
            nn.init.xavier_uniform_(self.rnn.weight_hh_l0_reverse, gain=1)
            nn.init.xavier_uniform_(self.rnn.weight_ih_l0_reverse, gain=1)
        nn.init.constant_(self.rnn.bias_hh_l0, 0)
        nn.init.constant_(self.rnn.bias_ih_l0, 0)
        if self.bidirect:
            nn.init.constant_(self.rnn.bias_hh_l0_reverse, 0)
            nn.init.constant_(self.rnn.bias_ih_l0_reverse, 0)

    def forward(self, embedded_words, dec_hidden, enc_output, context_mask, pre_cont, coverage):
        sent_lens = enc_output.shape[0]
        enc_feature = enc_output.view(-1, 2*hidden_size)                # batch_size*sent_lens, 2*hidden_size
        rnn_input = self.context(torch.cat((pre_cont, embedded_words), 1))   # batch_size, input_size
        dec_output, hn = self.rnn(rnn_input.unsqueeze(1).transpose(0, 1), dec_hidden)   # 1, batch_size, hidden_size
        dec_output = dec_output.transpose(0, 1)                             # batch_size, 1, hidden_size
        h_dec, c_dec = hn
        s_t_hat = torch.cat((h_dec.view(-1, hidden_size),
                             c_dec.view(-1, hidden_size)), 1)           # batch_size, 2*hidden_size
        del h_dec, c_dec
        
        # Attention Distribution
        dec_state = self.W_s(s_t_hat)                              # batch_size, 2*hidden_size
        dec_state_expanded = dec_state.unsqueeze(1).expand(dec_state.shape[0], sent_lens, dec_state.shape[1]).contiguous() # batch_size, sent_lens, 2*hidden_size
        dec_state_expanded = dec_state_expanded.view(-1, dec_state.shape[1])  # batch_size*sent_lens, 2*hidden_size
        del dec_state
        
#         print(enc_feature.shape)
#         print(dec_state_expanded.shape)
#         print(coverage.shape)
        e = self.v(torch.tanh(self.W_h(enc_feature) + dec_state_expanded + self.W_c(coverage.view(-1, 1)))).view(-1, sent_lens)  # batch_size, sent_lens
#         att_feature = enc_feature + dec_state_expanded   # batch_size*sent_lens, 2*hidden_size
#         # Coverage
#         coverage_feature = self.W_c(coverage.view(-1, 1))  # batch_size*sent_lens, 2*hidden_size
#         att_feature = att_feature + coverage_feature   # batch_size*sent_lens, 2*hidden_size
#         e = torch.tanh(att_feature)       # batch_size*sent_lens, 2*hidden_size
#         attn_scores = self.v(e).view(-1, self.sent_lens)      # batch_size, sent_lens
        del enc_feature
        attn_distrib_ = F.softmax(e, dim=1)*context_mask.float()   # batch_size, sent_lens
        del e
        norm_factor = attn_distrib_.sum(1, keepdim=True)
        attn_distrib = attn_distrib_ / norm_factor
        del attn_distrib_, norm_factor
        coverage = coverage + attn_distrib                  # batch_size, sent_lens
        attn_distrib = attn_distrib.unsqueeze(1)            # batch_size, 1, sent_lens
        
        # Context Vector
        cont_vec = torch.bmm(attn_distrib, enc_outputs.transpose(0, 1))  # batch_size, 1, 2*hidden_size
        concat_input = torch.cat((dec_output, cont_vec), dim=-1)           # batch_size, 1, enc_hidden_size * num_directions + dec_hidden_size

        # Vocabulary Distribution
        vocab_distrib = torch.softmax(self.V_p(self.V(concat_input)), dim=-1).squeeze(1)  # batch_size, output_size

        # Pointer Generator
        P_gen_input = torch.cat((cont_vec.squeeze(1), s_t_hat, rnn_input), dim=1)  # batch_size, (2*2*hidden_size + input_size)
        P_gen = self.P_gen_layer(P_gen_input)
        P_gen = torch.sigmoid(P_gen)       # batch_size, 1
        
        return (P_gen, vocab_distrib, hn, attn_distrib.squeeze(1), cont_vec.squeeze(1), coverage)
        

In [12]:
class AttnRNNDecoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout, bidirect):
        super(AttnRNNDecoder, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout = dropout
        self.bidirect = bidirect
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers=1, 
                                dropout=dropout, bidirectional=bidirect)
        self.out = nn.Linear(hidden_size, output_size)
        self.concat = nn.Linear(hidden_size * 2 + hidden_size, hidden_size)
        self.linear = nn.Linear(hidden_size * 2, hidden_size)
        
        self.W_h = nn.Linear(hidden_size * 2, 1)
        self.W_s = nn.Linear(hidden_size, 1)
        self.W_x = nn.Linear(input_size, 1)
        self.init_weight()
        
    # Initializes weight matrices using Xavier initialization
    def init_weight(self):
        nn.init.xavier_uniform_(self.rnn.weight_hh_l0, gain=1)
        nn.init.xavier_uniform_(self.rnn.weight_ih_l0, gain=1)
        if self.bidirect:
            nn.init.xavier_uniform_(self.rnn.weight_hh_l0_reverse, gain=1)
            nn.init.xavier_uniform_(self.rnn.weight_ih_l0_reverse, gain=1)
        nn.init.constant_(self.rnn.bias_hh_l0, 0)
        nn.init.constant_(self.rnn.bias_ih_l0, 0)
        if self.bidirect:
            nn.init.constant_(self.rnn.bias_hh_l0_reverse, 0)
            nn.init.constant_(self.rnn.bias_ih_l0_reverse, 0)

    # enc_output: batch_size, sent_lens, 2*hidden_size
    def forward(self, embedded_words, dec_hidden, enc_outputs, context_mask):
        embedded_words = embedded_words.view(1, embedded_words.size(0), embedded_words.size(1))   # 1, batch_size, input_size
        context_mask = context_mask.type(torch.uint8).unsqueeze(1)      # batch_size, 1, sent_lens
  
        rnn_output, hn = self.rnn(embedded_words, dec_hidden)           # 1, batch_size, hidden_size
        rnn_output = rnn_output.transpose(0, 1)                         # batch_size, 1, hidden_size        
        
        attn_scores = rnn_output.bmm(self.linear(enc_outputs).transpose(0, 1).transpose(1, 2))      # batch_size, 1, sent_lens
        attn_scores.data.masked_fill(context_mask == 0, float('inf'))      # batch_size, 1, sent_lens
#         print(attn_scores.shape)
        
        # Attention Distribution
        attn_weights = F.softmax(attn_scores.squeeze(1), dim=1).unsqueeze(1)        # batch_size, 1, sent_lens
        # Context Vector
        context = attn_weights.bmm(enc_outputs.transpose(0, 1))                     # batch_size, 1, hidden_size * num_directions
        concat_input = torch.cat((context, rnn_output), dim=-1)     # batch_size, 1, enc_hidden_size * num_directions + dec_hidden_size
        concat_output = torch.tanh(self.concat(concat_input))                       # batch_size, 1, dec_hidden_size
        # Vocabulary Distribution
        output = self.out(concat_output).squeeze(1)                                 # batch_size, output_size
        # Pointer-Generator
        p_gen = torch.sigmoid(self.W_h(context) + self.W_s(rnn_output) + self.W_x(embedded_words.transpose(0, 1))).squeeze(1)  # batch_size, 1
#         print(p_gen.shape)
        
        return (p_gen, output, hn, attn_weights.squeeze(1))

## Encoder to Decoder

In [13]:
def encode_input_for_decoder(x_tensor, inp_lens_tensor, model_input_emb, model_enc):
    input_emb = model_input_emb.forward(x_tensor)
    (enc_output_each_word, enc_context_mask, enc_final_states) = model_enc.forward(input_emb, inp_lens_tensor)
    enc_final_states_reshaped = (enc_final_states[0].unsqueeze(0), enc_final_states[1].unsqueeze(0))
    return (enc_output_each_word, enc_context_mask, enc_final_states_reshaped)

## Loss Function

In [14]:
# Implementation of loss function: masked cross entropy
# Reference to https://github.com/spro/practical-pytorch, make some modifications
def masked_cross_entropy(logits, target, length, context_mask):
    logits_flat = logits.view(-1, logits.size(-1))                  # batch_size * sent_len, vocab_size
    log_probs_flat = F.log_softmax(logits_flat, dim=-1)             # batch_size * sent_len, vocab_size
    target_flat = target.view(-1, 1)                                # batch * sent_len, 1
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)    # batch * max_len, 1
    losses = losses_flat.view(*target.size())     # batch, max_len
    losses = losses * context_mask.float()
    loss = losses.sum() / length.float().sum()
    return loss

In [15]:
def cal_step_loss(final_distrib, attn_distrib, coverage, step_Y_tensor, step_context_mask):
    step_distrib = torch.gather(final_distrib, dim=1, index=step_Y_tensor).squeeze(1)
    step_coverage_loss = torch.sum(torch.min(attn_distrib, coverage), 1)
    step_loss = -torch.log(step_distrib + eps) + cov_loss_wt*step_coverage_loss
    step_loss = step_loss * step_context_mask
    return step_loss

## Training Step

**Training Copy**

In [16]:
def pointer_generate_train(p_gen, dec_output, dec_attn, X_tensors, CUDA):
#     final_distrib = torch.zeros(dec_output.shape, dtype=torch.float)     # batch_size, vocab_size
    final_distrib = p_gen * dec_output
    dec_attn_padding = torch.zeros(X_tensors.shape, dtype=torch.float)
    if CUDA:
        dec_attn_padding = dec_attn_padding.cuda()
    dec_attn_padding[:, 0:dec_attn.shape[1]] = dec_attn 
    final_distrib = final_distrib.scatter_add(1, X_tensors, (1-p_gen)*dec_attn_padding)
    return final_distrib

In [17]:
BATCH_SIZE = 8
lr = 0.0005
input_dim = 100
output_dim = 100
hidden_size = 256
emb_dropout = 0.2
rnn_dropout = 0.2
bidirectional = True
num_epochs = 500
teacher_forcing_ratio = 1
cov_loss_wt = 1.0
eps = 1e-12
CUDA = True
pad_idx = vocab_indexer.index_of(PAD_SYMBOL)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

**Create indexed input/output for training**
- X_tensors_batch/Y_tensors_batch, list[array: [batch_size, sent_len], batch_num]
- inp_lens_batch/oup_lens_batch, list[array: [batch_size,], batch_num]

In [18]:
# Create indexed input/output for training
train_data_indexed.sort(key=lambda ex: len(ex.x_indexed), reverse=True)
input_train_max_len = np.max(np.asarray([len(ex.x_indexed) for ex in train_data_indexed]))
all_train_input_data = make_padded_input_tensor(train_data_indexed, vocab_indexer, input_train_max_len).astype(np.int64)

output_train_max_len = np.max(np.asarray([len(ex.y_indexed) for ex in train_data_indexed]))
all_train_output_data = make_padded_output_tensor(train_data_indexed, vocab_indexer, output_train_max_len).astype(np.int64)

X_tensors_batch = batch_data(all_train_input_data, BATCH_SIZE, cuda=CUDA)   # batch_num, batch_size, sent_len
Y_tensors_batch = batch_data(all_train_output_data, BATCH_SIZE, cuda=CUDA)  # batch_num, batch_size, sent_len
if CUDA:
    inp_lens_batch = [torch.tensor([torch.sum(X_tensor != 0) for X_tensor in X_tensors]).cuda() for X_tensors in X_tensors_batch]  # batch_num, batch_size
    oup_lens_batch = [torch.tensor([torch.sum(Y_tensor != 0) for Y_tensor in Y_tensors]).cuda() for Y_tensors in Y_tensors_batch]  # batch_num, batch_size
else:
    inp_lens_batch = [torch.tensor([torch.sum(X_tensor != 0) for X_tensor in X_tensors]) for X_tensors in X_tensors_batch]  # batch_num, batch_size
    oup_lens_batch = [torch.tensor([torch.sum(Y_tensor != 0) for Y_tensor in Y_tensors]) for Y_tensors in Y_tensors_batch]  # batch_num, batch_size

**Create model**
- model_input_emb/model_output_emb: embedding layer
- model_enc/model_dec: encoder/decoder
- optimizers: encoder/decoder

In [None]:
# Create model
model_input_emb = EmbeddingLayer(input_dim, len(vocab_indexer), emb_dropout)
model_enc = RNNEncoder(input_dim, hidden_size, rnn_dropout, bidirectional, CUDA=CUDA)
model_output_emb = EmbeddingLayer(output_dim, len(vocab_indexer), emb_dropout)
model_dec = AttnRNNDecoderBahdanau(input_size=output_dim, hidden_size=hidden_size, output_size=len(vocab_indexer), dropout=rnn_dropout, bidirect=False)
# CUDA
if CUDA:
    model_input_emb.cuda()
    model_enc.cuda()
    model_output_emb.cuda()
    model_dec.cuda()
# model_dec = AttnRNNDecoder(input_size=output_dim, hidden_size=hidden_size, output_size=len(vocab_indexer), dropout=rnn_dropout, bidirect=False)
enc_optimizer = optim.Adam(model_enc.parameters(), lr=lr)
dec_optimizer = optim.Adam(model_dec.parameters(), lr=lr)

  "num_layers={}".format(dropout, num_layers))


**Train Iteration**

In [None]:
start = time.time()
for epoch in range(0, num_epochs):
    print('--------------------- Epoch %d ---------------------'%(epoch+1))
    for X_tensors, Y_tensors, inp_lens_tensor, oup_lens_tensor in zip(X_tensors_batch, Y_tensors_batch, inp_lens_batch, oup_lens_batch):

        model_enc.train()
        model_dec.train()

        enc_optimizer.zero_grad()
        dec_optimizer.zero_grad()
        
        # Encoder
        enc_outputs, enc_context_mask, enc_hidden = encode_input_for_decoder(X_tensors, inp_lens_tensor, model_input_emb, model_enc)
        init_dec_inp = Variable(torch.LongTensor([vocab_indexer.index_of(SOS_SYMBOL)] * X_tensors.shape[0]))  
        if CUDA:
            init_dec_inp = init_dec_inp.cuda()
        dec_input = model_output_emb.forward(init_dec_inp)
        if TOPIC:
            init_dec_inp
        
        
        
        dec_hidden = enc_hidden
        cont = torch.zeros((X_tensors.shape[0], 2 * hidden_size))    # batch_size, 2*hidden_size
        coverage = torch.zeros((X_tensors.shape[0], enc_outputs.shape[0]))             # batch_size, sent_lens
        all_context_mask = torch.from_numpy(np.asarray([[1 if j < oup_lens_tensor.data[i].item() \
            else 0 for j in range(0, Y_tensors.size(1))] for i in range(0, oup_lens_tensor.shape[0])]))
        agr_loss = torch.zeros(X_tensors.shape[0])
        
        if CUDA:
            cont = cont.cuda()
            coverage = coverage.cuda()
            all_context_mask = all_context_mask.cuda()
            agr_loss = agr_loss.cuda()
        
#         # Decoder
#         all_dec_outputs = Variable(torch.zeros(output_train_max_len, X_tensors.shape[0], len(vocab_indexer)))   # sent_len, batch_size, ext_output_size
#         for idx in range(output_train_max_len):
#             p_gen, dec_output, dec_hidden, dec_attn = model_dec.forward(dec_input, dec_hidden, enc_outputs, enc_context_mask)
#             all_dec_outputs[idx] = pointer_generate_train(p_gen, dec_output, dec_attn, X_tensors)
#             max_prob_idx = torch.argmax(all_dec_outputs[idx], dim=1)
#             print('max_prob_idx:', max_prob_idx)
#             print('Y_tensors:', Y_tensors[:, idx])
# #             print(agr_loss)
# #             agr_loss = agr_loss + cal_step_loss(final_distrib, dec_attn, coverage, Y_tensors[:, idx].unsqueeze(1), all_context_mask[:, idx].float())
#             dec_input = model_output_emb.forward(Y_tensors[:, idx])
# #             coverage = next_coverage         
#         loss = masked_cross_entropy(all_dec_outputs.transpose(0, 1).contiguous(), Y_tensors, oup_lens_tensor, all_context_mask)       # batch_size, sent_len, output_size
        
        
        # Decoder2
        for idx in range(output_train_max_len):
            p_gen, dec_output, dec_hidden, dec_attn, cont, next_coverage = model_dec.forward(dec_input, dec_hidden, enc_outputs, enc_context_mask, cont, coverage)
            final_distrib = pointer_generate_train(p_gen, dec_output, dec_attn, X_tensors, CUDA)
            max_prob_idx = torch.argmax(final_distrib, dim=1)
#             print(final_distrib)
#             print('max_prob_idx:', max_prob_idx)
#             print('Y_tensors:', Y_tensors[:, idx])
#             print(agr_loss)
            agr_loss = agr_loss + cal_step_loss(final_distrib, dec_attn, coverage, Y_tensors[:, idx].unsqueeze(1), all_context_mask[:, idx].float())
            dec_input = model_output_emb.forward(Y_tensors[:, idx])
            coverage = next_coverage
        
        batch_avg_loss = agr_loss/oup_lens_tensor.float()
        loss = torch.mean(batch_avg_loss)
        
        loss.backward()

        enc_optimizer.step()
        dec_optimizer.step()

        print('loss', loss.item())
        
elapsed_time = time.time() - start
print('Time: %.2fs'%(elapsed_time))

--------------------- Epoch 1 ---------------------
loss 6.682219982147217
loss 7.096613883972168
loss 7.289017200469971
loss 6.96723747253418
loss 6.931626796722412
loss 7.120256423950195
loss 6.948846817016602
loss 7.004876136779785
loss 6.574513912200928
loss 6.941911220550537
loss 7.446579456329346
loss 6.987937927246094
loss 7.371664524078369
loss 7.43435525894165
loss 7.313004016876221
loss 6.834506988525391
loss 7.0308966636657715
loss 6.51055383682251
loss 7.370500087738037
loss 6.93660831451416
loss 6.362265110015869
loss 6.800319671630859
loss 6.689985275268555
loss 6.918300151824951
loss 7.141538143157959
loss 6.856302738189697
loss 6.781208038330078
loss 6.575254917144775
loss 6.801943778991699
loss 6.5886383056640625
loss 6.740764617919922
loss 6.209800720214844
loss 6.7847723960876465
loss 6.911956787109375
loss 6.4391865730285645
loss 7.091922760009766
loss 6.662973403930664
loss 6.254928112030029
loss 6.951873302459717
loss 6.890693664550781
loss 6.748363971710205
loss 

loss 6.260390758514404
loss 6.039236068725586
loss 6.219856262207031
loss 5.9032793045043945
loss 6.230953693389893
loss 6.319899559020996
loss 5.956056594848633
loss 6.485411643981934
loss 6.171507358551025
loss 5.820878982543945
loss 6.405047416687012
loss 6.332647323608398
loss 6.217918395996094
loss 6.540416717529297
loss 6.588804244995117
loss 6.134335517883301
loss 6.196417808532715
loss 6.239518165588379
loss 6.1439032554626465
loss 6.115933418273926
loss 5.918435096740723
loss 5.942765235900879
loss 6.086821556091309
loss 6.1803154945373535
loss 6.166346549987793
loss 6.187134265899658
loss 6.360445976257324
loss 6.438915729522705
loss 6.3634724617004395
loss 6.174997806549072
loss 6.229969501495361
loss 5.974547386169434
loss 5.883945465087891
loss 6.489138603210449
loss 5.7050275802612305
--------------------- Epoch 7 ---------------------
loss 5.894809722900391
loss 6.251573085784912
loss 6.28465461730957
loss 6.112685203552246
loss 6.0316572189331055
loss 6.224917888641357


loss 5.663633823394775
loss 5.559226036071777
loss 5.378233909606934
loss 5.827363014221191
loss 4.8823957443237305
--------------------- Epoch 12 ---------------------
loss 5.395224571228027
loss 5.688724040985107
loss 5.661533355712891
loss 5.559519290924072
loss 5.530464172363281
loss 5.685520648956299
loss 5.600069999694824
loss 5.602293968200684
loss 5.45303201675415
loss 5.569540977478027
loss 5.684627056121826
loss 5.6933183670043945
loss 5.931724548339844
loss 5.785111904144287
loss 5.722293376922607
loss 5.381943702697754
loss 5.50319242477417
loss 5.412695407867432
loss 5.624261856079102
loss 5.544377326965332
loss 5.404471397399902
loss 5.6783928871154785
loss 5.438896179199219
loss 5.6447296142578125
loss 5.797823429107666
loss 5.651298999786377
loss 5.646653175354004
loss 5.531062602996826
loss 5.651065349578857
loss 5.437340259552002
loss 5.59537410736084
loss 5.422532081604004
loss 5.60612154006958
loss 5.601322650909424
loss 5.348388671875
loss 5.889214992523193
loss 5.

loss 4.920263767242432
loss 5.168356418609619
loss 4.959718227386475
loss 4.9899492263793945
loss 4.906375885009766
loss 4.996639251708984
loss 4.86946439743042
loss 5.021608829498291
loss 4.820359706878662
loss 4.95493221282959
loss 5.084301948547363
loss 4.754262924194336
loss 5.126285552978516
loss 4.780130386352539
loss 4.684189796447754
loss 4.808468818664551
loss 4.876697540283203
loss 4.888808727264404
loss 5.071101665496826
loss 5.021739959716797
loss 4.58389139175415
loss 4.5764875411987305
loss 5.059024810791016
loss 4.799126148223877
loss 4.7855682373046875
loss 4.772384166717529
loss 4.911135673522949
loss 4.758112907409668
loss 4.584278583526611
loss 4.814074516296387
loss 4.7244873046875
loss 4.8671698570251465
loss 4.916433334350586
loss 4.7411885261535645
loss 4.774366855621338
loss 4.817635536193848
loss 4.8128228187561035
loss 4.595564842224121
loss 4.851517677307129
loss 3.6006150245666504
--------------------- Epoch 18 ---------------------
loss 4.661112308502197
lo

loss 4.011950492858887
loss 4.225798606872559
loss 4.309284687042236
loss 4.071982383728027
loss 4.080846786499023
loss 4.158544540405273
loss 4.202315330505371
loss 4.025012016296387
loss 4.206754684448242
loss 2.6888062953948975
--------------------- Epoch 23 ---------------------
loss 3.806586980819702
loss 4.033792495727539
loss 3.953136920928955
loss 4.041162490844727
loss 4.002616882324219
loss 4.10928201675415
loss 4.140981197357178
loss 4.157440185546875
loss 4.159093856811523
loss 4.170629024505615
loss 3.9149816036224365
loss 4.3233323097229
loss 4.168878078460693
loss 4.109511852264404
loss 3.790921926498413
loss 3.7835798263549805
loss 3.7018251419067383
loss 3.9634125232696533
loss 3.8100860118865967
loss 3.6996474266052246
loss 3.9601311683654785
loss 4.1524553298950195
loss 3.978564500808716
loss 3.9883439540863037
loss 4.354497909545898
loss 4.1356024742126465
loss 4.034270763397217
loss 3.9664723873138428
loss 4.211947917938232
loss 4.203221321105957
loss 4.05123758316

loss 3.270214557647705
loss 3.2504448890686035
loss 3.1244850158691406
loss 3.212632894515991
loss 3.110793113708496
loss 3.2321701049804688
loss 3.230846643447876
loss 3.3114984035491943
loss 3.448481798171997
loss 3.5445148944854736
loss 3.245776653289795
loss 3.1631886959075928
loss 3.6023201942443848
loss 3.2893733978271484
loss 3.281744956970215
loss 3.233367681503296
loss 3.4518022537231445
loss 3.316284418106079
loss 3.4026238918304443
loss 3.4250714778900146
loss 3.369476556777954
loss 3.51983642578125
loss 3.433988571166992
loss 3.534770965576172
loss 3.397387742996216
loss 3.3278369903564453
loss 3.2891101837158203
loss 3.4691152572631836
loss 3.477787494659424
loss 3.3158419132232666
loss 3.5107333660125732
loss 3.332801342010498
loss 3.1954691410064697
loss 3.5359294414520264
loss 3.3189098834991455
loss 3.289714813232422
loss 3.3593153953552246
loss 3.389375686645508
loss 3.291616439819336
loss 3.0845882892608643
loss 3.3602585792541504
loss 3.18339204788208
loss 3.4337098

loss 2.947390556335449
loss 2.64106822013855
loss 2.6150450706481934
loss 2.6993303298950195
loss 2.8927040100097656
loss 2.8374764919281006
loss 2.7256603240966797
loss 2.8928990364074707
loss 2.8455939292907715
loss 2.6245946884155273
loss 2.902045249938965
loss 2.770348310470581
loss 2.7636284828186035
loss 2.5699994564056396
loss 2.788529396057129
loss 2.739956855773926
loss 2.526165008544922
loss 2.701024293899536
loss 2.6353180408477783
loss 2.79329776763916
loss 2.9432854652404785
loss 2.7569029331207275
loss 2.6169538497924805
loss 2.707595109939575
loss 2.8409533500671387
loss 2.6760730743408203
loss 2.7880983352661133
loss 1.4672794342041016
--------------------- Epoch 34 ---------------------
loss 2.6799521446228027
loss 2.906698703765869
loss 2.6626951694488525
loss 2.7848570346832275
loss 2.5759854316711426
loss 2.8757011890411377
loss 2.6011545658111572
loss 2.79970383644104
loss 2.653364419937134
loss 2.901303291320801
loss 2.7988038063049316
loss 2.8880624771118164
loss

loss 2.103853464126587
loss 2.283562421798706
loss 2.274219512939453
loss 2.2803733348846436
loss 2.278688907623291
loss 1.021514654159546
--------------------- Epoch 39 ---------------------
loss 2.088799238204956
loss 2.2813167572021484
loss 2.1464784145355225
loss 2.108797550201416
loss 2.1120011806488037
loss 2.254930019378662
loss 2.219916582107544
loss 2.2767577171325684
loss 2.164870023727417
loss 2.248642921447754
loss 2.3656539916992188
loss 2.355726480484009
loss 2.2574527263641357
loss 2.2459425926208496
loss 2.131302833557129
loss 2.2519757747650146
loss 2.2087655067443848
loss 2.1646018028259277
loss 2.119184970855713
loss 2.048128128051758
loss 2.1295526027679443
loss 2.3048999309539795
loss 2.214344024658203
loss 1.9866589307785034
loss 2.4068524837493896
loss 2.2025768756866455
loss 2.1273841857910156
loss 2.032191753387451
loss 2.2880232334136963
loss 2.113028049468994
loss 2.2322745323181152
loss 2.1831412315368652
loss 2.1565914154052734
loss 2.220348596572876
loss 2

loss 1.9634177684783936
loss 1.870779275894165
loss 1.7680225372314453
loss 1.698221206665039
loss 1.7463165521621704
loss 1.7423207759857178
loss 1.612805724143982
loss 1.8954546451568604
loss 2.076310157775879
loss 1.7804591655731201
loss 1.861909031867981
loss 2.016956329345703
loss 1.872581124305725
loss 1.8526328802108765
loss 1.663135290145874
loss 1.9952486753463745
loss 1.749100923538208
loss 1.8945393562316895
loss 1.9041774272918701
loss 1.653784155845642
loss 1.9281127452850342
loss 1.6648794412612915
loss 1.917726993560791
loss 1.875666618347168
loss 1.7145476341247559
loss 1.7709711790084839
loss 1.9158812761306763
loss 1.8889079093933105
loss 1.887462854385376
loss 1.9362503290176392
loss 1.8527295589447021
loss 1.7671198844909668
loss 1.9679455757141113
loss 1.7938581705093384
loss 1.907630205154419
loss 1.7591171264648438
loss 1.8615989685058594
loss 1.7027320861816406
loss 1.7743414640426636
loss 1.800984263420105
loss 1.72874116897583
loss 1.8481096029281616
loss 1.93

loss 1.4602404832839966
loss 1.5045219659805298
loss 1.5286904573440552
loss 1.591550350189209
loss 1.5762676000595093
loss 1.530222773551941
loss 1.4713166952133179
loss 1.5735907554626465
loss 1.6111615896224976
loss 1.4783236980438232
loss 1.636704921722412
loss 1.544494390487671
loss 1.4835841655731201
loss 1.497857689857483
loss 1.3723567724227905
loss 1.711335301399231
loss 1.506631851196289
loss 1.631784439086914
loss 1.490631341934204
loss 1.4413282871246338
loss 1.4991880655288696
loss 1.633709192276001
loss 1.5597896575927734
loss 1.7399147748947144
loss 1.5367588996887207
loss 1.4862549304962158
loss 1.4824999570846558
loss 1.4548277854919434
loss 1.6217601299285889
loss 1.76715087890625
loss 0.6708067655563354
--------------------- Epoch 50 ---------------------
loss 1.3735558986663818
loss 1.7381664514541626
loss 1.6417384147644043
loss 1.5129735469818115
loss 1.510137677192688
loss 1.5823794603347778
loss 1.532226800918579
loss 1.5240964889526367
loss 1.551793098449707
lo

loss 1.3140678405761719
loss 1.382828950881958
loss 1.2763234376907349
loss 1.3226346969604492
loss 1.536137580871582
loss 1.2520432472229004
loss 1.369230031967163
loss 1.3311378955841064
loss 1.3478410243988037
loss 1.3901110887527466
loss 1.5277676582336426
loss 0.5733506083488464
--------------------- Epoch 55 ---------------------
loss 1.289381742477417
loss 1.5406817197799683
loss 1.3512651920318604
loss 1.3113878965377808
loss 1.2501639127731323
loss 1.2693177461624146
loss 1.2831969261169434
loss 1.3306344747543335
loss 1.3752549886703491
loss 1.4404044151306152
loss 1.3036718368530273
loss 1.321230411529541
loss 1.4412198066711426
loss 1.376633882522583
loss 1.2788379192352295
loss 1.3281525373458862
loss 1.3666179180145264
loss 1.2922950983047485
loss 1.278952956199646
loss 1.3593997955322266
loss 1.2270694971084595
loss 1.4441518783569336
loss 1.308197259902954
loss 1.3619410991668701
loss 1.5355311632156372
loss 1.3750882148742676
loss 1.2573957443237305
loss 1.308365345001

loss 1.2464536428451538
loss 1.1346272230148315
loss 1.1561775207519531
loss 1.1549961566925049
loss 1.2784188985824585
loss 1.1390190124511719
loss 1.215406060218811
loss 1.2910257577896118
loss 1.136930227279663
loss 1.1359776258468628
loss 1.1136051416397095
loss 1.1313738822937012
loss 1.0394632816314697
loss 1.102618932723999
loss 1.1647307872772217
loss 1.1690528392791748
loss 1.230008602142334
loss 1.2391952276229858
loss 1.1144258975982666
loss 1.3451287746429443
loss 1.078784704208374
loss 1.1641744375228882
loss 1.089640498161316
loss 1.2261271476745605
loss 1.1832373142242432
loss 1.1475621461868286
loss 1.1315375566482544
loss 1.0913865566253662
loss 1.1489115953445435
loss 1.1005210876464844
loss 1.2323020696640015
loss 1.1543174982070923
loss 1.1342350244522095
loss 1.159743070602417
loss 1.2628371715545654
loss 1.1901757717132568
loss 1.2402406930923462
loss 1.1826598644256592
loss 1.120090365409851
loss 1.1775665283203125
loss 1.158569097518921
loss 1.0111709833145142
l

loss 1.1283849477767944
loss 0.9763308763504028
loss 0.9647219181060791
loss 0.8101416826248169
loss 1.0221834182739258
loss 1.034515142440796
loss 1.008669137954712
loss 0.8873827457427979
loss 1.0191441774368286
loss 1.0909574031829834
loss 1.0028371810913086
loss 1.041899561882019
loss 1.0401391983032227
loss 1.0988593101501465
loss 1.0514891147613525
loss 1.096877098083496
loss 1.0431513786315918
loss 1.028302788734436
loss 1.0764257907867432
loss 0.9238641262054443
loss 0.9984949827194214
loss 1.1164400577545166
loss 0.9259751439094543
loss 1.0588724613189697
loss 0.9513096809387207
loss 0.9846163988113403
loss 0.9393723607063293
loss 0.9270359873771667
loss 1.0154366493225098
loss 1.0437098741531372
loss 0.9942163825035095
loss 1.0719460248947144
loss 0.937789797782898
loss 0.932292640209198
loss 0.9145527482032776
loss 0.9567843079566956
loss 1.0243027210235596
loss 1.0530132055282593
loss 0.46307918429374695
--------------------- Epoch 66 ---------------------
loss 0.8230437636

loss 1.063010573387146
loss 0.8467102646827698
loss 0.9023066759109497
loss 0.8632581830024719
loss 0.8609495759010315
loss 1.0053789615631104
loss 0.8370115160942078
loss 0.7613644599914551
loss 0.8375295996665955
loss 0.9607895612716675
loss 0.7788576483726501
loss 0.9209446310997009
loss 0.8697850108146667
loss 0.9375882148742676
loss 0.8443294167518616
loss 0.820056140422821
loss 0.8894909620285034
loss 0.8508692979812622
loss 0.8733620643615723
loss 0.9564884901046753
loss 0.3252265155315399
--------------------- Epoch 71 ---------------------
loss 0.8657115697860718
loss 1.2134382724761963
loss 0.9489212036132812
loss 0.8304892182350159
loss 0.8252907395362854
loss 0.957995593547821
loss 0.9181064963340759
loss 0.9507592916488647
loss 0.9255585074424744
loss 0.8920100927352905
loss 0.8193227052688599
loss 0.8710161447525024
loss 1.03899085521698
loss 0.9502179622650146
loss 0.9635941386222839
loss 0.7864966988563538
loss 0.8737404942512512
loss 0.8616213202476501
loss 0.857307434

loss 0.6994363069534302
loss 0.8754249811172485
loss 0.8158606290817261
loss 0.4444260001182556
--------------------- Epoch 76 ---------------------
loss 0.6839226484298706
loss 0.9910184144973755
loss 0.7495858669281006
loss 0.7211683392524719
loss 0.7165560722351074
loss 0.7784227132797241
loss 0.7585728168487549
loss 0.8849508762359619
loss 0.791135311126709
loss 0.7252575755119324
loss 0.7865986824035645
loss 0.8035749197006226
loss 0.8662052750587463
loss 0.8978942036628723
loss 0.7783665657043457
loss 0.6920024156570435
loss 0.7780017256736755
loss 0.7702270746231079
loss 0.7142243385314941
loss 0.7770546078681946
loss 0.7510182857513428
loss 0.7931092977523804
loss 0.7219836115837097
loss 0.8564304709434509
loss 0.8587449789047241
loss 0.8622000217437744
loss 0.79442298412323
loss 0.694556474685669
loss 0.9631822109222412
loss 0.6814762353897095
loss 0.773777425289154
loss 0.6779904365539551
loss 0.7125352621078491
loss 0.7794719934463501
loss 0.7209805846214294
loss 0.815186381

loss 0.7834445238113403
loss 0.8103252053260803
loss 0.7556467652320862
loss 0.6824995279312134
loss 0.6979905366897583
loss 0.7087035179138184
loss 0.6261752843856812
loss 0.6794961094856262
loss 0.6972942352294922
loss 0.6555590629577637
loss 0.761680543422699
loss 0.6907879114151001
loss 0.6633884906768799
loss 0.7384322881698608
loss 0.6862931251525879
loss 0.7081435918807983
loss 0.6290077567100525
loss 0.7085090279579163
loss 0.7200654149055481
loss 0.6484211683273315
loss 0.6558043360710144
loss 0.712381899356842
loss 0.7451252341270447
loss 0.6511214971542358
loss 0.785216748714447
loss 0.697226881980896
loss 0.6843724846839905
loss 0.7786070704460144
loss 0.827168881893158
loss 0.7744343876838684
loss 0.7590380907058716
loss 0.7315020561218262
loss 0.7343336939811707
loss 0.6772853136062622
loss 0.7137864828109741
loss 0.6903927326202393
loss 0.7136063575744629
loss 0.6369161009788513
loss 0.6670151352882385
loss 0.6947941780090332
loss 0.6533985137939453
loss 0.71921086311340

loss 0.614702582359314
loss 0.6097208261489868
loss 0.6151211261749268
loss 0.7393091917037964
loss 0.5623244643211365
loss 0.6042817831039429
loss 0.6246651411056519
loss 0.6979106068611145
loss 0.5968998670578003
loss 0.5767524838447571
loss 0.605891764163971
loss 0.6987789273262024
loss 0.6491739749908447
loss 0.633976936340332
loss 0.6222840547561646
loss 0.7067809700965881
loss 0.6732022166252136
loss 0.7047982811927795
loss 0.626406729221344
loss 0.7362367510795593
loss 0.5318584442138672
loss 0.6299446225166321
loss 0.5924950838088989
loss 0.6456809639930725
loss 0.5394518971443176
loss 0.7316031455993652
loss 0.5996037721633911
loss 0.6763849258422852
loss 0.6335757970809937
loss 0.6196620464324951
loss 0.6038855910301208
loss 0.6725431084632874
loss 0.5828480124473572
loss 0.6486150622367859
loss 0.292701780796051
--------------------- Epoch 87 ---------------------
loss 0.6203746795654297
loss 0.881248950958252
loss 0.6949098110198975
loss 0.5842900276184082
loss 0.6262578368

loss 0.577671468257904
loss 0.6744919419288635
loss 0.5284373164176941
loss 0.5748373866081238
loss 0.6207204461097717
loss 0.6327804923057556
loss 0.6281633973121643
loss 0.5042322874069214
loss 0.5810011625289917
loss 0.6806303858757019
loss 0.5884031653404236
loss 0.6202354431152344
loss 0.5919889211654663
loss 0.5036223530769348
loss 0.6082348823547363
loss 0.5516165494918823
loss 0.32254889607429504
--------------------- Epoch 92 ---------------------
loss 0.5867788195610046
loss 0.8120084404945374
loss 0.5575810670852661
loss 0.5625299215316772
loss 0.5283203721046448
loss 0.6101559996604919
loss 0.5281073451042175
loss 0.5693330764770508
loss 0.47646814584732056
loss 0.5239652395248413
loss 0.5604095458984375
loss 0.6336624026298523
loss 0.7003213167190552
loss 0.7114686369895935
loss 0.5923388600349426
loss 0.5886076092720032
loss 0.5975644588470459
loss 0.5744808316230774
loss 0.5957579016685486
loss 0.5825977325439453
loss 0.5220397114753723
loss 0.5964181423187256
loss 0.645

loss 0.6565539836883545
loss 0.7719975113868713
loss 0.547683835029602
loss 0.5022724866867065
loss 0.4654814302921295
loss 0.5454860925674438
loss 0.47765564918518066
loss 0.6136277318000793
loss 0.4843003749847412
loss 0.5088139176368713
loss 0.547877848148346
loss 0.5020872354507446
loss 0.5078868865966797
loss 0.5816514492034912
loss 0.4711948335170746
loss 0.5478132963180542
loss 0.5620085000991821
loss 0.5285356044769287
loss 0.5305919051170349
loss 0.4933290183544159
loss 0.5510438084602356
loss 0.5455303192138672
loss 0.5385456085205078
loss 0.5338278412818909
loss 0.6764528751373291
loss 0.5138461589813232
loss 0.5729058980941772
loss 0.5123180150985718
loss 0.5843256115913391
loss 0.48980361223220825
loss 0.5075106620788574
loss 0.5269128680229187
loss 0.4957852363586426
loss 0.5519905090332031
loss 0.5709874033927917
loss 0.5219010710716248
loss 0.45563918352127075
loss 0.5232431292533875
loss 0.5396391153335571
loss 0.6032170057296753
loss 0.5606306791305542
loss 0.53533488

loss 0.4973079562187195
loss 0.505480945110321
loss 0.5120275020599365
loss 0.6051666140556335
loss 0.534457802772522
loss 0.5388909578323364
loss 0.4545917510986328
loss 0.4894777834415436
loss 0.5192076563835144
loss 0.5679072737693787
loss 0.533894419670105
loss 0.5133161544799805
loss 0.5923168659210205
loss 0.5159578323364258
loss 0.508388102054596
loss 0.47121816873550415
loss 0.46953392028808594
loss 0.5846787095069885
loss 0.48821303248405457
loss 0.45260363817214966
loss 0.5485087037086487
loss 0.4715031683444977
loss 0.4876733124256134
loss 0.4820837080478668
loss 0.5183837413787842
loss 0.47196656465530396
loss 0.5585271716117859
loss 0.5835059285163879
loss 0.4514647424221039
loss 0.466668039560318
loss 0.46308809518814087
loss 0.48682186007499695
loss 0.5798189640045166
loss 0.39402228593826294
loss 0.5545338988304138
loss 0.4971248209476471
loss 0.563467800617218
loss 0.4696275591850281
loss 0.4347243905067444
loss 0.49799269437789917
loss 0.6502249240875244
loss 0.529751

loss 0.48419758677482605
loss 0.5618741512298584
loss 0.5041208863258362
loss 0.42394548654556274
loss 0.40073448419570923
loss 0.43897557258605957
loss 0.4690951108932495
loss 0.49294960498809814
loss 0.5103262662887573
loss 0.43562448024749756
loss 0.45229586958885193
loss 0.4874822199344635
loss 0.5201349258422852
loss 0.46705496311187744
loss 0.4473230838775635
loss 0.4915960431098938
loss 0.4505082964897156
loss 0.5241612792015076
loss 0.3973888158798218
loss 0.5263941884040833
loss 0.5231674313545227
loss 0.47696131467819214
loss 0.5713017582893372
loss 0.4941410720348358
loss 0.47967642545700073
loss 0.43239539861679077
loss 0.5372949242591858
loss 0.560322642326355
loss 0.49260449409484863
loss 0.45253920555114746
loss 0.505431056022644
loss 0.49872830510139465
loss 0.49651002883911133
loss 0.4858388304710388
loss 0.23232030868530273
--------------------- Epoch 108 ---------------------
loss 0.4113296866416931
loss 0.7436438798904419
loss 0.4495323598384857
loss 0.4895912408828

loss 0.3777197301387787
loss 0.4948538839817047
loss 0.4389674961566925
loss 0.48148664832115173
loss 0.38484904170036316
loss 0.37515774369239807
loss 0.5177443027496338
loss 0.46873313188552856
loss 0.36415350437164307
loss 0.5201740264892578
loss 0.5141422152519226
loss 0.5181798338890076
loss 0.4514090120792389
loss 0.47627097368240356
loss 0.535525918006897
loss 0.46631842851638794
loss 0.38734057545661926
loss 0.44354549050331116
loss 0.4173698127269745
loss 0.4963340759277344
loss 0.4292497932910919
loss 0.16045650839805603
--------------------- Epoch 113 ---------------------
loss 0.4557460844516754
loss 0.6144772171974182
loss 0.42296159267425537
loss 0.41441968083381653
loss 0.3911697566509247
loss 0.448837012052536
loss 0.45610353350639343
loss 0.5317264199256897
loss 0.48983561992645264
loss 0.4519725739955902
loss 0.44305360317230225
loss 0.4787508547306061
loss 0.48349112272262573
loss 0.48983943462371826
loss 0.4020819067955017
loss 0.40737074613571167
loss 0.38957449793

loss 0.4338931441307068
loss 0.43403130769729614
loss 0.4203384518623352
loss 0.3969242572784424
loss 0.400986909866333
loss 0.3667486608028412
loss 0.42546647787094116
loss 0.38811245560646057
loss 0.3953288197517395
loss 0.4387250542640686
loss 0.15131191909313202
--------------------- Epoch 118 ---------------------
loss 0.43386924266815186
loss 0.5631784796714783
loss 0.448164165019989
loss 0.3858354687690735
loss 0.3642728924751282
loss 0.4262772500514984
loss 0.41353046894073486
loss 0.4471053183078766
loss 0.36795058846473694
loss 0.43307560682296753
loss 0.4561817944049835
loss 0.5359686613082886
loss 0.5293490886688232
loss 0.48851776123046875
loss 0.46509289741516113
loss 0.4127974510192871
loss 0.49427494406700134
loss 0.38872843980789185
loss 0.428527295589447
loss 0.5443768501281738
loss 0.415912926197052
loss 0.5025980472564697
loss 0.38927552103996277
loss 0.3486175239086151
loss 0.40692973136901855
loss 0.42068570852279663
loss 0.453787237405777
loss 0.3903665542602539


loss 0.39585715532302856
loss 0.6748644709587097
loss 0.46147578954696655
loss 0.3958928883075714
loss 0.26886895298957825
loss 0.45001447200775146
loss 0.39536052942276
loss 0.34394699335098267
loss 0.454412043094635
loss 0.4190412163734436
loss 0.45931339263916016
loss 0.4596955180168152
loss 0.47689497470855713
loss 0.4536043107509613
loss 0.3750067353248596
loss 0.34829193353652954
loss 0.3977283835411072
loss 0.4726139307022095
loss 0.3992518484592438
loss 0.4104917645454407
loss 0.4050901532173157
loss 0.40455400943756104
loss 0.4638693630695343
loss 0.42378389835357666
loss 0.49510714411735535
loss 0.45765215158462524
loss 0.4078911542892456
loss 0.41038188338279724
loss 0.4291733503341675
loss 0.4325510859489441
loss 0.3417871296405792
loss 0.4325695037841797
loss 0.45133984088897705
loss 0.4568600356578827
loss 0.3997730612754822
loss 0.3674769103527069
loss 0.42016905546188354
loss 0.44451603293418884
loss 0.41119199991226196
loss 0.4977778196334839
loss 0.37720781564712524
l

loss 0.43499594926834106
loss 0.46048402786254883
loss 0.48785167932510376
loss 0.4208122193813324
loss 0.35474613308906555
loss 0.3981115520000458
loss 0.34361904859542847
loss 0.4409294128417969
loss 0.3935142457485199
loss 0.4110682010650635
loss 0.6135696172714233
loss 0.4230228662490845
loss 0.414829283952713
loss 0.43352654576301575
loss 0.35801562666893005
loss 0.3580731153488159
loss 0.4154476523399353
loss 0.3833749294281006
loss 0.4670047163963318
loss 0.3607714772224426
loss 0.3825303316116333
loss 0.4973807632923126
loss 0.44969403743743896
loss 0.35034480690956116
loss 0.4179840683937073
loss 0.37541520595550537
loss 0.4767743647098541
loss 0.37649407982826233
loss 0.40240299701690674
loss 0.4680737853050232
loss 0.36513328552246094
loss 0.36884090304374695
loss 0.47525814175605774
loss 0.3476407825946808
loss 0.35073286294937134
loss 0.3589525818824768
loss 0.45657646656036377
loss 0.3100243806838989
loss 0.2922249138355255
loss 0.37079349160194397
loss 0.4640726447105407

loss 0.373803973197937
loss 0.36591243743896484
loss 0.31706228852272034
loss 0.3527177572250366
loss 0.3576257824897766
loss 0.4138972759246826
loss 0.40616634488105774
loss 0.36312466859817505
loss 0.3300536274909973
loss 0.3626733124256134
loss 0.2959756851196289
loss 0.479775071144104
loss 0.3563081622123718
loss 0.34095054864883423
loss 0.39310845732688904
loss 0.4245975613594055
loss 0.30273863673210144
loss 0.3465613126754761
loss 0.3626135587692261
loss 0.39342382550239563
loss 0.3618278205394745
loss 0.3857068121433258
loss 0.4622649550437927
loss 0.42176908254623413
loss 0.4073598086833954
loss 0.2840624451637268
loss 0.4340020418167114
loss 0.4116487205028534
loss 0.4359748959541321
loss 0.3621590733528137
loss 0.39333081245422363
loss 0.33522289991378784
loss 0.3631424009799957
loss 0.36091554164886475
loss 0.4331454336643219
loss 0.3888512849807739
loss 0.3592221140861511
loss 0.3073289096355438
loss 0.28882837295532227
loss 0.338627427816391
loss 0.40011972188949585
loss 

loss 0.37469029426574707
loss 0.4246140122413635
loss 0.3145223557949066
loss 0.39056429266929626
loss 0.41701602935791016
loss 0.435499370098114
loss 0.36359918117523193
loss 0.40534546971321106
loss 0.32954978942871094
loss 0.4402717053890228
loss 0.35304808616638184
loss 0.3389992415904999
loss 0.4415262043476105
loss 0.33083850145339966
loss 0.31264036893844604
loss 0.27596181631088257
loss 0.3027462065219879
loss 0.3506903052330017
loss 0.3224759101867676
loss 0.4033212661743164
loss 0.32866424322128296
loss 0.36287209391593933
loss 0.4841790199279785
loss 0.3399460017681122
loss 0.3927091658115387
loss 0.38605576753616333
loss 0.3116157650947571
loss 0.31303641200065613
loss 0.30821749567985535
loss 0.35236990451812744
loss 0.32275086641311646
loss 0.29831451177597046
--------------------- Epoch 139 ---------------------
loss 0.31374090909957886
loss 0.5794798135757446
loss 0.31244659423828125
loss 0.3350990116596222
loss 0.2983396351337433
loss 0.35632088780403137
loss 0.3089262

loss 0.27330541610717773
loss 0.3148789703845978
loss 0.35373419523239136
loss 0.30132612586021423
loss 0.31310543417930603
loss 0.3152986168861389
loss 0.35708123445510864
loss 0.27768391370773315
loss 0.2861897349357605
loss 0.3386419713497162
loss 0.34036412835121155
loss 0.3542996644973755
loss 0.3376915156841278
loss 0.33888161182403564
loss 0.312079519033432
loss 0.34347695112228394
loss 0.2740821838378906
loss 0.29330211877822876
loss 0.3883952796459198
loss 0.2769203782081604
loss 0.40945711731910706
loss 0.1669611781835556
--------------------- Epoch 144 ---------------------
loss 0.33697932958602905
loss 0.49660515785217285
loss 0.3293059766292572
loss 0.3520488739013672
loss 0.27685296535491943
loss 0.29098933935165405
loss 0.3382396101951599
loss 0.30134010314941406
loss 0.3155430555343628
loss 0.3346025347709656
loss 0.3259471654891968
loss 0.33580294251441956
loss 0.3936181366443634
loss 0.31567251682281494
loss 0.3250616192817688
loss 0.3182058334350586
loss 0.3293333053

loss 0.3614330291748047
loss 0.3546130061149597
loss 0.3500686287879944
loss 0.2996794283390045
loss 0.34660351276397705
loss 0.32875338196754456
loss 0.301257848739624
loss 0.3136903941631317
loss 0.2893061637878418
loss 0.2980382740497589
loss 0.3218097388744354
loss 0.18644627928733826
--------------------- Epoch 149 ---------------------
loss 0.40582969784736633
loss 0.6015352010726929
loss 0.32232367992401123
loss 0.4007469713687897
loss 0.28241202235221863
loss 0.30456894636154175
loss 0.30781471729278564
loss 0.40164417028427124
loss 0.23341166973114014
loss 0.3429523706436157
loss 0.3048735558986664
loss 0.32649505138397217
loss 0.38289323449134827
loss 0.3663386106491089
loss 0.3354494571685791
loss 0.3013555705547333
loss 0.3392783999443054
loss 0.3737139403820038
loss 0.34645554423332214
loss 0.36377185583114624
loss 0.29057857394218445
loss 0.3919134736061096
loss 0.3401944041252136
loss 0.31631574034690857
loss 0.4265686869621277
loss 0.2479034960269928
loss 0.338468730449

loss 0.34071803092956543
loss 0.14722417294979095
--------------------- Epoch 154 ---------------------
loss 0.3433256149291992
loss 0.5993198752403259
loss 0.2998882234096527
loss 0.2801114320755005
loss 0.2700763940811157
loss 0.3490147292613983
loss 0.3723498582839966
loss 0.3344722092151642
loss 0.41839876770973206
loss 0.2858770191669464
loss 0.3760059177875519
loss 0.3674311339855194
loss 0.30629774928092957
loss 0.2874080538749695
loss 0.4083602726459503
loss 0.2567247450351715
loss 0.32471969723701477
loss 0.36469364166259766
loss 0.3237181007862091
loss 0.38527125120162964
loss 0.30685901641845703
loss 0.29500091075897217
loss 0.27550122141838074
loss 0.31344619393348694
loss 0.3859159052371979
loss 0.280173122882843
loss 0.3203040659427643
loss 0.33960819244384766
loss 0.28259217739105225
loss 0.2753888964653015
loss 0.38356295228004456
loss 0.28223878145217896
loss 0.33670574426651
loss 0.3251977562904358
loss 0.3042038083076477
loss 0.37673094868659973
loss 0.33363837003707

loss 0.2930462658405304
loss 0.25100693106651306
loss 0.3141259551048279
loss 0.3217093050479889
loss 0.28333744406700134
loss 0.36588254570961
loss 0.30559343099594116
loss 0.27096524834632874
loss 0.2519417405128479
loss 0.2743777334690094
loss 0.2939611077308655
loss 0.3278469741344452
loss 0.29757484793663025
loss 0.307424932718277
loss 0.30057492852211
loss 0.2949689030647278
loss 0.35151153802871704
loss 0.33975693583488464
loss 0.3144998848438263
loss 0.2961270213127136
loss 0.27169251441955566
loss 0.30457377433776855
loss 0.31857213377952576
loss 0.3334362804889679
loss 0.267793208360672
loss 0.38365134596824646
loss 0.34544092416763306
loss 0.24130751192569733
loss 0.27068135142326355
loss 0.28013062477111816
loss 0.31063294410705566
loss 0.29970452189445496
loss 0.3211654722690582
loss 0.32145559787750244
loss 0.33679354190826416
loss 0.3038923442363739
loss 0.3448750376701355
loss 0.28117161989212036
loss 0.29629284143447876
loss 0.22915010154247284
loss 0.34660136699676514

loss 0.30998915433883667
loss 0.26387640833854675
loss 0.3724473714828491
loss 0.25203317403793335
loss 0.3463703989982605
loss 0.2940697968006134
loss 0.29730963706970215
loss 0.35403215885162354
loss 0.2942934036254883
loss 0.2472815215587616
loss 0.2991616427898407
loss 0.2902318239212036
loss 0.1954975426197052
loss 0.2979593276977539
loss 0.2977184057235718
loss 0.3324529528617859
loss 0.26602011919021606
loss 0.26630184054374695
loss 0.24921879172325134
loss 0.2488306313753128
loss 0.3214792013168335
loss 0.33146005868911743
loss 0.221390500664711
loss 0.32227879762649536
loss 0.27278730273246765
loss 0.26081952452659607
loss 0.3136183023452759
loss 0.31850606203079224
loss 0.2879255712032318
loss 0.2692834436893463
loss 0.32628923654556274
loss 0.2502637505531311
loss 0.3186146914958954
loss 0.323102742433548
loss 0.36347782611846924
loss 0.2865435779094696
loss 0.2938992977142334
loss 0.2873757481575012
loss 0.2864348292350769
loss 0.3265341818332672
loss 0.3081169128417969
los

loss 0.42841994762420654
loss 0.30789005756378174
loss 0.29728859663009644
loss 0.3030523359775543
loss 0.3713999390602112
loss 0.33342111110687256
loss 0.3193974196910858
loss 0.31185248494148254
loss 0.3138074576854706
loss 0.37774384021759033
loss 0.3056284189224243
loss 0.3305799961090088
loss 0.27309420704841614
loss 0.29914233088493347
loss 0.3135618567466736
loss 0.36896976828575134
loss 0.32232075929641724
loss 0.2559121251106262
loss 0.2690717875957489
loss 0.30193892121315
loss 0.3094736337661743
loss 0.22434605658054352
loss 0.3030098080635071
loss 0.34413546323776245
loss 0.37797173857688904
loss 0.29786407947540283
loss 0.3850451409816742
loss 0.30353033542633057
loss 0.35700923204421997
loss 0.2930310070514679
loss 0.24707579612731934
loss 0.2783539295196533
loss 0.3015882074832916
loss 0.31902068853378296
loss 0.28383710980415344
loss 0.1184573620557785
--------------------- Epoch 170 ---------------------
loss 0.34807708859443665
loss 0.47337251901626587
loss 0.38328748

loss 0.3169485032558441
loss 0.2554911673069
loss 0.26814818382263184
loss 0.313199520111084
loss 0.27818799018859863
loss 0.2679077386856079
loss 0.31108298897743225
loss 0.23983080685138702
loss 0.2802805006504059
loss 0.31480491161346436
loss 0.28713661432266235
loss 0.2943194508552551
loss 0.25417789816856384
loss 0.2627364695072174
loss 0.31722766160964966
loss 0.2870675325393677
loss 0.3391992449760437
loss 0.2562906742095947
loss 0.275860071182251
loss 0.31824761629104614
loss 0.26790666580200195
loss 0.3313698172569275
loss 0.2893569767475128
loss 0.2946193516254425
loss 0.3589094877243042
loss 0.1873188018798828
--------------------- Epoch 175 ---------------------
loss 0.2570268511772156
loss 0.5014052987098694
loss 0.32573172450065613
loss 0.2614385485649109
loss 0.28067025542259216
loss 0.3220650553703308
loss 0.28382670879364014
loss 0.28380465507507324
loss 0.25171589851379395
loss 0.3169223964214325
loss 0.24621926248073578
loss 0.2886441648006439
loss 0.2531913518905639

loss 0.44757911562919617
loss 0.41896316409111023
loss 0.4092648923397064
loss 0.35673367977142334
loss 0.4784468710422516
loss 0.43969687819480896
loss 0.5014187097549438
loss 0.4225154519081116
loss 0.376264363527298
loss 0.3603043854236603
loss 0.5562562942504883
loss 0.343655526638031
loss 0.4241122603416443
loss 0.40798550844192505
loss 0.33915582299232483
loss 0.36804014444351196
loss 0.27574723958969116
--------------------- Epoch 180 ---------------------
loss 0.5414193868637085
loss 0.6627988219261169
loss 0.3782949149608612
loss 0.32825160026550293
loss 0.3961640000343323
loss 0.434635728597641
loss 0.31571251153945923
loss 0.3760494887828827
loss 0.4105933904647827
loss 0.4100518822669983
loss 0.37493348121643066
loss 0.3023000657558441
loss 0.44814205169677734
loss 0.4242839217185974
loss 0.34017103910446167
loss 0.3311592638492584
loss 0.3298012614250183
loss 0.4741256535053253
loss 0.40464383363723755
loss 0.38754457235336304
loss 0.367891788482666
loss 0.3577703833580017

loss 0.35298585891723633
loss 0.2330114096403122
loss 0.32278308272361755
loss 0.27533772587776184
loss 0.24566033482551575
loss 0.3623916506767273
loss 0.11196532845497131
--------------------- Epoch 185 ---------------------
loss 0.3376689553260803
loss 0.5497244000434875
loss 0.303877592086792
loss 0.29495900869369507
loss 0.22577090561389923
loss 0.32315877079963684
loss 0.2561693787574768
loss 0.2824965715408325
loss 0.23808154463768005
loss 0.24601252377033234
loss 0.356723427772522
loss 0.35124701261520386
loss 0.3238462805747986
loss 0.3492809236049652
loss 0.2623760998249054
loss 0.27447834610939026
loss 0.3105948567390442
loss 0.2825635075569153
loss 0.32647544145584106
loss 0.334696888923645
loss 0.29319262504577637
loss 0.3366376757621765
loss 0.26727718114852905
loss 0.30219942331314087
loss 0.28381338715553284
loss 0.24585272371768951
loss 0.2961648404598236
loss 0.4219072461128235
loss 0.32055357098579407
loss 0.24657416343688965
loss 0.31075459718704224
loss 0.352739453

loss 0.28089287877082825
loss 0.2647789716720581
loss 0.22180038690567017
loss 0.2367183268070221
loss 0.28696534037590027
loss 0.2737051844596863
loss 0.3668941259384155
loss 0.24013596773147583
loss 0.27913007140159607
loss 0.23613768815994263
loss 0.335144579410553
loss 0.28783121705055237
loss 0.24721869826316833
loss 0.2577897608280182
loss 0.2768748998641968
loss 0.3146592080593109
loss 0.2546342611312866
loss 0.24032188951969147
loss 0.31872451305389404
loss 0.3598247766494751
loss 0.2998592257499695
loss 0.2712849974632263
loss 0.24139247834682465
loss 0.2736872434616089
loss 0.34960120916366577
loss 0.32916632294654846
loss 0.34357088804244995
loss 0.32783159613609314
loss 0.2632695734500885
loss 0.27975520491600037
loss 0.30359455943107605
loss 0.2476455718278885
loss 0.30519339442253113
loss 0.270489364862442
loss 0.2714800238609314
loss 0.2828855812549591
loss 0.28194254636764526
loss 0.2680009603500366
loss 0.31076616048812866
loss 0.31285104155540466
loss 0.33680999279022

loss 0.2625613212585449
loss 0.2976694703102112
loss 0.26890695095062256
loss 0.25721216201782227
loss 0.2830163240432739
loss 0.25534096360206604
loss 0.2886734902858734
loss 0.1961883157491684
loss 0.28583449125289917
loss 0.2778281569480896
loss 0.24455800652503967
loss 0.2856454849243164
loss 0.2722930610179901
loss 0.2574165165424347
loss 0.2500511407852173
loss 0.20395199954509735
loss 0.24553315341472626
loss 0.36373311281204224
loss 0.2660362124443054
loss 0.2413739264011383
loss 0.2988205552101135
loss 0.28100916743278503
loss 0.2519508898258209
loss 0.21296226978302002
loss 0.2799088656902313
loss 0.2491341233253479
loss 0.2711094617843628
loss 0.251290500164032
loss 0.25033077597618103
loss 0.29723796248435974
loss 0.2849252223968506
loss 0.2406284660100937
loss 0.25532203912734985
loss 0.24296322464942932
loss 0.27073267102241516
loss 0.2276005893945694
loss 0.345211923122406
loss 0.30983766913414
loss 0.20605915784835815
loss 0.33377474546432495
loss 0.2657943069934845
los

loss 0.2854786217212677
loss 0.25735270977020264
loss 0.24352619051933289
loss 0.22957095503807068
loss 0.2729280889034271
loss 0.21328508853912354
loss 0.3049382269382477
loss 0.28084754943847656
loss 0.24379555881023407
loss 0.20401175320148468
loss 0.2330690622329712
loss 0.24366790056228638
loss 0.26052457094192505
loss 0.21008838713169098
loss 0.2748032212257385
loss 0.2894152104854584
loss 0.3148672580718994
loss 0.2787923514842987
loss 0.2649741768836975
loss 0.26757150888442993
loss 0.27886196970939636
loss 0.278413325548172
loss 0.3081001341342926
loss 0.28751111030578613
loss 0.24614350497722626
loss 0.2326594442129135
loss 0.2653445899486542
loss 0.2608376443386078
loss 0.24052771925926208
loss 0.3203868567943573
loss 0.30444949865341187
loss 0.2663339078426361
loss 0.24511095881462097
loss 0.23537932336330414
loss 0.30617621541023254
loss 0.30178219079971313
loss 0.24452939629554749
loss 0.2328839898109436
loss 0.2492828667163849
loss 0.24118930101394653
loss 0.265377998352

In [42]:
with open('./Models/train_400_120_enc.pkl', 'wb') as fp:
    pickle.dump(model_enc, fp)

In [43]:
with open('./Models/train_400_120_dec.pkl', 'wb') as fp:
    pickle.dump(model_dec, fp)

In [44]:
with open('./Models/train_400_120_inp_emb.pkl', 'wb') as fp:
    pickle.dump(model_input_emb, fp)

In [45]:
with open('./Models/train_400_120_out_emb.pkl', 'wb') as fp:
    pickle.dump(model_output_emb, fp)

## Evaluation Step

**Entended Vocabulary**

In [29]:
ext_vocab_indexer = copy.deepcopy(vocab_indexer)
for (x, y) in dev:
    for x_tok, y_tok in zip(tokenize(x), tokenize(y)):
        ext_vocab_indexer.get_index(x_tok)
        ext_vocab_indexer.get_index(y_tok)
        
ext_dev_data_indexed = index_data(dev, ext_vocab_indexer)
ext_dev_data_indexed.sort(key=lambda ex: len(ex.x_indexed), reverse=True)
ext_input_dev_max_len = np.max(np.asarray([len(ex.x_indexed) for ex in ext_dev_data_indexed]))
ext_all_dev_input_data = make_padded_input_tensor(ext_dev_data_indexed, ext_vocab_indexer, ext_input_dev_max_len).astype(np.int64)
ext_X_tensors_batch_dev = batch_data(ext_all_dev_input_data, BATCH_SIZE, cuda=CUDA)   # batch_num, batch_size, sent_len

**Evaluation Copy**

In [30]:
def pointer_generate_dev(p_gen, dec_output, dec_attn, ext_X_tensors, vocab_indexer, ext_vocab_indexer, cuda):
    final_distrib = torch.zeros([dec_output.shape[0], len(ext_vocab_indexer)], dtype=torch.float)    # add new words in the end
    if cuda:
        final_distrib = final_distrib.cuda()
    final_distrib[:, 0:len(vocab_indexer)] = p_gen * dec_output
    dec_attn_padding = torch.zeros(ext_X_tensors.shape, dtype=torch.float)
    if cuda:
        dec_attn_padding = dec_attn_padding.cuda()
    dec_attn_padding[:, 0:dec_attn.shape[1]] = dec_attn
    final_distrib = final_distrib.scatter_add(1, ext_X_tensors, (1-p_gen)*dec_attn_padding)
    final_distrib[:, vocab_indexer.index_of(UNK_SYMBOL)] = 0
    return final_distrib

**Performance Evaluation**

In [57]:
def getModelPerf(scores_lst):
    result_dic = {'rouge-1': {'f':0, 'p':0, 'r':0}, 'rouge-2': {'f':0, 'p':0, 'r':0}, 'rouge-l': {'f':0, 'p':0, 'r':0}}
    for scores in scores_lst:
        for key1 in scores:
            for key2 in scores[key1]:
                result_dic[key1][key2] += scores[key1][key2]
    sent_nums = len(scores_lst)
    for key1 in result_dic:
        for key2 in result_dic[key1]:
            result_dic[key1][key2] /= sent_nums
    return result_dic

**Create indexed input/output for development**

In [31]:
# Create indexed input/output for dev
dev_data_indexed.sort(key=lambda ex: len(ex.x_indexed), reverse=True)
input_dev_max_len = np.max(np.asarray([len(ex.x_indexed) for ex in dev_data_indexed]))
all_dev_input_data = make_padded_input_tensor(dev_data_indexed, vocab_indexer, input_dev_max_len).astype(np.int64)
output_dev_max_len = np.max(np.asarray([len(ex.y_indexed) for ex in dev_data_indexed]))
X_tensors_batch_dev = batch_data(all_dev_input_data, BATCH_SIZE, cuda=CUDA)   # batch_num, batch_size, sent_len
if CUDA:
    inp_lens_batch_dev = [torch.tensor([torch.sum(X_tensor != 0) for X_tensor in X_tensors]).cuda() for X_tensors in X_tensors_batch_dev]  # batch_num, batch_size
else:
    inp_lens_batch_dev = [torch.tensor([torch.sum(X_tensor != 0) for X_tensor in X_tensors]) for X_tensors in X_tensors_batch_dev]  # batch_num, batch_size

**Evaluation**

In [40]:
best_data = []
model_enc.eval()
model_dec.eval()
for X_tensors, inp_lens_tensor, ext_X_tensors in zip(X_tensors_batch_dev, inp_lens_batch_dev, ext_X_tensors_batch_dev):
    enc_outputs, enc_context_mask, enc_hidden = encode_input_for_decoder(X_tensors, inp_lens_tensor, model_input_emb, model_enc)
    dec_hidden = enc_hidden
    init_dec_inp = Variable(torch.LongTensor([vocab_indexer.index_of(SOS_SYMBOL)] * X_tensors.shape[0]))  
    if CUDA:
        init_dec_inp = init_dec_inp.cuda()
    dec_input = model_output_emb.forward(init_dec_inp)
    cont = torch.zeros((X_tensors.shape[0], 2 * hidden_size))    # batch_size, 2*hidden_size
    coverage = torch.zeros((X_tensors.shape[0], enc_outputs.shape[0]))             # batch_size, sent_lens
    all_dec_outputs = Variable(torch.zeros(output_dev_max_len, X_tensors.shape[0], len(ext_vocab_indexer)))   # sent_len, batch_size, ext_output_size
    
    if CUDA:
        cont = cont.cuda()
        coverage = coverage.cuda()
        all_dec_outputs = all_dec_outputs.cuda()
    
    for idx in range(output_dev_max_len):
        p_gen, dec_output, dec_hidden, dec_attn, cont, next_coverage = model_dec.forward(dec_input, dec_hidden, enc_outputs, enc_context_mask, cont, coverage)
        all_dec_outputs[idx] = pointer_generate_dev(p_gen, dec_output, dec_attn, ext_X_tensors, vocab_indexer, ext_vocab_indexer, CUDA)
        max_prob_idx = torch.argmax(all_dec_outputs[idx], dim=1)
        max_prob_idx[max_prob_idx >= len(vocab_indexer)] = vocab_indexer.index_of(UNK_SYMBOL)   # new words should be UNK when serving as next input
        dec_input = model_output_emb.forward(max_prob_idx)
    for best_sent in torch.argmax(all_dec_outputs, dim=2).transpose(0, 1).contiguous():
        best_ex = []
        for word_idx in best_sent:            # don't need to include EOS tok
            if word_idx.item() == ext_vocab_indexer.index_of(EOS_SYMBOL):
                break
            best_ex.append(ext_vocab_indexer.get_object(word_idx.item()))     # pred tok
        best_data.append(best_ex)

Both Mr Brown and Mr Blair rose to prominence when Lord Kinnock led Labour between 1983 and 1992.Former Labour leader Lord Kinnock said the chancellor would be best placed to take over from Mr Blair.Tony Blair has become the Labour Party 's longest-serving prime minister.Labour won a huge majority of 167 over the Conservatives in 2001 , but Mr Blair has since been criticised by many in his own party.Gordon Brown , chancellor of the exchequer under Mr Blair , became Britain 's longest-serving chancellor of modern times in 2004.In 1997 , Mr Blair became the youngest premier of the 20th century , when he came to power at the age of 43 .
Both Mr Brown and Blair Mr Blair rose to prominence when when Labour Labour Labour Labour and Former leader 's 's said take majority since Brown Brown Labour , of And Blair , chancellor chancellor of the the 20th under placed , of May , the and '' last said under would would Blair a future premier by Kinnock system and contest chancellor 's huge Kinnock Ki

BBC correspondent Stephen Cape said the combined unions represented `` a formidable force '' which could embarrass the government in the run-up to the General Election.The UK 's biggest civil service union is to ballot its 290,000 members on strikes in protest at government plans to extend their pension age to 65.The government says unions will be consulted before any changes are made to the pension system.The Public and Commercial Services Union will co-ordinate any action with up to six other public sector unions.PCS leader Mark Serwotka warned there could be further walkouts unless there was a government rethink.Unions have already earmarked 23 March for a one-day strike which could involve up to 1.4 million UK workers .
BBC Stephen Stephen in the which . `` correspondent Stephen , correspondent formidable which correspondent the some the represented General General General General General consulted General the on one pension the pension says up to plans its years which years stoppa

to to Criminal can can public disappointed the confirm found found action case to case a case to the question guilty.Greek of Citigroup the prosecutor has been criticism by the prosecutor move has been disappointed 's the a investigation in evidence Citigroup , '' Citigroup The The a referred to individuals a of , to case the be the the the back 's a investigation said financial German August We should German Citigroup signs of the prosecutor saw Citigroup , brought in bought .
[{'rouge-1': {'f': 0.5245901593697931, 'p': 0.7441860465116279, 'r': 0.4050632911392405}, 'rouge-2': {'f': 0.0846560798689849, 'p': 0.10666666666666667, 'r': 0.07017543859649122}, 'rouge-l': {'f': 0.268467511780758, 'p': 0.4418604651162791, 'r': 0.24050632911392406}}]
Land owners claim the National Land Institute has made mistakes in classifying lands as public or private.Under a 2001 land law , the government can tax or seize unused farm sites.In a statement , Mr Rangel said the land reform is not against the c

Another potential option , Glasgow flanker Andrew Wilson , has been ruled out for a month after damaging ligaments in both knees against Northampton recently.The Borders flanker has a knee injury and joins Donnie Macfadyen and Allister Hogg on the sidelines.Scotland doctor James Robson said : `` A scan has shown damage to the medial ligaments of Scott 's right knee ruling him out of the first part of the Championship.Star number eight Simon Taylor will miss at least the first two games after damaging an ankle during his comeback.Scotland 's back row crisis has worsened ahead of the RBS Six Nations with news that Scott Gray will miss out on the opening matches .
may potential , Another may miss out , ruled , , option to first Simon a knee two Hogg the two two eight Glasgow Simon a . weekend for damaging his and miss Allister opener Robson number for his of damaging Robson damage for the first option , to the prop cartilage in damaging the two two two have Another Another how Gray , Will

[{'rouge-1': {'f': 0.6239999950540799, 'p': 0.6964285714285714, 'r': 0.5652173913043478}, 'rouge-2': {'f': 0.17910447262691534, 'p': 0.16981132075471697, 'r': 0.18947368421052632}, 'rouge-l': {'f': 0.5326020332256153, 'p': 0.6071428571428571, 'r': 0.4927536231884058}}]
Rapper 50 Cent has become the first solo artist to have three singles in the US top five in the same week.Newcomer The Game 's debut reached the top of the charts five weeks ago , while 50 Cent 's second collection The Masscre was released in the US at the end of last week.Last May , R & B star Usher scored a chart first , with three concurrent singles in the US Top 10 , a feat which was previously matched only by The Bee Gees and The Beatles.50 Cent also appears on rap protege The Game 's song How We do , number four in the US but now outside the UK top ten .
Rapper 50 Cent Cent become has become weeks Rapper has become the become remained has has become the 50 while solo in the singles and singles three in the same fiv

[{'rouge-1': {'f': 0.5333333290055556, 'p': 0.8421052631578947, 'r': 0.3902439024390244}, 'rouge-2': {'f': 0.159090904432464, 'p': 0.2153846153846154, 'r': 0.12612612612612611}, 'rouge-l': {'f': 0.33683029823181404, 'p': 0.6578947368421053, 'r': 0.3048780487804878}}]
Ray has been nominated in six Oscar categories including best film and best actor for Jamie Foxx.Oscar-nominated film biopic Ray has surpassed its US box office takings with a combined tally of $ 80m ( £43m ) from DVD and video sales and rentals.Ray director Taylor Hackford , responsible for the classic 1982 film An Officer and a Gentleman , has also received an Oscar nomination in the best director category.In its first week on home entertainment release the film was the number one selling DVD , with the limited edition version coming in at number 11 .
Ray has been nominated its Oscar its Oscar-nominated its its Oscar its its its Oscar for Jamie surpassed and surpassed best ( its its its US with takings £43m surpassed and

**Performance**

In [54]:
rouge = Rouge()
scores_lst = []
for test_ex, best_ex in zip(dev_data_indexed, best_data):
    test_str = ' '.join(test_ex.y_tok)
    best_str = ' '.join(best_ex)
    print(test_str)
    print(best_str)
    scores = rouge.get_scores(best_str, test_str)
    print(scores)
    scores_lst.append(scores[0])

Both Mr Brown and Mr Blair rose to prominence when Lord Kinnock led Labour between 1983 and 1992.Former Labour leader Lord Kinnock said the chancellor would be best placed to take over from Mr Blair.Tony Blair has become the Labour Party 's longest-serving prime minister.Labour won a huge majority of 167 over the Conservatives in 2001 , but Mr Blair has since been criticised by many in his own party.Gordon Brown , chancellor of the exchequer under Mr Blair , became Britain 's longest-serving chancellor of modern times in 2004.In 1997 , Mr Blair became the youngest premier of the 20th century , when he came to power at the age of 43 .
Both Mr Brown and Blair Mr Blair rose to prominence when when Labour Labour Labour Labour and Former leader 's 's said take majority since Brown Brown Labour , of And Blair , chancellor chancellor of the the 20th under placed , of May , the and '' last said under would would Blair a future premier by Kinnock system and contest chancellor 's huge Kinnock Ki

But officials job of but said of are are job in to unemployment was of job job crowd.Holmes Germany , losses million million and effect job of the , top of Schroeder job creation does put to has job job job at job stifling , of labour Gerhard of 0.1 job predicted.Local creation of are labour , 10.8 work .
[{'rouge-1': {'f': 0.44642856700414546, 'p': 0.6756756756756757, 'r': 0.3333333333333333}, 'rouge-2': {'f': 0.07228915208012802, 'p': 0.10344827586206896, 'r': 0.05555555555555555}, 'rouge-l': {'f': 0.23681982866631537, 'p': 0.43243243243243246, 'r': 0.21333333333333335}}]
Excluding the car sector , US retail sales were up 0.6 % in January , twice what some analysts had been expecting.Excluding the car sector , sales rose by just 0.3 % .In December , overall retail sales rose by 1.1 % .US retail sales fell 0.3 % in January , the biggest monthly decline since last August , driven down by a heavy fall in car sales.The 3.3 % fall in car sales had been expected , coming after December 's 

India 's government has given its backing to cheaper and more accessible air travel.When it was set up the firm offered tickets that were 50 % cheaper than other Indian airlines.Air Deccan has ordered 30 Airbus A320 planes in a $ 1.8bn ( £931m ) deal as India 's first low-cost airline expands in the fast-growing domestic market.Beer magnate Vijay Mallya recently set up Kingfisher Airlines , while UK entrepreneur Richard Branson has said he is keen to start a local operation.Air Deccan was set up last year and wants to lure travellers away from the railway network and pricier rivals.The potential of the Indian market has attracted attention at home and abroad .
India 's more has given its backing to cheaper and more accessible accessible travel.When it was set up the firm offered in a $ 1.8bn entrepreneur in in January $ $ 50 offered Deccan it tickets tickets the firm offered it % to expected recently the it accessible was expands it was magnate entrepreneur it said expands was by to by

On Tuesday , Morientes had said : `` I like Liverpool and I am pleased that a club of their stature want to buy me.I have told Madrid that I want it to happen.But Benitez could yet turn his attentions to the younger Anelka should Morientes be reluctant to pledge his future to Liverpool.Newcastle have joined the race to sign Real Madrid striker Fernando Morientes and scupper Liverpool 's bid to snap up the player , according to reports.Real are believed to still want £7m before selling Morientes.If Madrid do not want me then it 's in the best interests of everyone that they are realistic .
On Tuesday , Morientes had said : `` `` I like to : I I want reluctant to should should want has should I want I I want I I I being bidding I could . `` like being being bidding attentions bidding want to stature I . `` like like then attentions had had being being . `` being like to permanent I I played could could . `` have like being being bidding had had had could like to I Madrid turn stature sta

[{'rouge-1': {'f': 0.5233644815791773, 'p': 0.8, 'r': 0.3888888888888889}, 'rouge-2': {'f': 0.2206896505721761, 'p': 0.3076923076923077, 'r': 0.17204301075268819}, 'rouge-l': {'f': 0.3542390110617143, 'p': 0.6571428571428571, 'r': 0.3194444444444444}}]
The Duchy of Lancaster provides the Queen 's private income , while the Duchy of Cornwall provides Prince Charles ' annual income.Aides from the Duchy of Lancaster and Duchy of Cornwall will appear before the Commons Public Accounts Committee.Duchy officials , who will appear before the committee on Monday , are only responsible for generating money.Senior officials at the two bodies generating private income for the Queen and Prince of Wales are to be questioned by MPs.The prince has voluntarily paid income tax - currently 40 % - since 1993 .
The Duchy of Lancaster provides the the Charles . `` `` Duchy appear Wales and of committee two Duchy Charles , while appear appear appear Wales last year 's William William William , while , befor

In [58]:
result_dic = getModelPerf(scores_lst)
print(result_dic)

{'rouge-1': {'f': 0.5672983961105551, 'p': 0.8445771616593384, 'r': 0.44006148343795387}, 'rouge-2': {'f': 0.20048847585455193, 'p': 0.2519793439842693, 'r': 0.17646345408604294}, 'rouge-l': {'f': 0.3980085871232608, 'p': 0.6954182204560879, 'r': 0.3623109768543346}}


In [38]:
dev

[["Byrds producer Melcher dies at 62  Record producer Terry Melcher , who was behind hits by the Byrds , Ry Cooder and the Beach Boys , has died aged 62 .  The son of actress Doris Day , he helped write Kokomo for the Beach Boys , which was used in the movie Cocktail , earning a 1988 Golden Globe nomination . He also produced Mr Tambourine Man for the Byrds , as well as other his such as Turn , Turn Turn . Melcher died on Friday night at his home in Beverly Hills , California , after a long battle with skin cancer . He joined Columbia Records as a producer in the mid-1960s , and also worked with Gram Parsons and the Mamas and the Papas .  Earlier in his career , Melcher had hits as part of duo called Bruce & Terry , with future Beach Boy Bruce Johnston , which evolved into the Rip Chords group . Melcher also worked closely with his mother , producing The Doris Day Show and helping to run her charitable activities . In 1969 his name became linked with the Charles Manson murders , which 

In [39]:
for test_ex, best_ex in zip(dev_data_indexed, best_data):
    test_str = ' '.join(test_ex.y_tok)
    print(len(test_ex.y_tok), test_str)

114 Both Mr Brown and Mr Blair rose to prominence when Lord Kinnock led Labour between 1983 and 1992.Former Labour leader Lord Kinnock said the chancellor would be best placed to take over from Mr Blair.Tony Blair has become the Labour Party 's longest-serving prime minister.Labour won a huge majority of 167 over the Conservatives in 2001 , but Mr Blair has since been criticised by many in his own party.Gordon Brown , chancellor of the exchequer under Mr Blair , became Britain 's longest-serving chancellor of modern times in 2004.In 1997 , Mr Blair became the youngest premier of the 20th century , when he came to power at the age of 43 .
101 More than five million households in the UK have broadband and that number is growing fast.The Demos report looked at the impact of broadband on people 's net habits.More significantly , argues the report , broadband is encouraging people to take a more active role online.The Demos report , entitled Broadband Britain : The End Of Asymmetry ? , was 

70 After their 3-2 win over Manchester City , McClaren said : `` We are playing exciting football , it 's a magnificent result to keep us in the top five.McClaren also praised winger Stewart Downing and strikers Jimmy Floyd Hasselbaink and Mark Viduka , who both ended barren runs in front of goal.He added : `` If Stewart keeps playing like this Sven-Goran Eriksson has got to pick him .
69 Munster 's Heineken Cup quarter-final tie against Biarritz on 3 April has been switched to Real Sociedad 's Paseo de Anoeta stadium in San Sebastian.Ulster were the last Irish team to play at the Paseo de Anoeta stadium where they faced a Euskarians side during a pre-season tour in 1998.Real 's ground holds 32,000 whereas the Parc des Sports Aguilera in Biarritz has a capacity of just 12,667 .
75 Ronaldo refused to commit his long-term future to the club.I 'm happy but nobody knows the future . `` `` The United board have already made an offer to renew the contract but I 'm trying not to think about i

In [None]:
# Create indexed input/output for training
# train_data_indexed.sort(key=lambda ex: len(ex.x_indexed), reverse=True)
# input_train_max_len = np.max(np.asarray([len(ex.x_indexed) for ex in train_data_indexed]))
# all_train_input_data = make_padded_input_tensor(train_data, input_indexer, input_train_max_len, args.reverse_input)
# output_train_max_len = np.max(np.asarray([len(ex.y_indexed) for ex in train_data]))
# all_train_output_data = make_padded_output_tensor(train_data, output_indexer, output_train_max_len)


# # Create indexed input/output for dev
# dev_data_indexed.sort(key=lambda ex: len(ex.x_indexed), reverse=True)
# input_dev_max_len = np.max(np.asarray([len(ex.x_indexed) for ex in dev_data_indexed]))
# all_dev_input_data = make_padded_input_tensor(dev_data_indexed, input_indexer, input_dev_max_len, args.reverse_input)
# output_dev_max_len = np.max(np.asarray([len(ex.y_indexed) for ex in dev_data_indexed]))
# all_dev_output_data = make_padded_output_tensor(dev_data_indexed, output_indexer, output_dev_max_len)



#         else:
#             for idx in range(output_train_max_len):
# #                 dec_output, dec_hidden = model_dec.forward(dec_input, dec_hidden)
#                 p_gen, dec_output, dec_hidden, dec_attn = model_dec.forward(dec_input, dec_hidden, enc_outputs, enc_context_mask)
#                 all_dec_outputs[idx] = pointer_generate_train(p_gen, dec_output, dec_attn, X_tensors)
#                 max_prob_idx = torch.argmax(all_dec_outputs[idx], dim=1)
#                 dec_input = model_output_emb.forward(max_prob_idx)

#         all_context_mask = torch.from_numpy(np.asarray([[1 if j < oup_lens_tensor.data[i].item() \
#             else 0 for j in range(0, Y_tensors.size(1))] for i in range(0, oup_lens_tensor.shape[0])], dtype=np.uint8))
#         loss = masked_cross_entropy(all_dec_outputs.transpose(0, 1).contiguous(), Y_tensors, oup_lens_tensor, all_context_mask)       # batch_size, sent_len, output_size
#                                                                                                                                       # batch_size, sent_len
#         Y_resize = torch.transpose(Y_tensors, 0, 1).contiguous()
#         loss = criterion(all_dec_outputs.view(-1, all_dec_outputs.shape[2]), Y_resize.view(-1))

# use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False