In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from PICOHelper import get_pico_datasets
from NewsroomHelper import get_newsroom_datasets
from SummarizationModelStructures import TextEncoder, ContextVectorNN, VocubularyDistributionNN
from utils import get_text_matrix, DataLoader
from pytorch_helper import ModelManipulator, plot_learning_curves

# Parameters

In [None]:
# training parameters
BATCH_SIZE = 2
NUM_EPOCHS = 1
LEARNING_RATE = 1e-3
USE_CUDA = torch.cuda.is_available()
print(USE_CUDA)

# Get Data

In [None]:
# pico_dataset_train, pico_dataset_dev, pico_dataset_test = get_pico_datasets()
newsroom_dataset_train, newsroom_dataset_dev, newsroom_dataset_test = get_newsroom_datasets()
word_vectors = newsroom_dataset_train.word_vectors
start_index = newsroom_dataset_train.word_indices['<start>']
end_index = newsroom_dataset_train.word_indices['<end>']

# Model Structure

In [None]:
class GeneratorModel(nn.Module):
    def __init__(self, word_vectors, start_index, end_index, num_hidden1=None, num_hidden2=None, with_coverage=False):
        super(self.__class__, self).__init__()
        self.word_vectors = word_vectors
        num_features = len(self.word_vectors[0])
        num_vocab = len(self.word_vectors)
        self.start_index = start_index
        self.end_index = end_index
        if num_hidden1 is None:
            num_hidden1 = num_features//2
        if num_hidden2 is None:
            num_hidden2 = num_features//2
        self.with_coverage = with_coverage
        
        self.text_encoder = TextEncoder(num_features, num_hidden1, bidirectional=True)
        self.summary_decoder = nn.LSTMCell(num_features, num_hidden1)
        self.context_nn = ContextVectorNN(num_hidden1*3, num_hidden2, with_coverage=with_coverage)
        self.vocab_nn = VocubularyDistributionNN(num_hidden1*3, num_vocab)
        
    def forward(self, text, text_length, summary=None, summary_length=None, generate_algorithm='greedy'):
        # get batch with vectors from index batch
        max_length = text.size(1)
        text = [get_text_matrix(example[:text_length[i]], self.word_vectors, max_length)[0].unsqueeze(0) for i,example in enumerate(text)]
        text = torch.cat(text, 0)
        
        # run text through lstm encoder
        text_states, (h, c) = self.text_encoder(text, text_length)
        
        #initialize
        coverage = torch.zeros((text_states.size(0), 1, text_states.size(1)), device=text_states.device)\
                   if self.with_coverage else None
        loss = 0
        if summary is None:
            if generate_algorithm == 'greedy':
                return self.forward_generate_greedy(text_states, h, c, coverage, loss)
            else:
                raise Exception
        else:
            return self.forward_supervised(text_states, h, c, coverage, loss, summary, summary_length)
            
    def forward_generate_greedy(self, text_states, h, c, coverage, loss):
        batch_length = text_states.size(0)
        valid_indices = torch.arange(batch_length, device=h.device)
        summary = [torch.zeros((batch_length,1), device=h.device)+self.start_index]
        summary_length = torch.zeros(batch_length, device=h.device)-1
        i = 0
        h, c = h[:,0], c[:,0]
        while True:
            summary_i = summary[-1]
            
            # take a time step
            vocab_dist_i, h, c, coverage, log_prob = self.timestep(valid_indices, summary_i, text_states, h, c, coverage, loss)
            
            summary_ip1 = torch.zeros(batch_length, device=h.device).long()
            summary_ip1[valid_indices] = torch.max(vocab_dist_i, 1)[1]
            summary.append(summary_ip1.unsqueeze(-1))
            ending = (summary_ip1[valid_indices] == self.end_index)
            ended_indices = valid_indices[torch.nonzero(ending).squeeze(-1)]
            valid_indices = valid_indices[torch.nonzero(ending == 0).squeeze(-1)]
            i += 1
            summary_length[ended_indices] = i
            if (summary_length >= 0).sum() == summary_length.size(0) or i > 50000:
                break
            
        return loss, torch.cat(summary, 1), summary_length
            
    
    def forward_supervised(self, text_states, h, c, coverage, loss, summary, summary_length):
        if summary_length is None:
            raise Exception
        h, c = h[:,0], c[:,0]
        for i in range(summary.size(1)):
            # get indices of instances that are not finished
            valid_indices = torch.nonzero((summary_length-i-1) >= 0)[:,0]
            summary_i = summary[valid_indices,i]
            
            # take a time step
            vocab_dist_i, h, c, coverage, loss = self.timestep(valid_indices, summary_i, text_states, h, c, coverage, loss)
            
        return dict(loss=loss)
        
    def timestep(self, valid_indices, summary_i, text_states, h, c, coverage, loss):
        # inputs at valid indices at position i
        text_states_i = text_states[valid_indices]
        summary_vec_i = get_text_matrix(summary_i, self.word_vectors, len(summary_i))[0]
        h_i, c_i = h[valid_indices], c[valid_indices]
        coverage_i = None if coverage is None else coverage[valid_indices]
        
        # Do forward pass
        h_i, c_i = self.summary_decoder(summary_vec_i)
        context_vector_i, attention_i = self.context_nn(text_states_i, h_i, coverage_i)
        vocab_dist_i = self.vocab_nn(context_vector_i, h_i)
        
        # set new h, c, coverage, and loss
        h[valid_indices], c[valid_indices] = h_i, c_i
        if coverage is not None:
            coverage[valid_indices] += attention_i
        loss += -vocab_dist_i[torch.arange(summary_i.size(0)).long(),summary_i.long()].sum()
        return vocab_dist_i, h, c, coverage, loss

# text_encoder.num_hidden + summary_decoder.num_hidden
def loss(loss):
    return loss

def error(loss):
    return None

In [None]:
generator_model = GeneratorModel(word_vectors, start_index, end_index, num_hidden1=None, num_hidden2=None, with_coverage=True)

# Train and Save Model

In [None]:
dataloader = DataLoader(newsroom_dataset_train, batch_size=BATCH_SIZE, shuffle=True)
optimizer = torch.optim.Adam(generator_model.parameters(),
                             lr=LEARNING_RATE)
model_manip = ModelManipulator(generator_model, optimizer, loss, error, use_cuda=USE_CUDA)
train_stats, val_stats = model_manip.train(dataloader, NUM_EPOCHS, dataset_val=newsroom_dataset_dev, stats_every=10, verbose_every=10)

# Plot

In [None]:
plot_learning_curves(training_values=train_stats, validation_values=val_stats, figure_name='summarization_training')

In [None]:
batch = newsroom_dataset_train[:2]
generator_model(batch['text'].cuda(), batch['text_length'].cuda())