In [10]:
from google_drive_downloader import GoogleDriveDownloader as gdd
import random
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pandas as pd
from itertools import cycle

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
file_path = './cnn_stories_tokenized'
gdd.download_file_from_google_drive(file_id='0BzQ6rtO2VN95cmNuc2xwUS1wdEE', 
                                    dest_path=file_path+'.zip', 
                                    unzip=True, 
                                    showsize=True)

In [11]:
from os import listdir
from tqdm import tqdm

def load_data(file_path):
    data = []
    file_list = listdir(file_path)
    for name in tqdm(file_list):
        file_name = file_path + '/' + name
        doc = load_doc(file_name)
        src, trg = split_doc(doc)
        data.append({'src': src, 'trg': trg})
    return data

def load_doc(file_name):
    file = open(file_name, encoding='utf-8')
    text = file.read()
    file.close()
    return text

def split_doc(doc):
    idx = doc.find('@highlight')
    src, trg = doc[:idx], doc[idx:].split('@highlight')
    trg = [t.strip() for t in trg if len(t) > 0]
    return src, trg

def clean_sent(sent):
    filter_list = ['/', '-LRB-', '-RRB-', '\n', '`', '\'\'', '"', '--', '...', 'NEW :']
    for token in filter_list:
        sent = sent.replace(token, ' ')
    sent = ' '.join(sent.split())
    sent = sent.lower()
    return sent

def preprocess(file_path, train_size, valid_size, test_size, max_enc_step, max_dec_step):
    # load data from file_path
    data = load_data(file_path)
    preprocessed = []
    for ex in data:
        src, trg = ex['src'], ex['trg']
        src = clean_sent(src)
        trg = [clean_sent(t) for t in trg]
        # choose the first hightlight as the truth
        # choose the first 2 highlights as the trg
        if len(trg) > 1:
            trg = trg[:2]
            trg = ' . '.join(trg)
        else:
            trg = trg[0]
        trg += ' .'
        # truncate examples
        if len(src) > max_enc_step > 0: 
            src = src.split()
            src = src[:max_enc_step]
            src = ' '.join(src)
        if len(trg) > max_dec_step > 0: 
            trg = trg.split()
            trg = trg[:max_dec_step]
            trg = ' '.join(trg)
        preprocessed.append({'src': src, 'trg': trg})
    # split data
    train_data = preprocessed[:train_size]
    valid_data = preprocessed[train_size:train_size+valid_size]
    test_data = preprocessed[-test_size:]
    print('Number of train examples: {}'.format(len(train_data)))
    print('Number of valid examples: {}'.format(len(valid_data)))
    print('Number of test examples: {}'.format(len(test_data)))
    return train_data, valid_data, test_data


train_size = 90000
valid_size = 1579
test_size = 1000
max_enc_step = 300
max_dec_step = -1
train_data, valid_data, test_data = preprocess(file_path, 
                                               train_size, 
                                               valid_size, 
                                               test_size,
                                               max_enc_step, 
                                               max_dec_step)


100%|██████████| 92579/92579 [00:51<00:00, 1810.94it/s]


Number of train examples: 90000
Number of valid examples: 1579
Number of test examples: 1000


In [12]:
class Vocab():
   
    def __init__(self, data, max_size=80000, min_freq=4):
        self.PAD = 0
        self.SOS = 1
        self.EOS = 2
        self.UNK = 3
        self.index2token = {0: '<PAD>', 1: '<SOS>', 2: '<EOS>', 3: '<UNK>'}
        self.token2index = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
        self.token2count = {}
        self.n_tokens = 4
        self.max_size = max_size
        self.min_freq = min_freq

        print("Read {} example pairs".format(len(data)))
        for ex in data:
            self.add_sent(ex['src'])
            self.add_sent(ex['trg'])
        print('Number of vocab: {}'.format(self.n_tokens))

        vocab_count = list(self.token2count.items()) 
        vocab_count = sorted(vocab_count, key=lambda x: x[1], reverse=True) # Sort by counts (descending)
        if max_size-4 < len(vocab_count):
            vocab_count = vocab_count[:max_size-4]

        self.token2index = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
        self.token2count = {}
        self.index2token = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.n_tokens = 4

        for token, count in vocab_count:
            if count >= min_freq: 
                self.token2index[token] = self.n_tokens
                self.token2count[token] = count
                self.index2token[self.n_tokens] = token
                self.n_tokens += 1    
        print('Number of vocab left: {}'.format(self.n_tokens))

    def add_sent(self, sent):
        for token in sent.split():
            self.add_token(token)

    def add_token(self, token):
        if sum([char.isdigit() for char in token]) == 0:
            if token not in self.token2index:
                self.token2index[token] = self.n_tokens
                self.token2count[token] = 1
                self.index2token[self.n_tokens] = token
                self.n_tokens += 1
            else:
                self.token2count[token] += 1

    def __getitem__(self, item):
        if type(item) == str:
            return int(self.token2index.get(item, self.UNK)) # Return the index of <UNK> if input word is not in the vocab
        elif type(item) == int:
            return self.index2token[item]

def build_oovs(src_tokens):
    oovs = [token for token in src_tokens if vocab[token] == vocab.UNK]
    return oovs

def tokens2indices(tokens):
    indices = [vocab[token] for token in tokens]
    indices = [vocab.SOS] + indices + [vocab.EOS]
    return indices

def extended2indices(extended, oovs):
    indices = []
    for token in extended:
        idx = vocab[token]
        if idx == vocab.UNK:
            idx = vocab.n_tokens + oovs.index(token) if token in oovs else vocab.UNK
        indices.append(idx)
    indices = [vocab.SOS] + indices + [vocab.EOS]
    return indices

def indices2extended(indices, oovs):
    tokens = []
    for idx in indices:
        if idx > vocab.n_tokens-1:
            token = oovs[idx - vocab.n_tokens]
        else:
            token = vocab[idx]
        tokens.append(token)
    return tokens


vocab = Vocab(train_data)

Read 90000 example pairs
Number of vocab: 197195
Number of vocab left: 80000


In [13]:
class Example():
    def __init__(self, ex, max_src_len=100):

        self.src_sent = ex['src']
        self.trg_sent = ex['trg']
        self.src_tokens = self.src_sent.split() 
        self.trg_tokens = self.trg_sent.split()
        self.src = tokens2indices(self.src_tokens) # For train
        self.trg = tokens2indices(self.trg_tokens)[:-1] # For train
        self.oovs = build_oovs(self.src_tokens)
        self.src_extended_vocab = extended2indices(self.src_tokens, self.oovs) # For attn dist
        self.trg_extended_vocab = extended2indices(self.trg_tokens, self.oovs)[1:] # For eval
        self.src_len = len(self.src)
        self.trg_len = len(self.trg)

def pad_sent(sent, max_len):
    sent += [vocab.PAD]*(max_len - len(sent))
    return sent

class Batch():
    def __init__(self, example_list):
        self.batch_size = len(example_list)
        self.original_src = [example.src_sent for example in example_list]
        self.original_trg = [example.trg_sent for example in example_list]

        self.max_src_len = max([example.src_len for example in example_list])
        self.max_trg_len = max([example.trg_len for example in example_list])

        self.src = np.zeros([self.batch_size, self.max_src_len], dtype=np.int32) # For train
        self.src_mask = np.zeros([self.batch_size, self.max_src_len], dtype=np.int32) # For train
        self.trg = np.zeros([self.batch_size, self.max_trg_len], dtype=np.int32) # For loss calculation
        self.src_len = np.zeros([self.batch_size], dtype=np.float32) # For train
        self.trg_len = np.zeros([self.batch_size], dtype=np.float32) # For train
        self.src_extended_vocab = np.zeros([self.batch_size, self.max_src_len], dtype=np.int32) # For attn dist
        self.trg_extended_vocab = np.zeros([self.batch_size, self.max_trg_len], dtype=np.int32) # For eval
        self.max_oovs_len = max([len(example.oovs) for example in example_list]) # For attn dist
        self.oovs = [example.oovs for example in example_list] # For attn dist
        
        for i, example in enumerate(example_list):

            self.src[i] = pad_sent(example.src, self.max_src_len)
            self.src_mask[i] = pad_sent([1]*(example.src_len), self.max_src_len)
            self.trg[i] = pad_sent(example.trg, self.max_trg_len)
            self.src_len[i] = example.src_len 
            self.trg_len[i] = example.trg_len 
            self.src_extended_vocab[i, :] = pad_sent(example.src_extended_vocab, self.max_src_len)
            self.trg_extended_vocab[i, :] = pad_sent(example.trg_extended_vocab, self.max_trg_len)

        self.src = torch.from_numpy(self.src).long().to(device)
        self.src_mask = torch.from_numpy(self.src_mask).long().to(device)
        self.trg = torch.from_numpy(self.trg).long().to(device)
        self.src_len = torch.from_numpy(self.src_len).long().to(device)
        self.trg_len = torch.from_numpy(self.trg_len).long().to(device)
        self.src_extended_vocab = torch.from_numpy(self.src_extended_vocab).long().to(device)
        self.trg_extended_vocab = torch.from_numpy(self.trg_extended_vocab).long().to(device)
        self.oov_pad = torch.zeros(self.batch_size, self.max_oovs_len).to(device) if self.max_oovs_len > 0 else None # [batch_size, oov_len]


def batchify(data, batch_size):
    examples = [Example(ex) for ex in data]
    examples = sorted(examples, key=lambda example: example.src_len, reverse=True)
    examples = [examples[i:i+batch_size] for i in range(0, len(examples), batch_size)]
    examples = [Batch(batch) for batch in examples]
    return examples

class Dataset():
    def __init__(self, data, batch_size):
        self.batches = batchify(data, batch_size)
    
    def process_data(self):
        for batch in self.batches:
            yield batch
        
    def get_stream(self):
        return cycle(self.process_data())

    def __iter__(self):
        return self.get_stream()

In [14]:
class Encoder(nn.Module):
    def __init__(self, emb_dim, hid_dim):
        super(Encoder, self).__init__()
        self.lstm = nn.LSTM(emb_dim, hid_dim, bidirectional=True, batch_first=True)
        self.reduce_hidden = nn.Linear(hid_dim*2, hid_dim, bias=False)
        self.reduce_cell = nn.Linear(hid_dim*2, hid_dim, bias=False)
    def forward(self, embedded, src_len): 
        packed = pack_padded_sequence(embedded, src_len, batch_first=True) 
        output, (hidden, cell) = self.lstm(packed)
        output, _ = pad_packed_sequence(output, batch_first=True)
        hidden = F.relu(self.reduce_hidden(torch.cat((hidden[-2, :, : ], hidden[-1, :, : ]), dim=1))).unsqueeze(0) # [1, batch_size, hid_dim]
        cell = F.relu(self.reduce_cell(torch.cat((cell[-2, :, : ], cell[-1, :, : ]), dim=1))).unsqueeze(0) # [1, batch_size, hid_dim]
        return output, hidden, cell

class Attention(nn.Module):
    def __init__(self, hid_dim):
        super(Attention, self).__init__()
        self.W_h = nn.Linear(2*hid_dim, 2*hid_dim, bias=False)
        self.W_s = nn.Linear(2*hid_dim, 2*hid_dim, bias=False)
        self.w_c = nn.Linear(1, 2*hid_dim, bias=False)
        self.v = nn.Linear(2*hid_dim, 1, bias=False)
    def forward(self, output, hidden, cell, src_mask, coverage):
        hidden = torch.cat((hidden, cell), dim=2).permute(1, 0, 2) # [batch_size, src_len, 2*hid_dim]
        energy = self.v(torch.tanh(self.W_h(output) + self.W_s(hidden) + self.w_c(coverage.unsqueeze(2)))) # [batch_size, src_len, 2*hid_dim] EQ.11
        attn_dist = F.softmax(energy.squeeze(2), dim=1) # [batch_size, src_len]
        attn_dist = attn_dist * src_mask # [batch_size, src_len]
        attn_dist = attn_dist / attn_dist.sum(1, keepdim=True)
        context = torch.bmm(attn_dist.unsqueeze(1), output).squeeze(1) # [batch_size, 2*hid_dim]
        coverage = coverage.clone() + attn_dist # [batch_size, src_len] EQ.10
        return attn_dist, context, coverage

class Pointer_Generator(nn.Module):
    def __init__(self, emb_dim, hid_dim, output_dim):
        super(Pointer_Generator, self).__init__()
        self.ptr = nn.Linear(emb_dim+4*hid_dim, 1, bias=True)
        self.V1 = nn.Linear(emb_dim+4*hid_dim, hid_dim, bias=True)
        self.V2 = nn.Linear(hid_dim, output_dim, bias=True)
    def forward(self, embedded, hidden, cell, context, attn_dist, src_extended_vocab, oov_pad):

        hidden = torch.cat((hidden, cell), dim=2).squeeze(0) # [batch_size, 2*hid_dim]

        # Generation probability
        gen_prob = torch.sigmoid(self.ptr(torch.cat((embedded, hidden, context), dim=1))) # [batch_size, 1]

        # Vocabulary distribution 
        vocab_dist = F.softmax(self.V2(self.V1(torch.cat((embedded, hidden, context), dim=1))), dim=1) # [batch_size, output_dim]
        vocab_dist = gen_prob * vocab_dist # [batch_size, output_dim+oov_len]
        if oov_pad != None:
            vocab_dist = torch.cat((vocab_dist, oov_pad), dim=1) # [batch_size, output_dim+oov_len]
            
        # Attention distribution
        attn_dist = (1-gen_prob) * attn_dist # [batch_size, src_len]

        # Word distribution
        word_dist = vocab_dist.scatter_add(1, src_extended_vocab, attn_dist) # [batch_size, output_dim+oov_len]
        return word_dist

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim):
        super(Decoder, self).__init__()
        self.attention = Attention(hid_dim)
        self.lstm = nn.LSTM(emb_dim+2*hid_dim, hid_dim, bidirectional=False, batch_first=True)
        self.pointer_generator = Pointer_Generator(emb_dim, hid_dim, output_dim)
    def forward(self, src_mask, embedded, hidden, cell, output, src_extended_vocab, oov_pad, coverage):

        attn_dist, context, coverage = self.attention(output, hidden, cell, src_mask, coverage)
        decoder_input = torch.cat((embedded, context), dim=1).unsqueeze(1) # [batch_size, 1, emb_dim+2*hid_dim]
        _, (hidden, cell) = self.lstm(decoder_input, (hidden, cell))
        word_dist = self.pointer_generator(embedded, hidden, cell, context, attn_dist, src_extended_vocab, oov_pad) # [batch_size, output_dim+oov_len]
        return word_dist, hidden, cell, attn_dist, coverage

class Param():
    def __init__(self):
        self.vocab_dim = vocab.n_tokens
        self.emb_dim = 300
        self.hid_dim = 300
        self.lr = 0.001
        self.n_iters = 100000
        self.batch_size = 16

class Model(nn.Module):
    def __init__(self, param):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(param.vocab_dim, param.emb_dim).to(device)
        self.encoder = Encoder(param.emb_dim, param.hid_dim).to(device)
        self.decoder = Decoder(param.vocab_dim, param.emb_dim, param.hid_dim).to(device)

param = Param()
model = Model(param)
model

Model(
  (embedding): Embedding(80000, 300)
  (encoder): Encoder(
    (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
    (reduce_hidden): Linear(in_features=600, out_features=300, bias=False)
    (reduce_cell): Linear(in_features=600, out_features=300, bias=False)
  )
  (decoder): Decoder(
    (attention): Attention(
      (W_h): Linear(in_features=600, out_features=600, bias=False)
      (W_s): Linear(in_features=600, out_features=600, bias=False)
      (w_c): Linear(in_features=1, out_features=600, bias=False)
      (v): Linear(in_features=600, out_features=1, bias=False)
    )
    (lstm): LSTM(900, 300, batch_first=True)
    (pointer_generator): Pointer_Generator(
      (ptr): Linear(in_features=1500, out_features=1, bias=True)
      (V1): Linear(in_features=1500, out_features=300, bias=True)
      (V2): Linear(in_features=300, out_features=80000, bias=True)
    )
  )
)

In [15]:
def calc_loss(vocab_dist, trg_step, attn_dist, coverage):
    probs = torch.gather(vocab_dist, 1, trg_step.unsqueeze(1)).squeeze()
    step_loss = -torch.log(probs + 1e-9) 
    covloss = torch.sum(torch.min(attn_dist, coverage), dim=1) 
    return step_loss + covloss

def calc_rouge(refs, cans):
    rouge_1 = 0
    rouge_2 = 0
    rouge_3 = 0
    count = 0
    for ref, can in zip(refs, cans):
        assert type(ref) == type(can) == str
        ref = ref.split()
        can = can.split()
        
        rouge_1 += rouge_n([[ref]], [can], 1)
        rouge_2 += rouge_n([[ref]], [can], 2)
        rouge_3 += rouge_n([[ref]], [can], 3)
        count += 1
    
    rouge_1 = rouge_1*100 / count
    rouge_2 = rouge_2*100 / count
    rouge_3 = rouge_3*100 / count

    return rouge_1, rouge_2, rouge_3

def get_ngram(lst, n):
    ngram = [tuple(lst[i:i+n]) for i in range(len(lst) - n + 1)]
    return ngram

def divide(numerator, denominator):
    if np.min(denominator) > 0:
        return numerator / denominator
    else:
        return 0.0

def match(ref, sys):
    intersection = set(ref) & set(sys)
    return len(intersection)

def single_rouge_n(ref, sys, N):
    ref_iterator = get_ngram(ref, N)
    sys_iterator = get_ngram(sys, N)
    match_count = match(ref_iterator, sys_iterator)
    rec_count = len(ref_iterator)
    
    return match_count, rec_count
    
def multi_rouge_n(refs, sys, N):
    temp_rouge = np.zeros([len(refs), 2])
    for i in range(len(refs)):
        temp_rouge[i, :] = single_rouge_n(refs[i], sys, N)
    rogue = divide(temp_rouge[:, 0], temp_rouge[:, 1])
    ind = np.argmax(rogue)
    
    match_count, rec_count = temp_rouge[ind, ]
    return match_count, rec_count
    

def rouge_n(reference_summaries, candidate_summaries, N):
    # https://rxnlp.com/how-rouge-works-for-evaluation-of-summarization-tasks/#.XuWlF0VKhPa
    match_counts = 0.0
    rec_counts = 0.0
    
    for (refs, can) in zip(reference_summaries, candidate_summaries):
        match_count, rec_count = multi_rouge_n(refs, can, N)
        match_counts += match_count
        rec_counts += rec_count
        
    score = divide(match_counts, rec_counts)
    return score

In [16]:
def train_batch(batch, model):
    src = batch.src
    src_mask = batch.src_mask
    trg = batch.trg
    src_extended_vocab = batch.src_extended_vocab
    trg_extended_vocab = batch.trg_extended_vocab
    src_len = batch.src_len
    trg_len = batch.trg_len
    oov_pad = batch.oov_pad
    coverage = torch.zeros(src.size()).to(device)

    embedded = model.embedding(src)
    output, hidden, cell = model.encoder(embedded, src_len)

    step_losses = []
    for t in range(trg.size(1)):
        embedded = model.embedding(trg[:, t]) # [batch_size, trg_len, emb_dim]
        word_dist, hidden, cell, attn_dist, coverage = model.decoder(src_mask,
                                                                     embedded, hidden, cell, output, 
                                                                     src_extended_vocab, oov_pad,
                                                                     coverage)     
        step_loss = calc_loss(word_dist, trg_extended_vocab[:, t], coverage, attn_dist)
        step_losses.append(step_loss)
    step_losses = torch.sum(torch.stack(step_losses, dim=1), dim=1)
    batch_loss = torch.mean(step_losses/trg_len)
    return batch_loss

def eval_batch(batch, model):
    src = batch.src
    src_mask = batch.src_mask
    trg = batch.trg
    src_extended_vocab = batch.src_extended_vocab
    trg_extended_vocab = batch.trg_extended_vocab
    src_len = batch.src_len
    trg_len = batch.trg_len
    oov_pad = batch.oov_pad
    coverage = torch.zeros(src.size()).to(device)

    embedded = model.embedding(src)
    output, hidden, cell = model.encoder(embedded, src_len)

    step_losses = []
    for t in range(trg.size(1)):
        embedded = model.embedding(trg[:, t]) # [batch_size, trg_len, emb_dim]
        word_dist, hidden, cell, attn_dist, coverage = model.decoder(src_mask,
                                                                     embedded, hidden, cell, output, 
                                                                     src_extended_vocab, oov_pad,
                                                                     coverage)        
        step_loss = calc_loss(word_dist, trg_extended_vocab[:, t], coverage, attn_dist)
        step_losses.append(step_loss)
    step_losses = torch.sum(torch.stack(step_losses, dim=1), dim=1)
    batch_loss = torch.mean(step_losses/trg_len)
    return batch_loss

def eval_batches(batches, model):
    valid_losses = []
    with torch.no_grad():
        for batch in iter(batches):
            valid_loss = eval_batch(batch, model)
            valid_losses.append(valid_loss.item())
        return np.mean(valid_losses)

def infer_batch(batch, model):
    src = batch.src
    src_mask = batch.src_mask
    trg = batch.trg
    src_extended_vocab = batch.src_extended_vocab
    trg_extended_vocab = batch.trg_extended_vocab
    src_len = batch.src_len
    trg_len = batch.trg_len
    oov_pad = batch.oov_pad
    coverage = torch.zeros(src.size()).to(device)

    embedded = model.embedding(src) # [batch_size, src_len, emb_dim]
    output, hidden, cell = model.encoder(embedded, src_len)
    pred_trgs = []
    pred_trg = torch.tensor([vocab.SOS]).to(device) # initial decoder input [batch_size]
    for t in range(100):
        embedded = model.embedding(pred_trg) # [batch_size, trg_len, emb_dim]
        word_dist, hidden, cell, attn_dist, coverage = model.decoder(src_mask,
                                                                     embedded, hidden, cell, output, 
                                                                     src_extended_vocab, oov_pad,
                                                                     coverage)        
        pred_trg = word_dist.argmax(1) # [batch_size]
        pred_idx = pred_trg.item()
        pred_trgs.append(pred_idx)
        if pred_trgs[-1] == vocab.EOS:
            break
        elif pred_trgs[-1] > vocab.n_tokens-1:
            pred_trg = torch.tensor([vocab.UNK]).to(device)
    
    pred_trgs = indices2extended(pred_trgs, batch.oovs[0])
    pred_trgs = ' '.join(pred_trgs)
    return pred_trgs

def infer_batches(batches, model, save=False):
    refs = []
    cans = []
    srcs = []
    with torch.no_grad():
        for batch in batches:
            src = batch.original_src[0]
            ref = batch.original_trg[0]
            can = infer_batch(batch, model)
            if can.split()[-1] == vocab[vocab.EOS]:
                can = can.split()
                can = can[:-1]
                can = ' '.join(can)
            refs.append(ref)
            cans.append(can)
            srcs.append(src)
    score, _, _ = calc_rouge(refs, cans)
    if save == True:
        df = pd.DataFrame({'src': srcs, 'reference': refs, 'candidate': cans})
        df.to_csv(path_or_buf='model_7.csv', index=False)
    return score

In [17]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Embedding has {count_parameters(model.embedding):,} trainable parameters')
print(f'Encoder has {count_parameters(model.encoder):,} trainable parameters')
print(f'Decoder has {count_parameters(model.decoder):,} trainable parameters')

Embedding has 24,000,000 trainable parameters
Encoder has 1,804,800 trainable parameters
Decoder has 26,695,401 trainable parameters


In [9]:
import time

train_dataset = Dataset(train_data, param.batch_size)
valid_dataset = batchify(valid_data, param.batch_size)
test_dataset = batchify(test_data, 1)

start_time = time.time()
best_valid_loss = float('inf')

train_losses = []
optimizer = optim.Adam(model.parameters(), lr=param.lr)

for i, batch in enumerate(train_dataset):
    optimizer.zero_grad()
    train_loss = train_batch(batch, model)
    train_losses.append(train_loss.item())
    train_loss.backward()
    optimizer.step()

    if i % 500 == 0:
        valid_loss = eval_batches(valid_dataset, model)
        avg_train_loss = np.mean(train_losses)
        elapsed_time = time.time() - start_time
        print(f'iter: {i:06}\ttrain_loss: {avg_train_loss:.4f}\tvalid_loss: {valid_loss:.4f}\telapsed_time: {elapsed_time:.6f}')

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'model_7.w')

iter: 000000	train_loss: 14.0277	valid_loss: 10.5289	elapsed_time: 13.004163
iter: 000500	train_loss: 6.8063	valid_loss: 6.2893	elapsed_time: 253.597062
iter: 001000	train_loss: 6.5011	valid_loss: 5.9967	elapsed_time: 495.354018
iter: 001500	train_loss: 6.3298	valid_loss: 5.8253	elapsed_time: 735.711740
iter: 002000	train_loss: 6.2164	valid_loss: 5.7266	elapsed_time: 976.529805
iter: 002500	train_loss: 6.1326	valid_loss: 5.6240	elapsed_time: 1219.047491
iter: 003000	train_loss: 6.0572	valid_loss: 5.5612	elapsed_time: 1460.310432
iter: 003500	train_loss: 5.9941	valid_loss: 5.4996	elapsed_time: 1702.529197
iter: 004000	train_loss: 5.9422	valid_loss: 5.4509	elapsed_time: 1944.225616
iter: 004500	train_loss: 5.8952	valid_loss: 5.4143	elapsed_time: 2184.745651
iter: 005000	train_loss: 5.8504	valid_loss: 5.3991	elapsed_time: 2425.976974
iter: 005500	train_loss: 5.7446	valid_loss: 5.4357	elapsed_time: 2640.832815
iter: 006000	train_loss: 5.6974	valid_loss: 5.3753	elapsed_time: 2864.414781
ite

KeyboardInterrupt: ignored

In [None]:
model.load_state_dict(torch.load('model_7.w'))
infer_batches(test_dataset, model, save=True)