In [None]:
import pandas as pd

import _pickle as pickle

## Some useful functions to ease the processings
def save(file,name, folder = ""):
    if folder != "":
        outfile = open('./'+folder+'/'+name+'.pickle', 'wb')
    else:
        outfile = open(name+'.pickle', 'wb')
    pickle.dump(file, outfile)
    outfile.close
    
def load(name, folder = ""):
    if folder != "":
        outfile = open('./'+folder+'/'+name+'.pickle', 'rb')
    else:
        outfile = open(name+'.pickle', 'rb')
    file = pickle.load(outfile)
    outfile.close
    return file

from transformers import BertTokenizer, TFBertForSequenceClassification, TFBertForQuestionAnswering, TFBertModel, TFBertForNextSentencePrediction

from tf_transformers import *

from tqdm.notebook import tqdm

import os

from tensorflow.keras.losses import sparse_categorical_crossentropy
from tensorflow.keras.metrics import sparse_categorical_accuracy
import tensorflow as tf

from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model

from tensorflow.keras.optimizers import Adam

from tf_transformers import *
from tensorflow.keras.layers import Input, Dense, Dropout, TimeDistributed, LSTM

## Loading the dataset

In [None]:
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', ")"]
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'
def read_text_file(text_file):
    lines = []
    with open(text_file, "r", encoding="utf8") as f:
        for line in f:
            lines.append(line.strip())
    return lines


def hashhex(s):
    """Returns a heximal formated SHA1 hash of the input string."""
    h = hashlib.sha1()
    h.update(s)
    return h.hexdigest()


def get_url_hashes(url_list):
    return [hashhex(url) for url in url_list]


def fix_missing_period(line):
    """Adds a period to a line that is missing a period"""
    if "@highlight" in line: 
        return line
    if line=="": 
        return line
    if line[-1] in END_TOKENS: 
        return line
  # print line[-1]
    return line + " ."

In [None]:
def get_art_abs(story_file):
    lines = read_text_file(story_file)

  # Lowercase everything
    lines = [line.lower() for line in lines]

  # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; consequently they end up in the body of the article as run-on sentences)
    lines = [fix_missing_period(line) for line in lines]

  # Separate out article and abstract sentences
    article_lines = []
    highlights = []
    next_is_highlight = False
    for idx,line in enumerate(lines):
        if line == "":
            continue # empty line
        elif line.startswith("@highlight"):
            next_is_highlight = True
        elif next_is_highlight:
            highlights.append(line)
        else:
            article_lines.append(line)

  # Make article into a single string
    article = ' '.join(article_lines)

  # Make abstract into a signle string, putting <s> and </s> tags around the sentences
    abstract = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights])

    return article, abstract

In [None]:
dir1 = 'cnn_stories_tokenized'

ar = list(np.zeros(len(os.listdir(dir1))))
ab = list(np.zeros(len(os.listdir(dir1))))
for i, elt in tqdm(enumerate(os.listdir(dir1)), total = len(os.listdir(dir1))):
    x, abstract = get_art_abs('./'+dir1+'/'+elt)
    
    x = x.replace('-lrb- cnn -rrb-','')
    x = x.replace('-lrb-','')
    x = x.replace('-rrb-','')
    
    ar[i] = x
    ab[i] = abstract.replace('<s>', '').replace('</s>', '')

In [None]:
dir1 = 'dm_stories_tokenized'

ar1 = list(np.zeros(len(os.listdir(dir1))))
ab1 = list(np.zeros(len(os.listdir(dir1))))
for i, elt in tqdm(enumerate(os.listdir(dir1)), total = len(os.listdir(dir1))):
    x, abstract = get_art_abs('./'+dir1+'/'+elt)
    
    x = x.replace('-lrb- cnn -rrb-','')
    x = x.replace('-lrb-','')
    x = x.replace('-rrb-','')
    
    ar1[i] = x
    ab1[i] = abstract.replace('<s>', '').replace('</s>', '')

In [None]:
ar.extend(ar1)
ab.extend(ab1)

df = pd.DataFrame({'article' : ar, 'abstract' : ab})

In [None]:
df.head()

In [None]:
df.shape

In [None]:
save(df, 'cnn_dm_raw')

## Tokenizing

In [None]:
df = load('cnn_dm_raw')

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
max_length_in = 512
max_length_out = 129

X = list(np.zeros(df.shape[0]))
X_masks = list(np.zeros(df.shape[0]))
               
Y_in = list(np.zeros(df.shape[0]))
Y_in_masks = list(np.zeros(df.shape[0]))

Y_out = list(np.zeros(df.shape[0]))
Y_out_masks = list(np.zeros(df.shape[0]))
               
text_pairs = []
for index, line in tqdm(df.iterrows(), total = df.shape[0]):
    s1 = line['article']
    s2 = line['abstract']
    
    s1 = s1.replace('.', '[SEP] [CLS]')
    
    tokenized = tokenizer.encode_plus(str(s1), add_special_tokens = True, max_length = max_length_in, pad_to_max_length = True)
    answer = tokenizer.encode_plus(str(s2), add_special_tokens = True, max_length = max_length_out, pad_to_max_length = True)
    
    X[index] = tokenized['input_ids']
    X_masks[index] = tokenized['attention_mask']
               
    Y_in[index] = answer['input_ids'][:max_length_out]
    Y_in_masks[index] = answer['attention_mask'][:max_length_out]
    
    Y_out[index] = answer['input_ids'][1:]
    Y_out_masks[index] = answer['attention_mask'][1:]


In [None]:
df['art'] = X
df['art_mask'] = X_masks

df['input'] = Y_in
df['input_masks'] = Y_in_masks

df['output'] = Y_out


In [None]:
save(df, 'cnn_dm_refined')

## Splitting into train and test set

In [None]:
df = load('cnn_dm_refined')

In [None]:
X = np.array([list(elt) for elt in df['art'].values]).astype(int)
# X_masks = np.array([list(elt) for elt in df['art_mask'].values]).astype(int)

Y_in = np.array([list(elt) for elt in df['input'].values]).astype(int)
# Y_in_masks = np.array([list(elt) for elt in df['input_masks'].values]).astype(int)

Y_out = np.array([list(elt) for elt in df['output'].values]).astype(int)

Y_out = np.concatenate([Y_out, np.zeros((Y_out.shape[0], 1))], axis = 1)

In [None]:
X.shape

In [None]:
Y_in.shape

In [None]:
Y_out.shape

In [None]:
import gc
del df
gc.collect()

In [None]:
X = X[:,:128]
Y_in = Y_in[:,:32]
Y_out = Y_out[:,:32]

In [None]:
Y_in[:, -1] = 102
Y_out[:,-1] = 0

In [None]:
X_train_enc, X_test_enc, y_train, y_test = train_test_split(X, Y_out, random_state=42, test_size=0.1)
X_train_dec, X_test_dec, _, _ = train_test_split(Y_in, Y_out, random_state=42, test_size=0.1)

X_train = [X_train_enc, X_train_dec]
X_test = [X_test_enc, X_test_dec]

## Building architecture

In [None]:
## Encoder
def build_encoder(max_length_in = 512, vocab_size = 30522):

    encoder_input = Input(shape = (None,), dtype = 'int32')
    
    attention_mask = create_padding_mask(encoder_input)
    
    attention_bert = create_padding_mask(encoder_input, add_dimension = False)
    
    sentence_encoder = TFBertModel.from_pretrained(
        "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.  
        output_attentions = False, # Whether the model returns attentions weights.
        output_hidden_states = False, # Whether the model returns all hidden-states.
    )

    encoded = sentence_encoder(encoder_input, attention_mask = 1 - attention_bert)

    encoded = encoded[0]

    encoder = Model(encoder_input, [encoded,attention_mask] )
    return encoder

In [None]:
encoder = build_encoder()

In [None]:
encoder.summary()

In [None]:
def build_decoder(d_model = 768, max_length_out = 128, vocab_size = 30522):

    decoder_input = Input(shape = (None,))
    
    encoder_output = Input(shape = (None, 768))
    encoder_mask = Input(shape = (1,1,None))
    
    inputs_decoder = [decoder_input, encoder_output, encoder_mask]
    
    dec = Decoder(num_layers = 6, d_model = d_model, num_heads = 8, dff = 512, target_vocab_size = vocab_size,
               maximum_position_encoding = max_length_out, rate=0.1, bidirectional_decoder = False)

    decoded, _ = dec( decoder_input, encoder_output, training = True, padding_mask = encoder_mask)

    decoded = tf.keras.layers.Dense(vocab_size)(decoded)

    decoder = Model(inputs_decoder, decoded)
    return decoder

In [None]:
decoder = build_decoder()

In [None]:
decoder.summary()

In [None]:
## Encoder decoder architecture

max_length_in = 512
max_length_out = 128

vocab_size = 30522

encoder_inputs = Input(shape = (None,), dtype = 'int32')
decoder_inputs = Input(shape = (None,))


inputs = [encoder_inputs,  decoder_inputs]

encoder = build_encoder(max_length_in = max_length_in, vocab_size = 30522)

decoder = build_decoder(d_model = 768, max_length_out = max_length_out, vocab_size = 30522)

enc = encoder(encoder_inputs)
encoder_output = enc[0]
encoder_mask = enc[1] 

decoded = decoder([decoder_inputs, encoder_output, encoder_mask])

model = Model(inputs, decoded)

In [None]:
model.summary()

In [None]:
X.shape

## Pretraining the decoder

In [None]:
for layer in model.layers[2:3]:
    layer.trainable = False

In [None]:
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True)

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
                    from_logits=True, reduction='none')

def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
  
    return tf.reduce_mean(loss_)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
                    name='train_accuracy')

loss_classif     =  loss_function# find the right loss for multi-class classification
optimizer        =  Adam(3e-5, 1e-8) # find the right optimizer
metrics_classif  =  [train_accuracy]


model.compile(loss=loss_classif,
              optimizer=optimizer,
              metrics=metrics_classif)

In [None]:
model.summary()

In [None]:
batch_size = 20
epochs = 4

with tf.device('/device:GPU:0'):
    history = model.fit(X_train, y_train, batch_size=batch_size,
                                  epochs=epochs, validation_data=(X_test,  y_test))

## Training the full encoder decoder

In [None]:
for layer in model.layers[2:3]:
    layer.trainable = True

In [None]:
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False)

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
                    from_logits=True, reduction='none')

def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
  
    return tf.reduce_mean(loss_)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
                    name='train_accuracy')

loss_classif     =  loss_function# find the right loss for multi-class classification
optimizer        =  Adam(3e-5, 1e-8) # find the right optimizer
metrics_classif  =  [train_accuracy]


model.compile(loss=loss_classif,
              optimizer=optimizer,
              metrics=metrics_classif)

In [None]:
model.summary()

In [None]:
batch_size = 16
epochs = 4

with tf.device('/device:GPU:0'):
    history = model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(X_test,  y_test))