In [1]:
import sys
sys.path.insert(0, ".") 

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Setup

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import nltk
import torchtext
from torchtext.data import Example, Field, BucketIterator, TabularDataset, Iterator
from tqdm import tqdm, tnrange, tqdm_notebook, trange
import numpy as np
from __future__ import print_function

from model import Encoder, Decoder

In [4]:
torch.__version__

'0.3.0'

In [5]:
import psutil
import humanize
import os
import GPUtil as GPU
GPUs = GPU.getGPUs()
# XXX: only one GPU on Colab and isn’t guaranteed
gpu = GPUs[0]
process = psutil.Process(os.getpid())
print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
!nvidia-smi

Gen RAM Free: 30.5 GB  | Proc size: 161.5 MB
GPU RAM Free: 12206MB | Used: 0MB | Util   0% | Total 12206MB
Thu Mar 29 14:43:06 2018       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.81                 Driver Version: 384.81                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX TIT...  Off  | 00000000:03:00.0 Off |                  N/A |
|  0%   66C    P0    63W / 250W |      0MiB / 12206MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                 

In [6]:
PATH="/diskA/jethro/cnn"

In [7]:
os.listdir(PATH)

['train_small.tsv',
 'valid.feather',
 'test.feather',
 'test_small.tsv',
 'encoder_small.model',
 'train.pkl',
 'valid_small.tsv',
 'raw',
 'decoder.model',
 'train.tsv',
 'encoder.model',
 'decoder_small.model',
 'stories.feather',
 'test.tsv',
 'valid.tsv',
 'train.feather']

In [8]:
TEXT = Field(tokenize = nltk.word_tokenize, use_vocab = True, init_token = "<s>", eos_token = "<e>", lower = True, include_lengths = True, batch_first = True)

In [9]:
USE_CUDA = True
USE_SMALL_DATASET = True

In [10]:
if USE_SMALL_DATASET:
    train_ds, test_ds, valid_ds = f'{PATH}/train_small.tsv', f'{PATH}/test_small.tsv', f'{PATH}/valid_small.tsv'
else:
    train_ds, test_ds, valid_ds = f'{PATH}/train.tsv', f'{PATH}/test.tsv', f'{PATH}/valid.tsv'

In [11]:
train_data = TabularDataset(path=train_ds,
                            format='tsv',
                            fields=[('input',TEXT), ('target',TEXT)])
test_data = TabularDataset(path=test_ds,
                            format='tsv',
                            fields=[('input',TEXT), ('target',TEXT)])
valid_data = TabularDataset(path=valid_ds,
                           format='tsv',
                           fields=[('input', TEXT), ('target', TEXT)])

In [12]:
TEXT.build_vocab(train_data, test_data, valid_data, min_freq=2)

tqdm.write("Vocabulary size: {}".format(len(TEXT.vocab)))

Vocabulary size: 19194


In [13]:
BATCH_SIZE = 16
train_loader = BucketIterator(train_data,batch_size=BATCH_SIZE, device=None,
                              sort_key=lambda x: len(x.input),sort_within_batch=True,
                              repeat=False,shuffle=True)
test_loader  = BucketIterator(test_data,batch_size=1, device=None,
                              sort_key=lambda x: len(x.input),sort_within_batch=True,
                              repeat=False,shuffle=True)
valid_loader = BucketIterator(valid_data,batch_size=BATCH_SIZE, device=None,
                              sort_key=lambda x: len(x.input),sort_within_batch=True,
                              repeat=False,shuffle=True)
 # May be slightly less due to skipping empty stories
tqdm.write("Number of training stories: {}".format(len(train_data)))
tqdm.write("Number of testing stories: {}".format(len(test_data)))
tqdm.write("Number of validation stories: {}".format(len(valid_data)))

Number of training stories: 3000
Number of testing stories: 1000
Number of validation stories: 1000


In [50]:
EMBED = 300

if USE_SMALL_DATASET:
    HIDDEN = 50
else:
    HIDDEN = 200
    
    
VOCAB_SIZE = len(TEXT.vocab)
LR = 1e-4

encoder = Encoder(VOCAB_SIZE,EMBED,HIDDEN,bidirec=True)
decoder = Decoder(VOCAB_SIZE,EMBED,HIDDEN*2)

if USE_CUDA:
    tqdm.write("Using CUDA")
    if torch.cuda.device_count() > 1:
        print("Using %d devices" % (torch.cuda.device_count()))
        encoder = nn.DataParallel(encoder)
        decoder = nn.DataParallel(decoder)
    encoder = encoder.cuda()
    decoder = decoder.cuda()
decoder.embedding = encoder.embedding

Using CUDA


In [86]:
encoder

Encoder(
  (dropout): Dropout(p=0.5)
  (embedding): Embedding(19194, 300)
  (lstm): LSTM(300, 50, batch_first=True, bidirectional=True)
)

In [87]:
decoder

Decoder(
  (embedding): Embedding(19194, 300)
  (dropout): Dropout(p=0.3)
  (lstm): LSTM(300, 100, batch_first=True)
  (linear): Linear(in_features=300, out_features=19194)
  (dec_attention): Attention(
    (attn): Linear(in_features=100, out_features=100)
  )
  (enc_attention): IntraTempAttention(
    (attn): Linear(in_features=100, out_features=100)
  )
)

In [88]:
loss_function = nn.CrossEntropyLoss(ignore_index=TEXT.vocab.stoi['<pad>'])
enc_optim = optim.Adam(encoder.parameters(),lr=LR)
dec_optim = optim.Adam(decoder.parameters(),lr=LR)

In [89]:
if USE_SMALL_DATASET:
    APPEND = "_small"
else:
    APPEND = ""
ENCODER_MODEL_PATH = f'{PATH}/encoder{APPEND}.model'
DECODER_MODEL_PATH = f'{PATH}/decoder{APPEND}.model'

def load_models():
    encoder.load_state_dict(torch.load(ENCODER_MODEL_PATH))
    decoder.load_state_dict(torch.load(DECODER_MODEL_PATH))
    
def save_models():
    torch.save(encoder.state_dict(), ENCODER_MODEL_PATH)
    torch.save(decoder.state_dict(), DECODER_MODEL_PATH)

In [90]:
def train(train_loader):
    global encoder
    global decoder
    encoder = encoder.train()
    decoder = decoder.train()
    total_loss, total_squared_loss, num_batches = 0.0, 0.0, 0
    for batch in tqdm_notebook(train_loader, desc="Training Batches", leave=False):
        inputs,lengths = batch.input
        targets,_ = batch.target
        decoding_start = Variable(torch.LongTensor([TEXT.vocab.stoi['<s>']]*targets.size(0))).unsqueeze(1)
        if USE_CUDA:
            inputs = inputs.cuda()
            targets = targets.cuda()
            decoding_start = decoding_start.cuda()

        encoder.zero_grad()
        decoder.zero_grad()
        output,hidden = encoder(inputs,lengths.tolist())
        score = decoder(decoding_start,hidden,targets.size(1),output,lengths)

        loss = loss_function(score,targets.view(-1))
        total_loss += loss.data[0]
        total_squared_loss += loss.data[0]**2
        num_batches += 1
        loss.backward()
        enc_optim.step()
        dec_optim.step()
    loss_mean = total_loss / num_batches
    loss_variance = (total_squared_loss - (total_loss**2 / num_batches)) / (num_batches - 1)
    tqdm.write("Training: loss mean: %7.4f, loss variance: %7.4f" % (loss_mean, loss_variance))
    return loss_mean, loss_variance

In [91]:
def calculate_validation_loss(valid_loader):
    global encoder
    global decoder
    encoder = encoder.eval()
    decoder = decoder.eval()
    total_loss, total_squared_loss, num_batches = 0.0, 0.0, 0
    for batch in tqdm_notebook(valid_loader, desc="Validation Batches", leave=False):
        inputs,lengths = batch.input
        targets,_ = batch.target
        decoding_start = Variable(torch.LongTensor([TEXT.vocab.stoi['<s>']]*targets.size(0))).unsqueeze(1)
        if USE_CUDA:
            inputs = inputs.cuda()
            targets = targets.cuda()
            decoding_start = decoding_start.cuda()
        output,hidden = encoder(inputs,lengths.tolist())
        score = decoder(decoding_start,hidden,targets.size(1),output,lengths)

        loss = loss_function(score,targets.view(-1))
        total_loss += loss.data[0]
        total_squared_loss += loss.data[0]**2
        num_batches += 1
    loss_mean = total_loss / num_batches
    loss_variance = (total_squared_loss - (total_loss**2 / num_batches)) / (num_batches - 1)
    tqdm.write("Validation: loss mean: %7.4f, loss variance: %7.4f" % (loss_mean, loss_variance))
    return loss_mean, loss_variance

In [92]:
from tensorboardX import SummaryWriter
writer = SummaryWriter()    

In [93]:
def write_to_tensorboard(writer, epoch, train_loss, valid_loss):
     writer.add_scalars('data/loss',
                        {'train': train_loss,
                         'valid': valid_loss},
                        epoch)

In [94]:
NUM_EPOCHS = 1000
for epoch_idx in tnrange(NUM_EPOCHS, desc="Epochs", unit="epoch"):
    train_loss, train_variance = train(train_loader)
    valid_loss, valid_variance = calculate_validation_loss(valid_loader)
    write_to_tensorboard(writer, epoch_idx, train_loss, valid_loss)
    save_models()

Training: loss mean:  6.5954, loss variance:  0.0443


Validation: loss mean:  6.9516, loss variance:  0.0343


Training: loss mean:  6.4528, loss variance:  0.0412


Validation: loss mean:  7.0170, loss variance:  0.0709


Training: loss mean:  6.3918, loss variance:  0.0414


Validation: loss mean:  7.0239, loss variance:  0.0780


Training: loss mean:  6.3470, loss variance:  0.0459


Validation: loss mean:  7.0344, loss variance:  0.0698


Training: loss mean:  6.3046, loss variance:  0.0427


Validation: loss mean:  7.0658, loss variance:  0.0535


Training: loss mean:  6.2687, loss variance:  0.0483


Validation: loss mean:  7.1013, loss variance:  0.0630


Training: loss mean:  6.2345, loss variance:  0.0494


Validation: loss mean:  7.1272, loss variance:  0.0837


Training: loss mean:  6.2044, loss variance:  0.0455


Validation: loss mean:  7.1633, loss variance:  0.0709


KeyboardInterrupt: 

# Getting the Summaries

In [80]:
from rouge import ROUGE
from __future__ import print_function
rouge = ROUGE()

def get_string(summary):
    result = " ".join([TEXT.vocab.itos[idx] for idx in summary])
    return result

def show_selection_of_output(loader, num_to_show, num_to_calculate):
    global encoder
    global decoder
    total_rouge_score = {"rouge-1": {"recall": 0.0, "precision": 0.0},
                         "rouge-2": {"recall": 0.0, "precision": 0.0}}
    encoder = encoder.eval()
    decoder = decoder.eval()
    for i, batch in enumerate(loader):
        if i == num_to_calculate:
            break
        inputs, lengths = batch.input
        targets, _ = batch.target
        decoding_start = Variable(torch.LongTensor([TEXT.vocab.stoi['<s>']]*targets.size(0))).unsqueeze(1)
        if USE_CUDA:
            inputs = inputs.cuda()
            targets = targets.cuda()
            decoding_start = decoding_start.cuda()

        output,hidden = encoder(inputs,lengths.tolist())
        score = decoder(decoding_start, hidden, targets.size(1), output, lengths)

        reference_summary = targets.data.cpu().numpy()[0]
        generated_summary = [np.argmax(word) for word in score.data.cpu().numpy()]

        reference = get_string(reference_summary)
        generated = get_string(generated_summary)

        rouge_score = rouge.score(reference, generated)
        
        total_rouge_score["rouge-1"]["recall"] += rouge_score["rouge-1"]["recall"]
        total_rouge_score["rouge-1"]["precision"] += rouge_score["rouge-1"]["precision"]
        total_rouge_score["rouge-2"]["recall"] += rouge_score["rouge-2"]["recall"]
        total_rouge_score["rouge-2"]["precision"] += rouge_score["rouge-2"]["precision"]

        if i < num_to_show:
            print("\nReference summary:\n{}".format(reference))
            print("\nGenerated summary:\n{}".format(generated))
            print("\nROUGE score: {}\n".format(rouge_score))
        
    total_rouge_score["rouge-1"]["recall"] /= num_to_show
    total_rouge_score["rouge-1"]["precision"] /= num_to_show
    total_rouge_score["rouge-2"]["recall"] /= num_to_show
    total_rouge_score["rouge-2"]["precision"] /= num_to_show
    print("Mean ROUGE score: {}\n".format(total_rouge_score))
 

In [81]:
show_selection_of_output(train_loader, 5, 5)


Reference summary:
<s> romney and perry entered as frontrunners and left as frontrunners , said david gergen <e> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>

Generated summary:
<s> <s> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <s> <s> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <s> <s> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <s> <s> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <s> <s> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <s> <s> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <s> <s> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <s> <s> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e> <e

In [24]:
from beam import BeamSearch
beam_search = BeamSearch()

BEAM_WIDTH  = 5
BEAM_DEPTH  = 10
NUM_TO_SHOW = 3

for i, batch in enumerate(test_loader):
    if i == NUM_TO_SHOW:
        break
    inputs, lengths = batch.input
    targets, _ = batch.target
    if USE_CUDA:
        inputs = inputs.cuda()
        targets = targets.cuda()
    outputs, hidden = encoder(inputs, lengths.tolist())
    cell = decoder.init_context(inputs.size(0))
    
    best_sequence, top_candidates = beam_search.get_words(hidden, cell, TEXT.vocab.stoi["<s>"], outputs, lengths, decoder, TEXT.vocab, BEAM_WIDTH, BEAM_DEPTH)
    
    
    target_summary = get_string(targets.data.cpu().numpy()[0])
        
    print("Article: {}".format(" ".join([TEXT.vocab.itos[idx] for idx in inputs.cpu().data[0]])))
    print("Target summary: {}".format(target_summary))
    print("Best sequence: {}".format(" ".join(best_sequence)))
    print("Top candidates:")
    for candidate in top_candidates[1:]:
        print("\t{}".format(" ".join(candidate)))
    print("\n")
    break

Article: <s> port-au-prince , haiti in january 2010 a <unk> magnitude earthquake rocked haiti , killing more than 250,000 people and damaging its infrastructure , including some water systems . even before the quake , haiti 's water systems were fragile , and just months after the quake the country was hit with a devastating cholera outbreak -- the first in nearly a century . by the time the outbreak subsided , more than 8,000 people had died and hundreds of thousands more had become sick . independent studies suggest the outbreak was caused by u.n. peacekeepers who improperly disposed of <unk> <e>
Target summary: <s> haiti still recovering from cholera epidemic that left 8,000 dead <e>
Best sequence: <s> <s> <s> <s> <s> <s> <s> <s> <s> <s> <s>
Top candidates:
	<s> new <s> <s> <s> <s> <s> <s> <s> <s> <s>
	<s> `` <s> <s> <s> <s> <s> <s> <s> <s> <s>
	<s> the <s> <s> <s> <s> <s> <s> <s> <s> <s>
	<s> <s> new <s> <s> <s> <s> <s> <s> <s> <s>


