# Setup

Imports and plotting functions

In [1]:
from typing import List, Union
from parlai.utils.strings import colorize
import copy
def compute_ngram_repeats(context: Union[str, List], model_text: Union[str, List], n=3, splitted=False):
    cgrams = {}
    # compute N grams of the context
    
    if not splitted:
        context = context.split(' ')
        model_text = model_text.split(' ')

    for i in range(n, len(context) + 1):
        ngram = ' '.join(context[i - n : i])
        cgrams[ngram] = True
    # compute N grams of the model response
    creps = 0
    lreps = 0
    repetition_idxs = [0] * len(model_text)
    lreps_idxs = [0] * len(model_text)
    creps_idxs = [0] * len(model_text)
    
    lgrams = {}

    for i in range(n, len(model_text) + 1):
        ngram = ' '.join(model_text[i - n : i])
        
        if ngram in cgrams:
            creps = creps + 1
            repetition_idxs[i-1] = 1
            creps_idxs[i-1] = 1
        
        if ngram in lgrams:
            lreps = lreps + 1
            repetition_idxs[i-1] = 1
            lreps_idxs[i-1] = 1
           
        lgrams[ngram] = True
    
    for i in range(n-1, len(model_text)):
        if repetition_idxs[i] == 1:
            for j in range(1, n):
                repetition_idxs[i-j] = 1

        if creps_idxs[i] == 1:
            for j in range(1, n):
                creps_idxs[i-j] = 1

        if lreps_idxs[i] == 1:
            for j in range(1, n):
                lreps_idxs[i-j] = 1

    return creps + lreps, creps, lreps, repetition_idxs, creps_idxs, lreps_idxs


def print_with_colors(text, repeat_indices):
    colorized_tokens = []
    tokenized_text = text.split(" ")
    
    is_repeat_indices = copy.copy(repeat_indices)
    for (token, is_repeat) in zip(tokenized_text, is_repeat_indices):
       
        if is_repeat:
            colorized_token = colorize(token, "red")
        else:
            colorized_token = token

        colorized_tokens.append(colorized_token)

    return " ".join(colorized_tokens)

def print_sample(context, model_text, repeat_type='all'):
    _, _, _, arep_idxs, crep_idxs, lrep_idxs = compute_ngram_repeats(context, model_text)
    # print(context)

    rep_idxs = None
    if repeat_type == 'all':
        rep_idxs = arep_idxs
    elif repeat_type == 'context':
        rep_idxs = crep_idxs
    elif repeat_type == 'labels':
        rep_idxs = lrep_idxs

    print(print_with_colors(model_text, rep_idxs))


# Narrative QA Dataset

This dataset is not suitable for our use case as the answers are very small. 
Maybe this can be used in another setting where we generate summary from the text.

In [1]:
from datasets import load_dataset
dataset = load_dataset("narrativeqa")

Downloading builder script:   0%|          | 0.00/1.85k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

Using custom data configuration default


Downloading and preparing dataset narrativeqa/default (download: 183.61 MiB, generated: 15.21 GiB, post-processed: Unknown size, total: 15.38 GiB) to /home/mila/a/arorakus/scratch/.cache/huggingface/datasets/narrativeqa/default/0.0.0/daef7ccc51ec258bef464658d11751bb20f033da9b4c219fd84563b3a4af0422...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/32747 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10557 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3461 [00:00<?, ? examples/s]

Dataset narrativeqa downloaded and prepared to /home/mila/a/arorakus/scratch/.cache/huggingface/datasets/narrativeqa/default/0.0.0/daef7ccc51ec258bef464658d11751bb20f033da9b4c219fd84563b3a4af0422. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [36]:
example = dataset['train'][1000]
example.keys()

dict_keys(['document', 'question', 'answers'])

In [37]:
print("Document: " + example['document']['summary']['text'][:1000])
print("Question: " + example['question']['text'])
print(f"Answers:")
for i, answer in enumerate(example['answers']):
    print(f"\t{i+1}. {answer['text']}")

Document:  Following his pursuit by Kirill (in The Bourne Supremacy), Jason Bourne (Matt Damon) evades Moscow police while wounded, and deals with more flashbacks of when he first joined Operation Treadstone. Six weeks later, CIA Deputy Director Pamela Landy (Joan Allen) divulges the audiotaped confession of Ward Abbott, the late former head of Treadstone, to Director Ezra Kramer (Scott Glenn). Meanwhile, in Turin, journalist Simon Ross (Paddy Considine) of The Guardian meets an informant to learn about Bourne and Operation Blackbriar, the program succeeding Treadstone. The CIA tracks Ross as he returns to London, after his mention of "Blackbriar" during a cell-phone call to his editor is detected by the ECHELON system. Bourne reappears in Paris to inform Martin Kreutz (Daniel Brühl), the step-brother of his girlfriend Marie Helena Kreutz (Franka Potente), of her assassination in India, also in the previous film.
Bourne reads Ross's articles and arranges a meeting with him at London Wa

## Observation:
The answer size here is pretty small, hence this is not very suitable for our use case.

# Writing Prompts

In [1]:
from datasets import load_dataset

prompt_response_dataset = load_dataset("rewardsignal/reddit_writing_prompts", data_files="prompt_responses_full.csv")


Using custom data configuration rewardsignal--reddit_writing_prompts-dd5d2a64487ab606
Reusing dataset csv (/home/mila/a/arorakus/scratch/.cache/huggingface/datasets/csv/rewardsignal--reddit_writing_prompts-dd5d2a64487ab606/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)


  0%|          | 0/1 [00:00<?, ?it/s]

In [2]:
print("Prompt:")
print(prompt_response_dataset['train'][0]['prompt'])
print()
print("Response:")
print(prompt_response_dataset['train'][0]['response'])

Prompt:
[WP] "Ma'am you can't bring your emotional support dragon inside the restaurant."

Response:
The manager saw the lady in the vest coming a mile away. Literally. It wasn't a small dragon. It lumbered up the path to the Hilltop Restaurant.

*\*sigh\* Not again*, thought the manager. Last time this happened... Have you ever tried pushing a fire-breathing dragon out of a restaurant? It's not easy.

He signaled to the waiter to keep inside and be ready on backup. At least this dragon seemed more... behaved? It was looking around and trying to be careful. But, rules were rules.

He walked outside, put up his hand, and said, "Ma'am you can't bring your emotional support dragon inside the restaurant."

The dragon yipped and grabbed the woman, holding her tight. "Ssh, ssh. It's OK. He's not trying to hurt you," she cooed while stroking it softly. "Hug me as long as you need to." The dragon stopped shaking, but just stared wide-eyed at the manager.

She turned her head, looked at the man

In [9]:
from transformers import AutoTokenizer, AutoModelForCausalLM

gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")

gpt2_finetuned_model = AutoModelForCausalLM.from_pretrained("/home/mila/a/arorakus/scratch/ews/finetuned_writing_prompts/08-13-2022-05-56/")

# ELI-5 

In [3]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5")

model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5")

In [2]:
from datasets import load_dataset

dataset = load_dataset("eli5")

Downloading builder script:   0%|          | 0.00/5.63k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

Downloading and preparing dataset eli5/LFQA_reddit (download: 6.03 MiB, generated: 1.26 GiB, post-processed: Unknown size, total: 1.26 GiB) to /home/mila/a/arorakus/scratch/.cache/huggingface/datasets/eli5/LFQA_reddit/1.0.0/17574e5502a10f41bbd17beba83e22475b499fa62caa1384a3d093fc856fe6fa...


Downloading:   0%|          | 0.00/3.50k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/576M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/21.1M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/286M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/9.65M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/17.7M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/330M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/18.7M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/36.2M [00:00<?, ?B/s]

Dataset eli5 downloaded and prepared to /home/mila/a/arorakus/scratch/.cache/huggingface/datasets/eli5/LFQA_reddit/1.0.0/17574e5502a10f41bbd17beba83e22475b499fa62caa1384a3d093fc856fe6fa. Subsequent calls will reuse this data.


  0%|          | 0/9 [00:00<?, ?it/s]

In [16]:
dataset['test_eli5'][1234]

{'q_id': '359ec0',
 'title': 'Why do planes seem to ”rock" side-to-side when taking off?',
 'selftext': "I've recently been employed in a position that requires me to travel a lot. I've noticed that when taking off from a runway, planes seem to rock side to side on the runway before they take off. Any explanation would help my peace of mind.",
 'document': '',
 'subreddit': 'explainlikeimfive',
 'answers': {'a_id': ['cr29zuw'],
  'text': ['Not quite sure when you\'re experiencing the rocking, I\'m assuming it\'s just before you reach speed where the wheels lift off the ground?\n\nIf so, a key thing to realize is that large passenger planes are actually VERY flexible. They\'re designed that way to allow shock absorption and deal with some pretty large stresses when you\'re flying through turbulent or wind-gusty air that might hit one part of the plane differently from another or shove it around a little. A brittle plane that experiences those stresses might snap, but a flexible plane "b

In [19]:
from transformers import BartForConditionalGeneration, BartTokenizer
question = dataset['test_eli5'][1234]['selftext']
tok = BartTokenizer.from_pretrained("facebook/bart-large")

batch = tok(question, return_tensors="pt")

In [22]:
generated_ids = model.generate(batch["input_ids"],  do_sample=True, top_p=0.9, min_length=100)

tok.batch_decode(generated_ids, skip_special_tokens=True)

[' When taking off from a runway, planes seem to rock side to side on the runway before']

# ArXiv Summarization   

The summaries generated here are pretty good by the looks of it but these are very extractive and copies heavily from the source text. 
Can this be a symptom of degradation, where context is being copied in huge chunks rather than being abstractive in nature?
Can something like entropy aware beam (or greedy) search  reduce the extractiveness of the summary?

In [3]:
from datasets import load_dataset

# pubmed_summ_dataset = load_dataset("scientific_papers", "pubmed")
arxiv_summ_dataset = load_dataset("scientific_papers", "arxiv")

Reusing dataset scientific_papers (/home/mila/a/arorakus/scratch/.cache/huggingface/datasets/scientific_papers/arxiv/1.1.1/306757013fb6f37089b6a75469e6638a553bd9f009484938d8f75a4c5e84206f)


  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
arxiv_summ_dataset['test']

Dataset({
    features: ['article', 'abstract', 'section_names'],
    num_rows: 6440
})

In [5]:
from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer

bigbird_arxiv_tokenizer = AutoTokenizer.from_pretrained("google/bigbird-pegasus-large-arxiv")

bigbird_arxiv_model = BigBirdPegasusForConditionalGeneration.from_pretrained("google/bigbird-pegasus-large-arxiv")

In [38]:
import random

arxiv_summ_testset = arxiv_summ_dataset['test']

idx = random.randint(0, len(arxiv_summ_testset))

text = arxiv_summ_dataset['test'][idx]['article']
abstract = arxiv_summ_dataset['test'][idx]['abstract']

text = ' '.join(text.split()[:2000])

print(abstract)

 we study the detectability of circular polarization in a stochastic gravitational wave background from various sources such as supermassive black hole binaries , cosmic strings , and inflation in the early universe with pulsar timing arrays . 
 we calculate generalized overlap reduction functions for the circularly polarized stochastic gravitational wave background . 
 we find that the circular polarization can not be detected for an isotropic background . however , there is a chance to observe the circular polarization for an anisotropic gravitational wave background . 
 we also show how to separate polarized gravitational waves from unpolarized gravitational waves . 


In [39]:
inputs = bigbird_arxiv_tokenizer(text, return_tensors='pt')
prediction = bigbird_arxiv_model.generate(**inputs)
prediction = bigbird_arxiv_tokenizer.batch_decode(prediction)
prediction

['<s> we investigate the detectability of circular polarization in the stochastic gravitational wave background ( sgwb ) by pulsar timing arrays ( ptas ). we characterize the sgwb by the so called stokes @xmath0 parameter and calculate generalized overlap reduction functions ( orfs ) so that we can probe the circular polarization of the sgwb.<n> we also discuss a method to separate the intensity ( @xmath1 mode ) and circular polarization ( @xmath2 mode ) of the sgwb.</s>']

In [40]:
print_sample(prediction[0], text, repeat_type='context')

print("Abstract:")
print_sample(text, abstract)

print("Summary:")
print()
print("Context Repeats Highlighted")
print_sample(text, prediction[0], repeat_type='all')
print()

print("Label Repeats Highlighted")
print_sample(text, prediction[0], repeat_type='labels')


it is believed that the direct detection of gravitational waves ( gws ) will bring the era of gravitational wave astronomy . the interferometer detectors are now under operation and awaiting the first signal of gws @xcite . it is also known that [0;31mpulsar[0;0m [0;31mtiming[0;0m [0;31marrays[0;0m [0;31m([0;0m [0;31mptas[0;0m ) can be used as a detector for gws @xcite . these detectors are used to search for very low frequency ( @xmath0 ) gravitational waves , where the lower limit of the observable frequencies is determined by the inverse of total observation time @xmath1 . indeed , the total observation time has a crucial role in ptas , because ptas are most sensitive near the lower edge of observable frequencies @xcite . taking into account its sensitivity , the first direct detection of the gravitational waves might be achieved by ptas . the main target of ptas is [0;31mthe[0;0m [0;31mstochastic[0;0m [0;31mgravitational[0;0m [0;31mwave[0;0m [0;31mbackground[0;0

# Xsum & Pegasus

In [19]:
from datasets import load_dataset

xsum_dataset = load_dataset("xsum")

Using custom data configuration default
Reusing dataset xsum (/home/mila/a/arorakus/scratch/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934)


  0%|          | 0/3 [00:00<?, ?it/s]

In [20]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
pegasus_tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum")
pegasus_model = AutoModelForSeq2SeqLM.from_pretrained("google/pegasus-xsum")

In [52]:
import random

xsum_testset = xsum_dataset['test']

idx = random.randint(0, len(xsum_testset))

xsum_document = xsum_testset[idx]['document']
xsum_summary = xsum_testset[idx]['summary']

print("Document:")
print(xsum_document)
print()
print("Summary:")
print(xsum_summary)

Document:
A red Ford Fiesta travelling north at Drumjohn, near Carsphairn, was in collision with a white Asda delivery van heading south at about 10:40 on Friday.
The van driver Scott Kennedy, 46, was taken by ambulance to Ayr Hospital where he died a short time later.
The Fiesta driver,  50 year old Antony Sztuka, died at the scene.
Both men were from Ayrshire.
The A713 was re-opened around 18:45 hours.
Sgt Billy McEwan, of Police Scotland, said: "We would like to hear from anyone who was in the area of the time of the crash to contact police.
"We know from witnesses already spoken to that there was a white flat-bed pickup truck - the size of a transit van - on the road at the time of the crash.
"We are very keen to speak to the driver as he or she may have information that could prove vital to the investigation."

Summary:
Two men have died following a road crash on the A713 in Dumfries and Galloway.


In [53]:
inputs = pegasus_tokenizer(xsum_document, max_length=512, truncation=True, return_tensors='pt')

prediction = pegasus_model.generate(**inputs)
prediction = pegasus_tokenizer.batch_decode(prediction)
prediction

Ignored unknown kwarg option direction


['<pad> Two men have died in a crash on the A713 in South Ayrshire.</s>']

In [54]:
print_sample(prediction[0], xsum_document, repeat_type='context')

print("Abstract:")
print_sample(xsum_document, xsum_summary)

print("Summary:")
print()
print("Context Repeats Highlighted")
print_sample(xsum_document, prediction[0], repeat_type='all')
print()

print("Label Repeats Highlighted")
print_sample(xsum_document, prediction[0], repeat_type='labels')

A red Ford Fiesta travelling north at Drumjohn, near Carsphairn, was in collision with a white Asda delivery van heading south at about 10:40 on Friday.
The van driver Scott Kennedy, 46, was taken by ambulance to Ayr Hospital where he died a short time later.
The Fiesta driver,  50 year old Antony Sztuka, died at the scene.
Both men were from Ayrshire.
The A713 was re-opened around 18:45 hours.
Sgt Billy McEwan, of Police Scotland, said: "We would like to hear from anyone who was in the area of the time of the crash to contact police.
"We know from witnesses already spoken to that there was a white flat-bed pickup truck - the size of a transit van - on the road at the time of the crash.
"We are very keen to speak to the driver as he or she may have information that could prove vital to the investigation."
Abstract:
Two men have died following a road crash on the A713 in Dumfries and Galloway.
Summary:

Context Repeats Highlighted
<pad> Two men have died in a crash on the A713 in South 

In [None]:
# CNN-Daily Mail

In [5]:
from datasets import load_dataset
cnn_dm_dataset = load_dataset("cnn_dailymail", "3.0.0")

Downloading and preparing dataset cnn_dailymail/3.0.0 (download: 558.32 MiB, generated: 1.28 GiB, post-processed: Unknown size, total: 1.82 GiB) to /home/mila/a/arorakus/scratch/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234...


Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/159M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/376M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/572k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/661k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/5 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

Dataset cnn_dailymail downloaded and prepared to /home/mila/a/arorakus/scratch/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
pegasus_cnn_dailymail_tokenizer = AutoTokenizer.from_pretrained("google/pegasus-cnn_dailymail")
pegasus_cnn_dailymail_model = AutoModelForSeq2SeqLM.from_pretrained("google/pegasus-cnn_dailymail")

Downloading:   0%|          | 0.00/88.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.12k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.91M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.28G [00:00<?, ?B/s]