# Generator Model

In [1]:
import torch
from PICOHelper import get_pico_datasets
from NewsroomHelper import get_newsroom_datasets
from models import Summarizer
from model_helpers import loss_function, error_function
from utils import get_index_words, produce_attention_visualization_file, summarize, get_text_triplets, produce_batch_summary_files
from pytorch_helper import VariableBatchDataLoader, ModelManipulator, plot_learning_curves

## Parameters

In [2]:
# training parameters
BATCH_SIZE = 64
NUM_EPOCHS = 1
LEARNING_RATE = 1e-3
# INITIAL_ACCUMULATOR_VALUE = 0.1
GAMMA = 1
USE_CUDA = torch.cuda.is_available()
print(USE_CUDA)
BEAM_SIZE = 4

True


## Get Data

In [3]:
# pico_dataset_train, pico_dataset_dev, pico_dataset_test = get_pico_datasets()
newsroom_dataset_train, newsroom_dataset_dev, newsroom_dataset_test, preprocessor = get_newsroom_datasets()
word_vectors = preprocessor.word_vectors
start_index = preprocessor.word_indices['<start>']
end_index = preprocessor.word_indices['<end>']

11029 3676 3678
retrieving word2vec model from file


## Create Model

In [4]:
generator_model = Summarizer(preprocessor, start_index, end_index, num_hidden1=None, num_hidden2=None)

## Train and Save Model

In [None]:
dataloader = VariableBatchDataLoader(newsroom_dataset_train, batch_size=BATCH_SIZE, shuffle=True)
optimizer = torch.optim.Adam(generator_model.parameters(),
                             lr=LEARNING_RATE)
# optimizer = torch.optim.Adagrad((generator_model.cuda() if USE_CUDA else generator_model).parameters(),
#                                 lr=LEARNING_RATE, initial_accumulator_value=INITIAL_ACCUMULATOR_VALUE)
model_manip = ModelManipulator(generator_model, optimizer, loss_function, error_function, use_cuda=USE_CUDA)
train_stats, val_stats = model_manip.train(dataloader, NUM_EPOCHS, dataset_val=newsroom_dataset_dev, stats_every=10, verbose_every=10)

In [None]:
torch.save(generator_model, 'models/generator_temp.model')

In [5]:
generator_model = torch.load('models/generator_temp.model')

## Plot

In [None]:
plot_learning_curves(training_values=train_stats, validation_values=val_stats, figure_name='graphs/generator_training_temp')

In [6]:
batch = newsroom_dataset_train[0:4]
results = summarize(batch, generator_model, beam_size=BEAM_SIZE)

  output, (h, c) = self.lstm(x)


In [7]:
summary_info = results[0]
triplets = get_text_triplets(batch, summary_info, preprocessor)
for i,(text, reference_summary, decoded_summary) in enumerate(triplets):
    loss = summary_info[2][i]
    print("text", text)
    print("reference summary", reference_summary)
    print("decoded summary", decoded_summary)
    print(loss)

text ['<start>', 'tuesday', ',', 'march', 'qqq', ',', 'qqq', ',', 'qqq', 'pm', 'apple', "'s", 'ceo', 'tim', 'cook', 'finally', 'revealed', 'the', 'tech', 'specs', 'for', 'the', 'highly', 'anticipated', 'apple', 'watch', 'monday', 'in', 'san', 'francisco', ',', 'when', 'he', 'announced', 'the', 'smartwatch', 'will', 'begin', 'shipping', 'april', 'qqq', 'and', 'starts', 'at', '$', 'qqq', '.', 'the', 'apple', 'watch', 'will', 'come', 'in', 'two', 'different', 'retina', 'displays', ',', 'qqq', 'and', 'qqq', ',', 'and', 'it', "'ll", 'have', 'a', 'battery', 'that', 'lasts', 'up', 'to', 'qqq', 'hours', '.', 'the', 'watch', 'will', 'also', 'come', 'in', 'four', 'different', 'colors', 'and', 'three', 'different', 'models', ':', 'apple', 'watch', 'sport', ',', 'apple', 'watch', 'and', 'apple', 'watch', 'edition', '.', 'compared', 'to', 'other', 'watches', 'like', 'the', 'samsung', 'gear', 's', 'and', 'the', 'moto', 'qqq', ',', 'the', 'apple', 'watch', 'is', 'capable', 'of', 'doing', 'the', 'same

In [8]:
summary_info = results[0]
i = 0

triplets = get_text_triplets(batch, summary_info, preprocessor)

text, decoded_summary, reference_summary = triplets[i]
attentions, p_gens = [[float(f) for f in vector[1:-1]] for vector in summary_info[4][i][:-1]], [float(0) for j in range(len(decoded_summary)-1)]

produce_attention_visualization_file('graphs/attn_vis_data.json', text, decoded_summary, " ".join(reference_summary), attentions, p_gens)

In [9]:
produce_batch_summary_files(batch, generator_model, "data", beam_size=BEAM_SIZE)

4