In [1]:
import tensorflow as tf
import numpy as np
from abstransformer import Transformer, masks

In [2]:
import pickle
pickle_in=open("document.pkl","rb")
document=pickle.load(pickle_in)
pickle_in.close()
import pickle
pickle_in=open("summary.pkl","rb")
summary=pickle.load(pickle_in)
pickle_in.close()

In [3]:
# for decoder sequence
summary = summary.apply(lambda x: '<go> ' + x + ' <stop>')

In [9]:
filters = '!"#$%&()*+,-./:;=?@[\\]^_`{|}~\t\n'
oov_token = '<unk>'
document_tokenizer = tf.keras.preprocessing.text.Tokenizer(oov_token=oov_token)
summary_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters=filters, oov_token=oov_token)
document_tokenizer.fit_on_texts(document)
summary_tokenizer.fit_on_texts(summary)

In [10]:
# hyper-params
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
EPOCHS = 23

In [11]:
transformer = Transformer(
    num_layers, 
    d_model, 
    num_heads, 
    dff,
    len(document_tokenizer.word_index) + 1, 
    len(summary_tokenizer.word_index) + 1, 
    pe_input=len(document_tokenizer.word_index) + 1, 
    pe_target=len(summary_tokenizer.word_index) + 1,
)

In [12]:
def create_masks(inp, tar):
    enc_padding_mask = create_padding_mask(inp)
    dec_padding_mask = create_padding_mask(inp)

    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
  
    return enc_padding_mask, combined_mask, dec_padding_mask

In [13]:
checkpoint_path = "checkpoints"

# Load checkpoints
ckpt = tf.train.Checkpoint(transformer=transformer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

Latest checkpoint restored!!


In [17]:
def evaluate(input_document):
    input_document = document_tokenizer.texts_to_sequences([input_document])
    input_document = tf.keras.preprocessing.sequence.pad_sequences(input_document, maxlen=400, padding='post', truncating='post')

    encoder_input = tf.expand_dims(input_document[0], 0)

    decoder_input = [summary_tokenizer.word_index["<go>"]]
    output = tf.expand_dims(decoder_input, 0)
    for i in range(75):
        enc_padding_mask, combined_mask, dec_padding_mask = masks.create_masks(encoder_input, output)

        predictions, attention_weights = transformer(
            encoder_input, 
            output,
            False,
            enc_padding_mask,
            combined_mask,
            dec_padding_mask
        )

        predictions = predictions[: ,-1:, :]
        predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)

        if predicted_id == summary_tokenizer.word_index["<stop>"]:
            return tf.squeeze(output, axis=0), attention_weights

        output = tf.concat([output, predicted_id], axis=-1)

    return tf.squeeze(output, axis=0), attention_weights

In [18]:
def summarize(input_document):
    # not considering attention weights for now, can be used to plot attention heatmaps in the future
    summarized = evaluate(input_document=input_document)[0].numpy()
    summarized = np.expand_dims(summarized[1:], 0)  # not printing <go> token
    return summary_tokenizer.sequences_to_texts(summarized)[0]  # since there is just one translated document

In [20]:
summarize(
    "MahaRERA has come across a project whose developer allegedly utilised 100 per cent of funds but merely completed 20 to 30 per cent of the construction work. “We have issued a show-cause notice to the erring developer,” mentioned a release issued by MahaRERA. The case came to fore after MahaRERA Chairman Ajoy Mehta asked officials to scrutinise projects listed with the authority. According to Section 11 of the Real Estate (Regulation and Development) Act, it is mandatory for a registered developer to upload details of a project every three months, and a violation could attract a penalty of up to 30 per cent of the project cost."
)

'govt 39 s hard for 60 mn loss in chandigarh'