In [None]:
import torch
from PICOHelper import get_pico_datasets
from NewsroomHelper import get_newsroom_datasets
from SummarizationModelStructures import GeneratorModel, loss_function, error_function
from utils import get_index_words
from pytorch_helper import VariableBatchDataLoader, ModelManipulator, plot_learning_curves

# Parameters

In [None]:
# training parameters
BATCH_SIZE = 64
NUM_EPOCHS = 1
LEARNING_RATE = 1e-2
# INITIAL_ACCUMULATOR_VALUE = 0.1
GAMMA = 1
USE_CUDA = torch.cuda.is_available()
print(USE_CUDA)

# Get Data

In [None]:
# pico_dataset_train, pico_dataset_dev, pico_dataset_test = get_pico_datasets()
newsroom_dataset_train, newsroom_dataset_dev, newsroom_dataset_test = get_newsroom_datasets()
word_vectors = newsroom_dataset_train.word_vectors
start_index = newsroom_dataset_train.word_indices['<start>']
end_index = newsroom_dataset_train.word_indices['<end>']

# Create Model

In [None]:
generator_model = GeneratorModel(word_vectors, start_index, end_index, num_hidden1=None, num_hidden2=None, with_coverage=False, gamma=GAMMA)

# Train and Save Model

In [None]:
dataloader = VariableBatchDataLoader(newsroom_dataset_train, batch_size=BATCH_SIZE, shuffle=True)
optimizer = torch.optim.Adam(generator_model.parameters(),
                             lr=LEARNING_RATE)
# optimizer = torch.optim.Adagrad((generator_model.cuda() if USE_CUDA else generator_model).parameters(),
#                                 lr=LEARNING_RATE, initial_accumulator_value=INITIAL_ACCUMULATOR_VALUE)
model_manip = ModelManipulator(generator_model, optimizer, loss_function, error_function, use_cuda=USE_CUDA)
train_stats, val_stats = model_manip.train(dataloader, NUM_EPOCHS, dataset_val=newsroom_dataset_dev, stats_every=10, verbose_every=10)

In [None]:
torch.save(generator_model, 'data/generator_test.model')

In [None]:
generator_model = torch.load('data/generator_test.model')

# Plot

In [None]:
plot_learning_curves(training_values=train_stats, validation_values=val_stats, figure_name='generator_training_test')

In [None]:
batch = newsroom_dataset_train[0:1]
results = generator_model(batch['text'].cuda(), batch['text_length'].cuda(), beam_size=4, return_all=True)
# generated_output = generator_model(batch['text'], batch['text_length'])

In [None]:
loss, summary_info = results[0]
for i,indices in enumerate(summary_info[0]):
    text, l = batch['text'][i], batch['text_length'][i]
    print("text", get_index_words(text[:l].numpy(), newsroom_dataset_train.words))
    text, l = batch['summary'][i], batch['summary_length'][i]
    print("summary", get_index_words(text[:l].numpy(), newsroom_dataset_train.words))
    print("generated summary", get_index_words(indices[:summary_info[1][i]], newsroom_dataset_train.words))
    print(loss[i])