In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import nltk
from torchtext.data import Field, BucketIterator, TabularDataset
from tqdm import tqdm, trange, tnrange, tqdm_notebook
import numpy as np
from __future__ import print_function

from model import Encoder, Decoder

In [2]:
TEXT = Field(tokenize = nltk.word_tokenize, use_vocab = True, init_token = "<s>", eos_token = "<e>", lower = True, include_lengths = True, batch_first = True)

In [3]:
tqdm.write("Loading training stories")
train = TabularDataset(path = "cnn_stories_cropped_train.txt",
                       format = 'tsv',
                       fields = [('input', TEXT), ('target', TEXT)])
tqdm.write("Loading testing stories")
test  = TabularDataset(path = "cnn_stories_cropped_test.txt",
                       format = 'tsv',
                       fields = [('input', TEXT), ('target', TEXT)])
tqdm.write("Loading validation stories")
valid = TabularDataset(path = "cnn_stories_cropped_valid.txt",
                       format = 'tsv',
                       fields = [('input', TEXT), ('target', TEXT)])
tqdm.write("Building vocabulary")
TEXT.build_vocab(train, test, valid, min_freq = 2)
tqdm.write("Vocabulary size: {}".format(len(TEXT.vocab)))

Loading training stories
Loading testing stories
Loading validation stories
Building vocabulary
Vocabulary size: 82850


In [4]:
BATCH_SIZE = 8
train_loader = BucketIterator(train,batch_size=BATCH_SIZE, device=None,
                              sort_key=lambda x: len(x.input),sort_within_batch=True,
                              repeat=False,shuffle=True)
test_loader  = BucketIterator(test,batch_size=BATCH_SIZE, device=None,
                              sort_key=lambda x: len(x.input),sort_within_batch=True,
                              repeat=False,shuffle=True)
valid_loader = BucketIterator(valid,batch_size=1, device=None,
                              sort_key=lambda x: len(x.input),sort_within_batch=True,
                              repeat=False,shuffle=True)
tqdm.write("Number of training stories: {}".format(len(train)))
tqdm.write("Number of testing stories: {}".format(len(test)))
tqdm.write("Number of validation stories: {}".format(len(valid)))

Number of training stories: 73972
Number of testing stories: 14794
Number of validation stories: 3699


In [5]:
HIDDEN_SIZE = 50
EMBED_SIZE  = 300
VOCAB_SIZE  = len(TEXT.vocab)
LEARN_RATE  = 0.001

In [6]:
encoder = Encoder(VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE, bidirec=True)
decoder = Decoder(VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE*2)

In [7]:
def load_models():
    encoder.load_state_dict(torch.load("./encoder.model"))
    decoder.load_state_dict(torch.load("./decoder.model"))
#load_models()

In [8]:
def save_models():
    torch.save(encoder.state_dict(), "./encoder.model")
    torch.save(decoder.state_dict(), "./decoder.model")

In [9]:
USE_CUDA = True
if USE_CUDA:
    tqdm.write("Using CUDA")
    if torch.cuda.device_count() > 1:
        print("Using %d devices" % (torch.cuda.device_count()))
        encoder = nn.DataParallel(encoder)
        decoder = nn.DataParallel(decoder)
    encoder = encoder.cuda()
    decoder = decoder.cuda()
decoder.embedding = encoder.embedding

Using CUDA


In [10]:
loss_function = nn.CrossEntropyLoss(ignore_index = TEXT.vocab.stoi['<pad>'])
enc_optim = optim.Adam(encoder.parameters(), lr = LEARN_RATE)
dec_optim = optim.Adam(decoder.parameters(), lr = LEARN_RATE)

In [None]:
NUM_EPOCHS          = 10
EPOCH_SAVE_INTERVAL = 1
for epoch_idx in tnrange(NUM_EPOCHS, desc = "Epochs", unit = "epoch"):
    encoder = encoder.train()
    decoder = decoder.train()
    total_loss, total_squared_loss, num_batches = 0.0, 0.0, 0
    for batch in tqdm_notebook(train_loader, desc = "Batches", unit = "batch"):
        inputs,lengths = batch.input
        targets,_ = batch.target
        decoding_start = Variable(torch.LongTensor([TEXT.vocab.stoi['<s>']]*targets.size(0))).unsqueeze(1)
        if USE_CUDA:
            inputs = inputs.cuda()
            targets = targets.cuda()
            decoding_start = decoding_start.cuda()

        encoder.zero_grad()
        decoder.zero_grad()
        output,hidden = encoder(inputs,lengths.tolist())
        score = decoder(decoding_start,hidden,targets.size(1),output,lengths)

        loss = loss_function(score,targets.view(-1))
        total_loss += loss.data[0]
        total_squared_loss += loss.data[0]**2
        num_batches += 1
        loss.backward()
        enc_optim.step()
        dec_optim.step()
    train_loss_mean = total_loss / num_batches
    train_loss_variance = (total_squared_loss - (total_loss**2 / num_batches)) / (num_batches - 1)
    
    encoder = encoder.eval()
    decoder = decoder.eval()
    total_loss, total_squared_loss, num_batches = 0.0, 0.0, 0
    for batch in tqdm_notebook(test_loader, desc = "Batches", unit = "batch"):
        inputs,lengths = batch.input
        targets,_ = batch.target
        decoding_start = Variable(torch.LongTensor([TEXT.vocab.stoi['<s>']]*targets.size(0))).unsqueeze(1)
        if USE_CUDA:
            inputs = inputs.cuda()
            targets = targets.cuda()
            decoding_start = decoding_start.cuda()
        output,hidden = encoder(inputs,lengths.tolist())
        score = decoder(decoding_start,hidden,targets.size(1),output,lengths)

        loss = loss_function(score,targets.view(-1))
        total_loss += loss.data[0]
        total_squared_loss += loss.data[0]**2
        num_batches += 1
    loss_mean = total_loss / num_batches
    loss_variance = (total_squared_loss - (total_loss**2 / num_batches)) / (num_batches - 1)
    tqdm.write("%3d Training --- loss mean: %7.4f, loss variance: %7.4f" % (epoch_idx + 1, train_loss_mean, train_loss_variance))
    tqdm.write("     Testing --- loss mean: %7.4f, loss variance: %7.4f" % (loss_mean, loss_variance))
    
    if (epoch_idx + 1) % EPOCH_SAVE_INTERVAL == 0:
        save_models()
        
save_models()

In [11]:
from rouge import ROUGE
from __future__ import print_function
rouge = ROUGE()

def get_string(summary):
    result = ""
    for idx in summary:
        if idx == TEXT.vocab.stoi["unk"]:
            continue
        elif idx in [TEXT.vocab.stoi["pad"], TEXT.vocab.stoi["<e>"]]:
            break
        result += TEXT.vocab.itos[idx] + " "
        
    return result

def show_selection_of_output(loader, num_to_show, num_to_calculate):
    global encoder
    global decoder
    total_rouge_score = {"rouge-1": {"recall": 0.0, "precision": 0.0},
                         "rouge-2": {"recall": 0.0, "precision": 0.0}}
    encoder = encoder.eval()
    decoder = decoder.eval()
    for i, batch in enumerate(loader):
        if i == num_to_calculate:
            break
        inputs, lengths = batch.input
        targets, _ = batch.target
        decoding_start = Variable(torch.LongTensor([TEXT.vocab.stoi['<s>']]*targets.size(0))).unsqueeze(1)
        if USE_CUDA:
            inputs = inputs.cuda()
            targets = targets.cuda()
            decoding_start = decoding_start.cuda()

        output,hidden = encoder(inputs,lengths.tolist())
        score = decoder(decoding_start,hidden,targets.size(1),output,lengths)

        reference_summary = targets.data.cpu().numpy()[0]
        generated_summary = [np.argmax(word) for word in score.data.cpu().numpy()[0]]

        reference = get_string(reference_summary)
        generated = get_string(generated_summary)

        rouge_score = rouge.score(reference, generated)
        
        total_rouge_score["rouge-1"]["recall"] += rouge_score["rouge-1"]["recall"]
        total_rouge_score["rouge-1"]["precision"] += rouge_score["rouge-1"]["precision"]
        total_rouge_score["rouge-2"]["recall"] += rouge_score["rouge-2"]["recall"]
        total_rouge_score["rouge-2"]["precision"] += rouge_score["rouge-2"]["precision"]

        if i < num_to_show:
            print("\nReference summary:\n{}".format(reference))
            print("\nGenerated summary:\n{}".format(generated))
            print("\nROUGE score: {}\n".format(rouge_score))
        
    total_rouge_score["rouge-1"]["recall"] /= num_to_show
    total_rouge_score["rouge-1"]["precision"] /= num_to_show
    total_rouge_score["rouge-2"]["recall"] /= num_to_show
    total_rouge_score["rouge-2"]["precision"] /= num_to_show
    print("Mean ROUGE score: {}\n".format(total_rouge_score))
    encoder = encoder.train()
    decoder = decoder.train()

In [12]:
print("Selection of training stories")
show_selection_of_output(train_loader, 5, 100)

Selection of training stories

Reference summary:
<s> north korea plans to erect a statue of kim jong il and build towers across the country 

Generated summary:


ROUGE score: {'rouge-2': {'recall': 0.0, 'precision': 0.0}, 'rouge-1': {'recall': 0.0, 'precision': 0.0}}


Reference summary:
<s> the bill would fight counterfeiting and piracy 

Generated summary:


ROUGE score: {'rouge-2': {'recall': 0.0, 'precision': 0.0}, 'rouge-1': {'recall': 0.0, 'precision': 0.0}}


Reference summary:
<s> ferrari aiming for perfection after errors in malaysian and bahrain grands prix 

Generated summary:


ROUGE score: {'rouge-2': {'recall': 0.0, 'precision': 0.0}, 'rouge-1': {'recall': 0.0, 'precision': 0.0}}


Reference summary:
<s> new : al-jazeera broadcast says tape in bin laden saying , ` iraq is perfect base ' 

Generated summary:


ROUGE score: {'rouge-2': {'recall': 0.0, 'precision': 0.0}, 'rouge-1': {'recall': 0.0, 'precision': 0.0}}


Reference summary:
<s> the `` myspace shot `` makes you

In [None]:
print("Selection of testing stories")
show_selection_of_output(test_loader, 5, 100)

In [None]:
print("Selection of validation stories")
show_selection_of_output(valid_loader, 5, 100)

In [13]:
from beam import BeamSearch
beam_search = BeamSearch()

BEAM_WIDTH  = 5
BEAM_DEPTH  = 10
NUM_TO_SHOW = 3

for i, batch in enumerate(valid_loader):
    if i == NUM_TO_SHOW:
        break
    inputs, lengths = batch.input
    targets, _ = batch.target
    if USE_CUDA:
        inputs = inputs.cuda()
        targets = targets.cuda()
    outputs, hidden = encoder(inputs, lengths.tolist())
    cell = decoder.init_context(inputs.size(0))
    
    best_sequence, top_candidates = beam_search.get_words(hidden, cell, TEXT.vocab.stoi["<s>"], outputs, lengths, decoder, TEXT.vocab, BEAM_WIDTH, BEAM_DEPTH)
    
    
    target_summary = get_string(targets.data.cpu().numpy()[0])
        
    print("Article: {}".format(" ".join([TEXT.vocab.itos[idx] for idx in inputs.cpu().data[0]])))
    print("Target summary: {}".format(target_summary))
    print("Best sequence: {}".format(" ".join(best_sequence)))
    print("Top candidates:")
    for candidate in top_candidates[1:]:
        print("\t{}".format(" ".join(candidate)))
    print("\n")
    break

Article: <s> april 23 , 2014 this wednesday , we cover subjects related to civics , science and animal behavior . our first two reports center on supreme court cases , and our third examines the challenges still facing search crews more than a month after a massive landslide in washington state . we also explore the lingering effects of a 2010 oil spill and show you how some people are recovering from that . on this page you will find today 's show transcript , the daily curriculum , and a place for you to leave feedback . transcript click here <e>
Target summary: <s> this page includes the show transcript and the daily curriculum 
Best sequence: <s> translatable wilhite fourth-placed vanover banderas nll hadnot constrain fitzwilliams cheated
Top candidates:
	<s> translatable wilhite fourth-placed vanover banderas nll hadnot constrain fitzwilliams pursuant
	<s> translatable wilhite fourth-placed vanover banderas nll hadnot constrain fitzwilliams deflected
	<s> translatable wilhite four

  scores = nn.Softmax()(scores)
