# LSTM for auto text summarization

Playground. Will fulfilled this later.

Most of the setup codes are adapted from [this repo](https://github.com/chen0040/keras-text-summarization/).

## Load require libraries

In [23]:
from __future__ import print_function

from keras.models import Model, Sequential
from keras.layers import Input, LSTM, Dense
from keras.preprocessing.sequence import pad_sequences
from keras.callbacks import ModelCheckpoint
import numpy as np
import pandas as pd
import os
from collections import Counter
from sklearn.model_selection import train_test_split

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
batch_size = 64  # Batch size for training.
epochs = 20  # Number of epochs to train for.
latent_dim = 256  # Latent dimensionality of the encoding space.
num_samples = 2000  # Number of samples to train on.

MAX_INPUT_SEQ_LENGTH = 3000
MAX_TARGET_SEQ_LENGTH = 300
MAX_INPUT_VOCAB_SIZE = 15000
MAX_TARGET_VOCAB_SIZE = 6000
DEFAULT_EPOCH_SIZE = epochs = 20
DEFAULT_BATCH_SIZE = batch_size = 64

In [2]:
# data path
text_path = "/home/rlin225/test/text_2000/"
summary_path = "/home/rlin225/test/summary_2000/"
file_list = "/home/rlin225/test/datalist.csv"

## Get embedding matrices

In [4]:
input_counter = Counter()
target_counter = Counter()
max_input_seq_length = 0
max_target_seq_length = 0
input_seq_max_length = MAX_INPUT_SEQ_LENGTH
target_seq_max_length = MAX_TARGET_SEQ_LENGTH

data_df = pd.read_csv(file_list,sep=',',header='infer')

for i in range(len(data_df)):
    text_file = text_path+data_df["text_path"][i]
    with open(text_file, 'r') as rf:
        text = rf.read()
        text = [word.lower() for word in text.split(' ')]
        seq_length = len(text)
    if seq_length > input_seq_max_length:
        text = text[0:input_seq_max_length]
        seq_length = len(text)
    for word in text:
        input_counter[word] += 1
    max_input_seq_length = max(max_input_seq_length, seq_length)
    
    summary_file = summary_path + data_df["summary_path"][i]
    with open(summary_file, 'r') as rf:
        text = '\t' + rf.read() + '\n'
        text = [word.lower() for word in text.split(' ')]
        seq_length = len(text)
    if seq_length > target_seq_max_length:
        text = text[0:target_seq_max_length]
        seq_length = len(text)
    for word in text:
        target_counter[word] += 1
    max_target_seq_length = max(max_target_seq_length, seq_length)

In [5]:
input_word2idx = dict()
for idx, word in enumerate(input_counter.most_common(MAX_INPUT_VOCAB_SIZE)):
    input_word2idx[word[0]] = idx + 2
input_word2idx['PAD'] = 0
input_word2idx['UNK'] = 1
input_idx2word = dict([(idx, word) for word, idx in input_word2idx.items()])

target_word2idx = dict()
for idx, word in enumerate(target_counter.most_common(MAX_TARGET_VOCAB_SIZE)):
    target_word2idx[word[0]] = idx + 1
target_word2idx['UNK'] = 0

target_idx2word = dict([(idx, word) for word, idx in target_word2idx.items()])

num_input_tokens = len(input_word2idx)
num_target_tokens = len(target_word2idx)

In [8]:
print('max_input_seq_length:', max_input_seq_length)
print('max_target_seq_length:', max_target_seq_length)
print('num_input_tokens:', num_input_tokens)
print('num_target_tokens:', num_target_tokens)

max_input_seq_length: 1802
max_target_seq_length: 71
num_input_tokens: 15002
num_target_tokens: 6001


## Helper functions

These are required to convert input texts to numerical representation, and also converting numbers back to text so that we can get summaries.

In [17]:
def transform_input_text(texts):
    """ Takes a list of news texts and batch process them."""
    temp = []
    for input_news in texts:
        x = []
        for word in input_news.lower().split(' '):
            wid = 1
            if word in input_word2idx:
                wid = input_word2idx[word]
            x.append(wid)
            if len(x) >= max_input_seq_length:
                break
        temp.append(x)
    temp = pad_sequences(temp, maxlen=max_input_seq_length)
    print(temp.shape)
    return temp

def transform_target_encoding(texts):
    """ Processing target text sequences here. (Actual encoding translation happens at the generator)"""
    temp = []
    for line in texts:
        x = []
        line2 = '\t' + line.lower() + '\n'
        for word in line2.split(' '):
            x.append(word)
            if len(x) >= max_target_seq_length:
                break
        temp.append(x)

    temp = np.array(temp)
    print(temp.shape)
    return temp

def generate_batch(x_samples, y_samples, batch_size):
    """ Use generators here so as to reduce the burden at the training time """
    num_batches = len(x_samples) // batch_size
    while True:
        for batchIdx in range(0, num_batches):
            start = batchIdx * batch_size
            end = (batchIdx + 1) * batch_size
            encoder_input_data_batch = pad_sequences(x_samples[start:end], max_input_seq_length)
            decoder_target_data_batch = np.zeros(
                shape=(batch_size, max_target_seq_length, num_target_tokens))
            for lineIdx, target_words in enumerate(y_samples[start:end]):
                for idx, w in enumerate(target_words):
                    w2idx = 0  # default [UNK]
                    if w in target_word2idx:
                        w2idx = target_word2idx[w]
                    if w2idx != 0:
                        decoder_target_data_batch[lineIdx, idx, w2idx] = 1
            yield encoder_input_data_batch, decoder_target_data_batch

In [19]:
def summarize(input_text):
    """ Given a news text, automatically generate the summary."""
    input_seq = []
    input_wids = []
    for word in input_text.lower().split(' '):
        idx = 1  # default [UNK]
        if word in input_word2idx:
            idx = input_word2idx[word]
        input_wids.append(idx)
    input_seq.append(input_wids)
    input_seq = pad_sequences(input_seq, max_input_seq_length)
    predicted = self.model.predict(input_seq)
    predicted_word_idx_list = np.argmax(predicted, axis=1)
    predicted_word_list = [self.target_idx2word[wid] for wid in predicted_word_idx_list[0]]
    return predicted_word_list

## Model definition

The meat is here!

In [20]:
def CNN_LSTM():
    """ This is a simple One-shot model, can add more"""
    # encoder input model
    model = Sequential()
    model.add(Embedding(output_dim=128, input_dim=num_input_tokens, input_length=max_input_seq_length))

    # encoder model
    model.add(LSTM(128))
    model.add(RepeatVector(max_target_seq_length))
    # decoder model
    model.add(LSTM(128, return_sequences=True))
    model.add(TimeDistributed(Dense(num_target_tokens, activation='softmax')))

    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

    return model

## Training

The code below is a demo for training.

In [None]:
X = []
y = []

data_df = pd.read_csv(file_list,sep=',',header='infer')

# get the data and store them into an array
print("Loading the data...")
for i in range(len(data_df)):
    text_file = text_path+data_df["text_path"][i]
    with open(text_file, 'r') as rf:
        text = rf.read()
        X.append(text)
        
    summary_file = summary_path + data_df["summary_path"][i]
    with open(summary_file, 'r') as rf:
        text = rf.read()
        y.append(text)

# do a split
print("Spliting the data...")
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.2, random_state=42)
print('training size: ', len(Xtrain))
print('testing size: ', len(Xtest))

Ytrain = transform_target_encoding(Ytrain)
Ytest = transform_target_encoding(Ytest)

Xtrain = transform_input_text(Xtrain)
Xtest = transform_input_text(Xtest)

train_gen = generate_batch(Xtrain, Ytrain, batch_size)
test_gen = generate_batch(Xtest, Ytest, batch_size)

train_num_batches = len(Xtrain) // batch_size
test_num_batches = len(Xtest) // batch_size

# train a model
print('Start fitting ...')
model = CNN_LSTM()
model.fit_generator(generator=train_gen, steps_per_epoch=train_num_batches,
                   epochs=epochs,
                   verbose=VERBOSE, validation_data=test_gen, validation_steps=test_num_batches,
                   callbacks=[checkpoint])
model.save('lstm_summary.h5')

## Inference

Below shows how to do inference.

In [None]:
model = load_model('lstm_summary.h5')
ran = random.randint(0,len(data_df)) - 1
text_file = text_path+data_df["text_path"][ran]
with open(text_file, 'r') as rf:
    test_text = rf.read()
summary_file = summary_path + data_df["summary_path"][ran]
with open(summary_file, 'r') as rf:
    ground_truth = rf.read()
auto_summary = summarize(test_text)
print("Auto summary:",auto_summary)
print("Actual summary:",ground_truth)