In [18]:
import datasets
import transformers
import rouge

### Loading Data

In [32]:
train_data =       datasets.load_dataset("cnn_dailymail", "3.0.0", split="train")
validation_data =  datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation")
test_data =        datasets.load_dataset("cnn_dailymail", "3.0.0", split="test")

Reusing dataset cnn_dailymail (/Users/jeroen/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234)
Reusing dataset cnn_dailymail (/Users/jeroen/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234)
Reusing dataset cnn_dailymail (/Users/jeroen/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234)


### Sample of the data

In [33]:
def show_examples(dataset, num_samples=3, seed=42):
    samples = dataset.shuffle(seed=seed).select(range(num_samples))
        
    for idx, sample in enumerate(samples):
        display(f'sample {idx}: {sample["article"]} \n')
        display(f'highlight {idx}: {sample["highlights"]} \n')
        display(f'id: {sample["id"]}')
        display('-------')
        
def get_samples(dataset, num_samples=10):
    return dataset.shuffle(seed=1).select(range(num_samples))

def get_random_sample(dataset):
    sample = dataset.shuffle(seed=1).select(range(1)) 
    return [sample["article"][0], sample["highlights"][0]]

In [34]:
# get_random_sample(train_data)

### Tokenizer

In [35]:
batch_size=4 # change to 16 for full training
encoder_max_length=512
decoder_max_length=128

tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased")
# tokenizer = transformers.DistilBertTokenizerFast()
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token

def tokenize_data_to_model_input(batch):
    inputs  = tokenizer(batch["article"], padding="max_length", 
                       truncation=True, max_length=encoder_max_length)
    
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(batch["highlights"], 
                           padding="max_length", truncation=True, max_length=decoder_max_length)
    
    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    
    batch["decoder_input_ids"] = labels.input_ids
    batch["decoder_attention_mask"] = labels.attention_mask
    batch["labels"] = labels["input_ids"]
    
    # because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`. 
    # We have to make sure that the PAD token is ignored
    batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]
    
    return batch

# For now subsample is being used
percentile = 0.0001
amount = round(len(train_data) * percentile)
print(amount)
train_data = train_data.select(range(amount))

train_data = train_data.map(
    tokenize_data_to_model_input, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "highlights", "id"]
)

train_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)


# only use 16 training examples for notebook - DELETE LINE FOR FULL TRAINING
amount = round(len(validation_data) * percentile)
validation_data = validation_data.select(range(amount))

validation_data = validation_data.map(
    tokenize_data_to_model_input, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "highlights", "id"]
)
validation_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)


test_data = test_data.select(range(128))
# test_data = test_data.map(
#     tokenize_data_to_model_input, 
#     batched=True, 
#     batch_size=batch_size, 
#     remove_columns=["article", "highlights", "id"]
# )
# test_data.set_format(type="torch")
    

29


100%|████████████████████████████████████████████| 8/8 [00:00<00:00, 156.28ba/s]
100%|████████████████████████████████████████████| 1/1 [00:00<00:00, 375.40ba/s]


### Encoder - Decoder

In [None]:
from transformers import DistilBertTokenizer, DistilBertModel, DistilBertConfig, EncoderDecoderModel, EncoderDecoderConfig

# bert2bert = transformers.AutoModelForSeq2SeqLM.from_pretrained("bert-base-uncased", output_loading_info=True)



# config_encoder = DistilBertConfig.from_pretrained('distilbert-base-uncased')
# config_decoder = DistilBertConfig.from_pretrained('distilbert-base-uncased')

# config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

# distilbert2distilbert = EncoderDecoderModel(config=config)

bert2bert = transformers.EncoderDecoderModel.from_encoder_decoder_pretrained("distilbert-base-uncased", "distilbert-base-uncased")


In [36]:
# set special tokens
bert2bert.config.decoder_start_token_id = tokenizer.bos_token_id
bert2bert.config.eos_token_id = tokenizer.eos_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id

# sensible parameters for beam search
bert2bert.config.vocab_size = bert2bert.config.decoder.vocab_size
bert2bert.config.max_length = 142
bert2bert.config.min_length = 56
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4

### Evaluation Metric

In [12]:
from rouge import Rouge 

rouge_scorer = Rouge()

def compute_evaluation_metric(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # all unnecessary tokens are removed
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    
    score = rouge_scorer.get_scores(label_str, pred_str)
    f = score[0]["rouge-2"]["f"]
     
    return {
        "rouge2_fmeasure": f
    }

### Training

In [10]:
from typing_extensions import Protocol, runtime_checkable
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

batch_size = 2
training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    evaluation_strategy = "epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    predict_with_generate=True,
)

trainer = Seq2SeqTrainer(
    model=bert2bert,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_evaluation_metric,
    train_dataset=train_data,
    eval_dataset=validation_data,
)
    
# create_trainer()
trainer.train()

***** Running training *****
  Num examples = 29
  Num Epochs = 10
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 2
  Gradient Accumulation steps = 1
  Total optimization steps = 150


Epoch,Training Loss,Validation Loss,Rouge2 Fmeasure
1,No log,,0.0
2,No log,,0.0
3,No log,,0.0
4,No log,,0.0
5,No log,,0.0
6,No log,,0.0
7,No log,,0.0
8,No log,,0.0
9,No log,,0.0
10,No log,,0.0


***** Running Evaluation *****
  Num examples = 1
  Batch size = 2
***** Running Evaluation *****
  Num examples = 1
  Batch size = 2
***** Running Evaluation *****
  Num examples = 1
  Batch size = 2
***** Running Evaluation *****
  Num examples = 1
  Batch size = 2
***** Running Evaluation *****
  Num examples = 1
  Batch size = 2
***** Running Evaluation *****
  Num examples = 1
  Batch size = 2
***** Running Evaluation *****
  Num examples = 1
  Batch size = 2
***** Running Evaluation *****
  Num examples = 1
  Batch size = 2
***** Running Evaluation *****
  Num examples = 1
  Batch size = 2
***** Running Evaluation *****
  Num examples = 1
  Batch size = 2


Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=150, training_loss=0.0, metrics={'train_runtime': 345.8153, 'train_samples_per_second': 0.839, 'train_steps_per_second': 0.434, 'total_flos': 109841307955200.0, 'train_loss': 0.0, 'epoch': 10.0})

### Evaluation

In [24]:
# map data correctly
def generate_summary(batch):
    # Tokenizer will automatically set [BOS] <text> [EOS]
    # cut off at BERT max length 512
    inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask

    outputs = model.generate(input_ids, attention_mask=attention_mask)

    # all special tokens including will be removed
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    batch["pred"] = output_str

    return batch

In [30]:
model = transformers.EncoderDecoderModel.from_pretrained("./checkpoint-1500")

results = test_data.map(generate_summary, batched=True, batch_size=batch_size)


  0%|                                                    | 0/32 [00:00<?, ?ba/s]


TypeError: can't convert np.ndarray of type numpy.str_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.

## Displaying Results

### Util function

In [None]:
# Credits tot Mauro Di Pietro
# https://towardsdatascience.com/text-summarization-with-nlp-textrank-vs-seq2seq-vs-bart-474943efeb09

import re
import difflib
import nltk
# nltk.download()

'''
Find the matching substrings in 2 strings.
:parameter
    :param a: string - raw text
    :param b: string - raw text
:return
    2 lists used in to display matches
'''
def utils_split_sentences(a, b):
    ## find clean matches
    match = difflib.SequenceMatcher(isjunk=None, a=a, b=b, autojunk=True)
    lst_match = [block for block in match.get_matching_blocks() if block.size > 20]
    
    ## difflib didn't find any match
    if len(lst_match) == 0:
        lst_a, lst_b = nltk.sent_tokenize(a), nltk.sent_tokenize(b)
    
    ## work with matches
    else:
        first_m, last_m = lst_match[0], lst_match[-1]

        ### a
        string = a[0 : first_m.a]
        lst_a = [t for t in nltk.sent_tokenize(string)]
        for n in range(len(lst_match)):
            m = lst_match[n]
            string = a[m.a : m.a+m.size]
            lst_a.append(string)
            if n+1 < len(lst_match):
                next_m = lst_match[n+1]
                string = a[m.a+m.size : next_m.a]
                lst_a = lst_a + [t for t in nltk.sent_tokenize(string)]
            else:
                break
        string = a[last_m.a+last_m.size :]
        lst_a = lst_a + [t for t in nltk.sent_tokenize(string)]

        ### b
        string = b[0 : first_m.b]
        lst_b = [t for t in nltk.sent_tokenize(string)]
        for n in range(len(lst_match)):
            m = lst_match[n]
            string = b[m.b : m.b+m.size]
            lst_b.append(string)
            if n+1 < len(lst_match):
                next_m = lst_match[n+1]
                string = b[m.b+m.size : next_m.b]
                lst_b = lst_b + [t for t in nltk.sent_tokenize(string)]
            else:
                break
        string = b[last_m.b+last_m.size :]
        lst_b = lst_b + [t for t in nltk.sent_tokenize(string)]
    
    return lst_a, lst_b


'''
Highlights the matched strings in text.
:parameter
    :param a: string - raw text
    :param b: string - raw text
    :param both: bool - search a in b and, if True, viceversa
    :param sentences: bool - if False matches single words
:return
    text html, it can be visualized on notebook with display(HTML(text))
'''
def display_string_matching(a, b, both=True, sentences=True, titles=[]):
    if sentences is True:
        lst_a, lst_b = utils_split_sentences(a, b)
    else:
        lst_a, lst_b = a.split(), b.split()       
    
    ## highlight a
    first_text = []
    for i in lst_a:
        if re.sub(r'[^\w\s]', '', i.lower()) in [re.sub(r'[^\w\s]', '', z.lower()) for z in lst_b]:
            first_text.append('<span style="background-color:rgba(255,215,0,0.3);">' + i + '</span>')
        else:
            first_text.append(i)
    first_text = ' '.join(first_text)
    
    ## highlight b
    second_text = []
    if both is True:
        for i in lst_b:
            if re.sub(r'[^\w\s]', '', i.lower()) in [re.sub(r'[^\w\s]', '', z.lower()) for z in lst_a]:
                second_text.append('<span style="background-color:rgba(255,215,0,0.3);">' + i + '</span>')
            else:
                second_text.append(i)
    else:
        second_text.append(b) 
    second_text = ' '.join(second_text)
    
    ## concatenate
    if len(titles) > 0:
        first_text = "<strong>"+titles[0]+"</strong><br>"+first_text
    if len(titles) > 1:
        second_text = "<strong>"+titles[1]+"</strong><br>"+second_text
    else:
        second_text = "---"*65+"<br><br>"+second_text
    final_text = first_text +'<br><br>'+ second_text
    return final_text

## Results

In [None]:
from IPython.core.display import display, HTML
from rouge import Rouge 

rouge_new = Rouge()

for i in range(10):
    article = results["article"][i]
    highlight = results["highlights"][i]
    prediction = results["pred"][i]
    score = rouge_new.get_scores(highlight, prediction)
    
    rouge_1_f = score[0]["rouge-1"]["f"] * 100
    rouge_2_f = score[0]["rouge-2"]["f"] * 100
    s = f"rouge-1: {rouge_1_f}, rouge-2:  {rouge_2_f}" 
    
#     match_article_prediction = display_string_matching(article, prediction, 
#                                                        both=True, sentences=False,
#                                                       titles=["Article", "Predicted Summary"])
    
    match_summary = display_string_matching(highlight, prediction, both=True, 
                                    sentences=False, 
                                    titles=["Real Summary", f"Predicted Summary ({s})"])
    
    display(HTML(match_summary))
    print("---")