In [1]:
from google_drive_downloader import GoogleDriveDownloader as gdd
import random
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pandas as pd
from itertools import cycle
from torch.distributions import Categorical

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
file_path = './cnn_stories_tokenized'
gdd.download_file_from_google_drive(file_id='0BzQ6rtO2VN95cmNuc2xwUS1wdEE', 
                                    dest_path=file_path+'.zip', 
                                    unzip=True, 
                                    showsize=True)

In [2]:
from os import listdir
from tqdm import tqdm

def load_data(file_path):
    data = []
    file_list = listdir(file_path)
    for name in tqdm(file_list):
        file_name = file_path + '/' + name
        doc = load_doc(file_name)
        src, trg = split_doc(doc)
        data.append({'src': src, 'trg': trg})
    return data

def load_doc(file_name):
    file = open(file_name, encoding='utf-8')
    text = file.read()
    file.close()
    return text

def split_doc(doc):
    idx = doc.find('@highlight')
    src, trg = doc[:idx], doc[idx:].split('@highlight')
    trg = [t.strip() for t in trg if len(t) > 0]
    return src, trg

def clean_sent(sent):
    filter_list = ['/', '-LRB-', '-RRB-', '\n', '`', '\'\'', '"', '--', '...', 'NEW :']
    for token in filter_list:
        sent = sent.replace(token, ' ')
    sent = ' '.join(sent.split())
    sent = sent.lower()
    return sent

def preprocess(file_path, train_size, valid_size, test_size, max_enc_step, max_dec_step):
    # load data from file_path
    data = load_data(file_path)
    preprocessed = []
    for ex in data:
        src, trg = ex['src'], ex['trg']
        src = clean_sent(src)
        trg = [clean_sent(t) for t in trg]
        # choose the first hightlight as the truth
        # choose the first 2 highlights as the trg
        if len(trg) > 1:
            trg = trg[:2]
            trg = ' . '.join(trg)
        else:
            trg = trg[0]
        trg += ' .'
        # truncate examples
        if len(src) > max_enc_step > 0: 
            src = src.split()
            src = src[:max_enc_step]
            src = ' '.join(src)
        if len(trg) > max_dec_step > 0: 
            trg = trg.split()
            trg = trg[:max_dec_step]
            trg = ' '.join(trg)
        preprocessed.append({'src': src, 'trg': trg})
    # split data
    train_data = preprocessed[:train_size]
    valid_data = preprocessed[train_size:train_size+valid_size]
    test_data = preprocessed[-test_size:]
    print('Number of train examples: {}'.format(len(train_data)))
    print('Number of valid examples: {}'.format(len(valid_data)))
    print('Number of test examples: {}'.format(len(test_data)))
    return train_data, valid_data, test_data


train_size = 90000
valid_size = 1579
test_size = 1000
max_enc_step = 300
max_dec_step = -1
train_data, valid_data, test_data = preprocess(file_path, 
                                               train_size, 
                                               valid_size, 
                                               test_size,
                                               max_enc_step, 
                                               max_dec_step)


100%|██████████| 92579/92579 [00:02<00:00, 32272.44it/s]


Number of train examples: 90000
Number of valid examples: 1579
Number of test examples: 1000


In [3]:
class Vocab():
   
    def __init__(self, data, max_size=80000, min_freq=4):
        self.PAD = 0
        self.SOS = 1
        self.EOS = 2
        self.UNK = 3
        self.index2token = {0: '<PAD>', 1: '<SOS>', 2: '<EOS>', 3: '<UNK>'}
        self.token2index = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
        self.token2count = {}
        self.n_tokens = 4
        self.max_size = max_size
        self.min_freq = min_freq

        print("Read {} example pairs".format(len(data)))
        for ex in data:
            self.add_sent(ex['src'])
            self.add_sent(ex['trg'])
        print('Number of vocab: {}'.format(self.n_tokens))

        vocab_count = list(self.token2count.items()) 
        vocab_count = sorted(vocab_count, key=lambda x: x[1], reverse=True) # Sort by counts (descending)
        if max_size-4 < len(vocab_count):
            vocab_count = vocab_count[:max_size-4]

        self.token2index = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
        self.token2count = {}
        self.index2token = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.n_tokens = 4

        for token, count in vocab_count:
            if count >= min_freq: 
                self.token2index[token] = self.n_tokens
                self.token2count[token] = count
                self.index2token[self.n_tokens] = token
                self.n_tokens += 1    
        print('Number of vocab left: {}'.format(self.n_tokens))

    def add_sent(self, sent):
        for token in sent.split():
            self.add_token(token)

    def add_token(self, token):
        if sum([char.isdigit() for char in token]) == 0:
            if token not in self.token2index:
                self.token2index[token] = self.n_tokens
                self.token2count[token] = 1
                self.index2token[self.n_tokens] = token
                self.n_tokens += 1
            else:
                self.token2count[token] += 1

    def __getitem__(self, item):
        if type(item) == str:
            return int(self.token2index.get(item, self.UNK)) # Return the index of <UNK> if input word is not in the vocab
        elif type(item) == int:
            return self.index2token[item]

def build_oovs(src_tokens):
    oovs = [token for token in src_tokens if vocab[token] == vocab.UNK]
    return oovs

def tokens2indices(tokens):
    indices = [vocab[token] for token in tokens]
    indices = [vocab.SOS] + indices + [vocab.EOS]
    return indices

def extended2indices(extended, oovs):
    indices = []
    for token in extended:
        idx = vocab[token]
        if idx == vocab.UNK:
            idx = vocab.n_tokens + oovs.index(token) if token in oovs else vocab.UNK
        indices.append(idx)
    indices = [vocab.SOS] + indices + [vocab.EOS]
    return indices

def indices2extended(indices, oovs):
    tokens = []
    for idx in indices:
        idx = idx.item()
        if idx > vocab.n_tokens-1:
            token = oovs[idx - vocab.n_tokens]
        else:
            token = vocab[idx]
        tokens.append(token)
    return tokens

vocab = Vocab(train_data)

Read 90000 example pairs
Number of vocab: 197195
Number of vocab left: 80000


In [4]:
class Example():
    def __init__(self, ex, max_src_len=100):

        self.src_sent = ex['src']
        self.trg_sent = ex['trg']
        self.src_tokens = self.src_sent.split() 
        self.trg_tokens = self.trg_sent.split()
        self.src = tokens2indices(self.src_tokens) # For train
        self.trg = tokens2indices(self.trg_tokens)[:-1] # For train
        self.oovs = build_oovs(self.src_tokens)
        self.src_extended_vocab = extended2indices(self.src_tokens, self.oovs) # For attn dist
        self.trg_extended_vocab = extended2indices(self.trg_tokens, self.oovs)[1:] # For eval
        self.src_len = len(self.src)
        self.trg_len = len(self.trg)

def pad_sent(sent, max_len):
    sent += [vocab.PAD]*(max_len - len(sent))
    return sent

class Batch():
    def __init__(self, example_list):
        self.batch_size = len(example_list)
        self.original_src = [example.src_sent for example in example_list]
        self.original_trg = [example.trg_sent for example in example_list]

        self.max_src_len = max([example.src_len for example in example_list])
        self.max_trg_len = max([example.trg_len for example in example_list])

        self.src = np.zeros([self.batch_size, self.max_src_len], dtype=np.int32) # For train
        self.src_mask = np.zeros([self.batch_size, self.max_src_len], dtype=np.int32) # For train
        self.trg = np.zeros([self.batch_size, self.max_trg_len], dtype=np.int32) # For loss calculation
        self.src_len = np.zeros([self.batch_size], dtype=np.float32) # For train
        self.trg_len = np.zeros([self.batch_size], dtype=np.float32) # For train
        self.src_extended_vocab = np.zeros([self.batch_size, self.max_src_len], dtype=np.int32) # For attn dist
        self.trg_extended_vocab = np.zeros([self.batch_size, self.max_trg_len], dtype=np.int32) # For eval
        self.max_oovs_len = max([len(example.oovs) for example in example_list]) # For attn dist
        self.oovs = [example.oovs for example in example_list] # For attn dist
        
        for i, example in enumerate(example_list):

            self.src[i] = pad_sent(example.src, self.max_src_len)
            self.src_mask[i] = pad_sent([1]*(example.src_len), self.max_src_len)
            self.trg[i] = pad_sent(example.trg, self.max_trg_len)
            self.src_len[i] = example.src_len 
            self.trg_len[i] = example.trg_len 
            self.src_extended_vocab[i, :] = pad_sent(example.src_extended_vocab, self.max_src_len)
            self.trg_extended_vocab[i, :] = pad_sent(example.trg_extended_vocab, self.max_trg_len)

        self.src = torch.from_numpy(self.src).long().to(device)
        self.src_mask = torch.from_numpy(self.src_mask).long().to(device)
        self.trg = torch.from_numpy(self.trg).long().to(device)
        self.src_len = torch.from_numpy(self.src_len).long().to(device)
        self.trg_len = torch.from_numpy(self.trg_len).long().to(device)
        self.src_extended_vocab = torch.from_numpy(self.src_extended_vocab).long().to(device)
        self.trg_extended_vocab = torch.from_numpy(self.trg_extended_vocab).long().to(device)
        self.oov_pad = torch.zeros(self.batch_size, self.max_oovs_len).to(device) if self.max_oovs_len > 0 else None # [batch_size, oov_len]


def batchify(data, batch_size):
    examples = [Example(ex) for ex in data]
    examples = sorted(examples, key=lambda example: example.src_len, reverse=True)
    examples = [examples[i:i+batch_size] for i in range(0, len(examples), batch_size)]
    examples = [Batch(batch) for batch in examples]
    return examples

class Dataset():
    def __init__(self, data, batch_size):
        self.batches = batchify(data, batch_size)
    
    def process_data(self):
        for batch in self.batches:
            yield batch
        
    def get_stream(self):
        return cycle(self.process_data())

    def __iter__(self):
        return self.get_stream()

In [5]:
class Encoder(nn.Module):
    def __init__(self, emb_dim, hid_dim):
        super(Encoder, self).__init__()
        self.lstm = nn.LSTM(emb_dim, hid_dim, bidirectional=True, batch_first=True)
        self.reduce_hidden = nn.Linear(hid_dim*2, hid_dim, bias=False)
        self.reduce_cell = nn.Linear(hid_dim*2, hid_dim, bias=False)
    def forward(self, embedded, src_len): 
        packed = pack_padded_sequence(embedded, src_len, batch_first=True) 
        output, (hidden, cell) = self.lstm(packed)
        output, _ = pad_packed_sequence(output, batch_first=True)
        hidden = F.relu(self.reduce_hidden(torch.cat((hidden[-2, :, : ], hidden[-1, :, : ]), dim=1))).unsqueeze(0) # [1, batch_size, hid_dim]
        cell = F.relu(self.reduce_cell(torch.cat((cell[-2, :, : ], cell[-1, :, : ]), dim=1))).unsqueeze(0) # [1, batch_size, hid_dim]
        return output, hidden, cell

class Attention(nn.Module):
    def __init__(self, hid_dim):
        super(Attention, self).__init__()
        self.W_h = nn.Linear(2*hid_dim, 2*hid_dim, bias=False)
        self.W_s = nn.Linear(2*hid_dim, 2*hid_dim, bias=False)
        self.v = nn.Linear(2*hid_dim, 1, bias=False)
    def forward(self, output, hidden, cell, src_mask):
        hidden = torch.cat((hidden, cell), dim=2).permute(1, 0, 2) # [batch_size, src_len, 2*hid_dim]
        energy = self.v(torch.tanh(self.W_h(output) + self.W_s(hidden))) # [batch_size, src_len, 2*hid_dim] EQ.11
        attn_dist = F.softmax(energy.squeeze(2), dim=1) # [batch_size, src_len]
        # Masking
        attn_dist = attn_dist * src_mask # [batch_size, src_len]
        attn_dist = attn_dist / attn_dist.sum(1, keepdim=True)
        context = torch.bmm(attn_dist.unsqueeze(1), output).squeeze(1) # [batch_size, 2*hid_dim]
        return attn_dist, context

class Pointer_Generator(nn.Module):
    def __init__(self, emb_dim, hid_dim, output_dim):
        super(Pointer_Generator, self).__init__()
        self.ptr = nn.Linear(emb_dim+4*hid_dim, 1, bias=True)
        self.V1 = nn.Linear(emb_dim+4*hid_dim, hid_dim, bias=True)
        self.V2 = nn.Linear(hid_dim, output_dim, bias=True)
    def forward(self, embedded, hidden, cell, context, attn_dist, src_extended_vocab, oov_pad):

        hidden = torch.cat((hidden, cell), dim=2).squeeze(0) # [batch_size, 2*hid_dim]

        # Generation probability
        gen_prob = torch.sigmoid(self.ptr(torch.cat((embedded, hidden, context), dim=1))) # [batch_size, 1]

        # Vocabulary distribution 
        vocab_dist = F.softmax(self.V2(self.V1(torch.cat((embedded, hidden, context), dim=1))), dim=1) # [batch_size, output_dim]
        vocab_dist = gen_prob * vocab_dist # [batch_size, output_dim+oov_len]
        vocab_dist = torch.cat((vocab_dist, oov_pad), dim=1) if oov_pad != None else vocab_dist
            
        # Attention distribution
        attn_dist = (1-gen_prob) * attn_dist # [batch_size, src_len]

        # Word distribution
        word_dist = vocab_dist.scatter_add(1, src_extended_vocab, attn_dist) # [batch_size, output_dim+oov_len]
        return word_dist

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim):
        super(Decoder, self).__init__()
        self.attention = Attention(hid_dim)
        self.lstm = nn.LSTM(emb_dim+2*hid_dim, hid_dim, bidirectional=False, batch_first=True)
        self.pointer_generator = Pointer_Generator(emb_dim, hid_dim, output_dim)
    def forward(self, src_mask, embedded, hidden, cell, output, src_extended_vocab, oov_pad):
        attn_dist, context = self.attention(output, hidden, cell, src_mask)
        decoder_input = torch.cat((embedded, context), dim=1).unsqueeze(1) # [batch_size, 1, emb_dim+2*hid_dim]
        _, (hidden, cell) = self.lstm(decoder_input, (hidden, cell))
        word_dist = self.pointer_generator(embedded, hidden, cell, context, attn_dist, src_extended_vocab, oov_pad) # [batch_size, output_dim+oov_len]
        return word_dist, hidden, cell

class Param():
    def __init__(self):
        self.vocab_dim = vocab.n_tokens
        self.emb_dim = 300
        self.hid_dim = 300
        self.lr = 0.001
        self.n_iters = 100000
        self.batch_size = 16

class Model(nn.Module):
    def __init__(self, param):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(param.vocab_dim, param.emb_dim).to(device)
        self.encoder = Encoder(param.emb_dim, param.hid_dim).to(device)
        self.decoder = Decoder(param.vocab_dim, param.emb_dim, param.hid_dim).to(device)

param = Param()
model = Model(param)
model

Model(
  (embedding): Embedding(80000, 300)
  (encoder): Encoder(
    (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
    (reduce_hidden): Linear(in_features=600, out_features=300, bias=False)
    (reduce_cell): Linear(in_features=600, out_features=300, bias=False)
  )
  (decoder): Decoder(
    (attention): Attention(
      (W_h): Linear(in_features=600, out_features=600, bias=False)
      (W_s): Linear(in_features=600, out_features=600, bias=False)
      (v): Linear(in_features=600, out_features=1, bias=False)
    )
    (lstm): LSTM(900, 300, batch_first=True)
    (pointer_generator): Pointer_Generator(
      (ptr): Linear(in_features=1500, out_features=1, bias=True)
      (V1): Linear(in_features=1500, out_features=300, bias=True)
      (V2): Linear(in_features=300, out_features=80000, bias=True)
    )
  )
)

In [6]:
def get_ngram(lst, n):
    ngram = [tuple(lst[i:i+n]) for i in range(len(lst) - n + 1)]
    return ngram

def match(n_gram_1, n_gram_2):
    intersection = set(n_gram_1) & set(n_gram_2)
    return len(intersection)

def single_rouge_n(reference, generated, N=1):
    reference = get_ngram(reference.split(), N)
    generated = get_ngram(generated.split(), N)
    match_count = match(reference, generated)
    rec_count = len(reference)
    score = match_count / rec_count if rec_count > 0 else 0
    return score

def calc_rouge(reference_sents, generated_sents, N=1):
    scores = []
    for (reference, generated) in zip(reference_sents, generated_sents):
        score = single_rouge_n(reference, generated, N)
        scores.append(score)
    return torch.tensor(scores).float().to(device) # [batch_size]

In [7]:
def train_batch(batch, model, gamma):
    src = batch.src
    src_len = batch.src_len
    embedded = model.embedding(src)
    state = model.encoder(embedded, src_len)
    loss_CE = train_batch_CE(batch, model, state) if gamma < 1 else torch.zeros(1).to(device)
    loss_RL = train_batch_RL(batch, model, state) if gamma > 0 else torch.zeros(1).to(device)
    loss = gamma*loss_RL + (1-gamma)*loss_CE
    return loss

def sample_decode(batch, model, state):
    # https://pytorch.org/docs/stable/distributions.html
    src_mask = batch.src_mask
    src_extended_vocab = batch.src_extended_vocab
    oov_pad = batch.oov_pad
    (output, hidden, cell) = state

    actions = []
    log_probs = []
    eos_masks = []

    action = torch.tensor([vocab.SOS]).repeat(batch.batch_size).to(device)
    for t in range(100):
        embedded = model.embedding(action)
        probs, hidden, cell = model.decoder(src_mask, embedded, hidden, cell, output, src_extended_vocab, oov_pad)
        m = Categorical(probs)
        action = m.sample()
        log_prob = m.log_prob(action)
        action, eos_mask = next_action(action)

        actions.append(action)
        eos_masks.append(eos_mask)
        log_probs.append(log_prob)

    actions = torch.stack(actions, dim=1)
    eos_masks = torch.stack(eos_masks, dim=1)
    log_probs = torch.stack(log_probs, dim=1)*eos_masks

    log_probs_sum = torch.sum(log_probs, dim=1)/torch.sum(eos_masks, dim=1)   
    sents = actions2sents(actions, batch)
    return sents, log_probs_sum

def greedy_decode(batch, model, state):
    with torch.no_grad():
        src_mask = batch.src_mask
        src_extended_vocab = batch.src_extended_vocab
        oov_pad = batch.oov_pad
        (output, hidden, cell) = state

        actions = []
        
        action = torch.tensor([vocab.SOS]).repeat(batch.batch_size).to(device)
        for t in range(100):
            embedded = model.embedding(action)
            probs, hidden, cell = model.decoder(src_mask, embedded, hidden, cell, output, src_extended_vocab, oov_pad)
            _, action = torch.max(probs, dim=1)
            action, _ = next_action(action)

            actions.append(action)

        actions = torch.stack(actions, dim=1)
        sents = actions2sents(actions, batch)
    return sents

def actions2sents(actions, batch):
    sents = []    
    for i in range(batch.batch_size):
        sent = list(actions[i].cpu().numpy())
        sent = indices2extended(sent, batch.oovs[i])
        eos = vocab[vocab.EOS]
        sent = sent[:sent.index(eos)] if eos in sent else sent
        sent = ' '.join(sent)
        sents.append(sent)
    return sents

def next_action(action):
    oov_mask = (action > vocab.n_tokens-1).long()                                  
    action = oov_mask*vocab.UNK + (1-oov_mask)*action
    eos_mask = torch.ones(action.size(0)).to(device)
    eos_mask[action == vocab.EOS] == 0
    return action, eos_mask

def train_batch_RL(batch, model, state):
    sample_sents, log_probs = sample_decode(batch, model, state)
    greedy_sents = greedy_decode(batch, model, state)
    sample_reward = calc_rouge(batch.original_src, sample_sents)
    baseline_reward = calc_rouge(batch.original_src, greedy_sents)
    loss_RL = torch.mean((baseline_reward - sample_reward) * log_probs)
    return loss_RL

def calc_CE_loss(vocab_dist, trg_step):
    probs = torch.gather(vocab_dist, 1, trg_step.unsqueeze(1)).squeeze()
    step_loss = -torch.log(probs + 1e-9) 
    return step_loss

def train_batch_CE(batch, model, state):
    src_mask = batch.src_mask
    trg = batch.trg
    trg_len = batch.trg_len
    src_extended_vocab = batch.src_extended_vocab
    trg_extended_vocab = batch.trg_extended_vocab
    oov_pad = batch.oov_pad
    (output, hidden, cell) = state

    step_losses = []
    for t in range(trg.size(1)):
        embedded = model.embedding(trg[:, t])
        word_dist, hidden, cell = model.decoder(src_mask, embedded, hidden, cell, output, src_extended_vocab, oov_pad)
        step_loss = calc_CE_loss(word_dist, trg_extended_vocab[:, t])
        step_losses.append(step_loss)
    step_losses = torch.sum(torch.stack(step_losses, dim=1), dim=1)
    loss_CE = torch.mean(step_losses/trg_len)
    return loss_CE

In [8]:
def infer_batch(batch, model):
    assert batch.batch_size == 1
    src = batch.src
    src_len = batch.src_len
    embedded = model.embedding(src)
    state = model.encoder(embedded, src_len)
    sents = greedy_decode(batch, model, state)
    return sents

def infer_batches(batches, model, save=False, name='rl.csv'):
    refs = []
    cans = []
    srcs = []
    scores = []
    with torch.no_grad():
        for batch in batches:
            src = batch.original_src[0]
            ref = batch.original_trg[0]
            can = infer_batch(batch, model)[0]
            score = single_rouge_n(ref, can)

            refs.append(ref)
            cans.append(can)
            srcs.append(src)
            scores.append(score)

    score = np.mean(scores)
    if save == True:
        df = pd.DataFrame({'src': srcs, 'reference': refs, 'candidate': cans})
        df.to_csv(path_or_buf=name, index=False)
    return score

In [9]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Embedding has {count_parameters(model.embedding):,} trainable parameters')
print(f'Encoder has {count_parameters(model.encoder):,} trainable parameters')
print(f'Decoder has {count_parameters(model.decoder):,} trainable parameters')

Embedding has 24,000,000 trainable parameters
Encoder has 1,804,800 trainable parameters
Decoder has 26,694,801 trainable parameters


In [13]:
import time

train_dataset = Dataset(train_data, param.batch_size)
valid_dataset = batchify(valid_data, 1)
test_dataset = batchify(test_data, 1)

In [11]:
model.load_state_dict(torch.load('pretrained.w'))
infer_batches(test_dataset, model, save=True, name='pretrained.csv')

KeyboardInterrupt: ignored

In [15]:
model.load_state_dict(torch.load('pretrained.w'))

start_time = time.time()
best_rouge = 0
gamma = 0.9
train_losses = []
optimizer = optim.Adam(model.parameters(), lr=param.lr)

for i, batch in enumerate(train_dataset):
    if i % 300 == 0 and i != 0:
        gamma += 0.05

    optimizer.zero_grad()
    train_loss = train_batch(batch, model, gamma)
    train_losses.append(train_loss.item())
    train_loss.backward()
    optimizer.step()

    if i % 100 == 0:
        valid_rouge = infer_batches(valid_dataset, model)
        avg_train_loss = np.mean(train_losses)
        elapsed_time = time.time() - start_time
        print(f'iter: {i:06}\tgamma: {gamma}\ttrain_loss: {avg_train_loss:.4f}\tvalid_rouge: {valid_rouge:.4f}\telapsed_time: {elapsed_time:.6f}')

    if valid_rouge > best_rouge:
        best_rouge = valid_rouge
        torch.save(model.state_dict(), 'rl.w')

iter: 000000	gamma: 0.9	train_loss: 0.4133	valid_rouge: 0.2221	elapsed_time: 312.149242
iter: 000100	gamma: 0.9	train_loss: 0.3615	valid_rouge: 0.2312	elapsed_time: 816.794215
iter: 000200	gamma: 0.9	train_loss: 0.3533	valid_rouge: 0.2401	elapsed_time: 1320.661373
iter: 000300	gamma: 0.9500000000000001	train_loss: 0.3527	valid_rouge: 0.2275	elapsed_time: 1823.751273
iter: 000400	gamma: 0.9500000000000001	train_loss: 0.3073	valid_rouge: 0.2345	elapsed_time: 2325.733480
iter: 000500	gamma: 0.9500000000000001	train_loss: 0.2796	valid_rouge: 0.2315	elapsed_time: 2828.787404
iter: 000600	gamma: 1.0	train_loss: 0.2603	valid_rouge: 0.2393	elapsed_time: 3330.650475
iter: 000700	gamma: 1.0	train_loss: 0.2254	valid_rouge: 0.2810	elapsed_time: 3794.255626
iter: 000800	gamma: 1.0	train_loss: 0.2208	valid_rouge: 0.2178	elapsed_time: 4259.083356
iter: 000900	gamma: 1.05	train_loss: 0.2194	valid_rouge: 0.1936	elapsed_time: 4723.061026


KeyboardInterrupt: ignored

In [16]:
model.load_state_dict(torch.load('rl.w'))
infer_batches(test_dataset, model, save=True)

0.280294509739343