In [None]:
# !pip install transformers sentence-transformers stanza datasets rouge_score spacy
# ! python -m spacy download en_core_web_lg

import stanza

corenlp_dir = './corenlp'
stanza.install_corenlp(dir=corenlp_dir)

# Set the CORENLP_HOME environment variable to point to the installation location
import os
os.environ["CORENLP_HOME"] = corenlp_dir

!ls $CORENLP_HOME
from stanza.server import CoreNLPClient


import datasets
import transformers


import logging
import os

import re
import sys
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from datasets import load_dataset, load_metric
from dataclasses import dataclass, field
from typing import Optional
from rouge_score import rouge_scorer

from tqdm.notebook import tqdm
from stanza.server import CoreNLPClient
!pip install googletrans==4.0.0-rc1
import pandas as pd
import googletrans
from googletrans import Translator

import random
import spacy
spacy = spacy.load("en_core_web_lg")

In [43]:
def align_ws(old_token, new_token):
    # Align trailing whitespaces between tokens
    if old_token[-1] == new_token[-1] == " ":
        return new_token
    elif old_token[-1] == " ":
        return new_token + " "
    elif new_token[-1] == " ":
        return new_token[:-1]
    else:
        return new_token
    
NEGATABLE_TOKENS = ("are", "is", "was", "were", "have", "has", "had",
                                   "do", "does", "did", "can", "ca", "could", "may",
                                   "might", "must", "shall", "should", "will", "would")



def negation(summary):
  candidate_tokens = [token for token in summary if token.text in NEGATABLE_TOKENS]

  if not candidate_tokens:
      return None, None

  # choose random token to negate
  negated_token = random.choice(candidate_tokens)
  negated_index = negated_token.i
  L = len(summary)

  # negation occurs at the first negatable token (e.g. does not have)
  if negated_index > 0:
    if summary[negated_index-1].text in NEGATABLE_TOKENS:
      negated_token = summary[negated_index-1]
      negated_index = negated_index-1
  
  #check whether qualified by a negative
  is_negative = False
  if (L-1) > negated_index:
    if summary[negated_index + 1].text in ["not", "n't"]:
      is_negative = True
    elif summary[negated_index+1].text == "no":
      return None, None

  #add not/n't if is_negative is False, remove if True
  tokens = [token.text_with_ws for token in summary]
  if is_negative:
      if summary[negated_index + 1].text.lower() == "n't":
          if summary[negated_index + 1].text.lower() == "ca":
              tokens[negated_index] = "can" if tokens[negated_index].islower() else "Can"
          tokens[negated_index] = tokens[negated_index] + " "
      tokens.pop(negated_index + 1)

  else:
      if summary[negated_index].text.lower() in ["am", "may", "might", "must", "shall", "will"]:
          negation = "not "
      else:
          negation = random.choice(["not ", "n't "])

      if negation == "n't ":
          if summary[negated_index].text.lower() == "can":
              tokens[negated_index] = "ca" if tokens[negated_index].islower() else "Ca"
          else:
              tokens[negated_index] = tokens[negated_index][:-1]
      tokens.insert(negated_index + 1, negation)
  
  new_summary = spacy("".join(tokens))
  augmentation_span = [(negated_index, negated_index if is_negative else negated_index + 1)]

  if new_summary.text == summary.text:
    return None, None
  else:
    return new_summary, augmentation_span

CLASS_TO_PRONOUN = {
            "SUBJECT": ["you", "he", "she", "we", "they"],
            "OBJECT": ["me", "you", "him", "her", "us", "them"],
            "POSSESSIVE": ["my", "your", "his", "her", "its", "our", "your", "their"],
            "REFLEXIVE": ["myself", "yourself", "himself", "itself", "ourselves", "yourselves", "themselves"]
        }


PRONOUN_TO_CLASS = {pronoun: key for (key, values) in CLASS_TO_PRONOUN.items() for pronoun in values}
PRONOUNS = [pronoun for pronoun in PRONOUN_TO_CLASS.keys()]


def pronounswap(summary):
  summary_pronouns = [token for token in summary if token.text.lower() in PRONOUNS]

  if not summary_pronouns:
    return None, None

  swap = random.choice(summary_pronouns)
  swap_index = swap.i
  swap_class = PRONOUN_TO_CLASS[swap.text.lower()]

  candidate_tokens = [token for token in CLASS_TO_PRONOUN[swap_class] if token != swap.text.lower()]

  if not candidate_tokens:
    return None, None

  swapped_token = random.choice(candidate_tokens)
  swapped_token = align_ws(swap.text_with_ws, swapped_token)
  swapped_token = swapped_token if swap.text.islower() else swapped_token.capitalize()

  summary_tokens = [token.text_with_ws for token in summary]
  summary_tokens[swap_index] = swapped_token

  new_summary = spacy("".join(summary_tokens))

  augmentation_span = [(swap_index, swap_index)]

  if new_summary.text == summary.text:
    return None, None

  else:
    return new_summary, augmentation_span


ENTITY_CATEGORIES = ("PERSON", "ORG", "NORP", "FAC", "GPE", "LOC", "PRODUCT",
                           "WORK_OF_ART", "EVENT")

def entityswap(summary, source):
  source_ents = [ent for ent in source.ents if ent.label_ in ENTITY_CATEGORIES]
  summary_ents = [ent for ent in summary.ents if ent.label_ in ENTITY_CATEGORIES]

  if not source_ents or not summary_ents:
    return None, None

  swap = random.choice(summary_ents)
  candidate_ents = [ent for ent in source_ents if ent.text != swap.text and ent.text not in swap.text and swap.text not in ent.text]

  if not candidate_ents:
    return None, None

  swapped_ent = random.choice(candidate_ents)
  summary_tokens = [token.text_with_ws for token in summary]
  swapped_token = align_ws(swap.text_with_ws, swapped_ent.text_with_ws)
  summary_tokens = summary_tokens[:swap.start] + [swapped_token] + summary_tokens[swap.end:]

  new_summary = spacy("".join(summary_tokens))
  augmentation_span = [(swap.start, swap.start + len(swapped_ent) - 1)]

  if new_summary.text == summary.text:
    return None, None

  else:
    return new_summary, augmentation_span

NUMBER_CATEGORIES = ("PERCENT", "MONEY", "QUANTITY", "CARDINAL")

def numberswap(summary, source):
  source_ents = [ent for ent in source.ents if ent.label_ in NUMBER_CATEGORIES]
  summary_ents = [ent for ent in summary.ents if ent.label_ in NUMBER_CATEGORIES]

  if not source_ents or not summary_ents:
    return None, None

  swap = random.choice(summary_ents)
  candidate_ents = [ent for ent in source_ents if ent.text != swap.text and ent.text not in swap.text and swap.text not in ent.text]

  if not candidate_ents:
    return None, None

  swapped_ent = random.choice(candidate_ents)
  summary_tokens = [token.text_with_ws for token in summary]
  swapped_token = align_ws(swap.text_with_ws, swapped_ent.text_with_ws)
  summary_tokens = summary_tokens[:swap.start] + [swapped_token] + summary_tokens[swap.end:]

  new_summary = spacy("".join(summary_tokens))
  augmentation_span = [(swap.start, swap.start + len(swapped_ent) - 1)]

  if new_summary.text == summary.text:
    return None, None

  else:
    return new_summary, augmentation_span

DATE_CATEGORIES = ("DATE", "TIME")

def dateswap(summary, source):
  source_ents = [ent for ent in source.ents if ent.label_ in DATE_CATEGORIES]
  summary_ents = [ent for ent in summary.ents if ent.label_ in DATE_CATEGORIES]

  if not source_ents or not summary_ents:
    return None, None

  swap = random.choice(summary_ents)
  candidate_ents = [ent for ent in source_ents if ent.text != swap.text and ent.text not in swap.text and swap.text not in ent.text]

  if not candidate_ents:
    return None, None

  swapped_ent = random.choice(candidate_ents)
  summary_tokens = [token.text_with_ws for token in summary]
  swapped_token = align_ws(swap.text_with_ws, swapped_ent.text_with_ws)
  summary_tokens = summary_tokens[:swap.start] + [swapped_token] + summary_tokens[swap.end:]

  new_summary = spacy("".join(summary_tokens))
  augmentation_span = [(swap.start, swap.start + len(swapped_ent) - 1)]

  if new_summary.text == summary.text:
    return None, None

  else:
    return new_summary, augmentation_span

from stanza.server import CoreNLPClient


POS_TAGS = ("NOUN", "ADJ", "PROPN")


def openIE_filter(extract, summary):
  new_text = ""
  extract = spacy(extract)

  start_index = 0
  for token in summary:
    if token.text != extract[0].text:
      start_index+=1
    else:
      break
  tracking_index = start_index
  for word in extract:
    while (summary[tracking_index].text != word.text):
      tracking_index+=1
    if summary[tracking_index].pos_ in POS_TAGS:
      break
  start_index = tracking_index

  while (summary[tracking_index].pos_ != "NOUN" and summary[tracking_index].pos_ != "PROPN"):
    new_text+=(summary[tracking_index].text_with_ws)
    tracking_index += 1
  while (summary[tracking_index].pos_ == "NOUN" or summary[tracking_index].pos_ == "PROPN"):
    new_text+=(summary[tracking_index].text_with_ws)
    tracking_index += 1
  return new_text, start_index, tracking_index
  


# def s_o_swap(summary):
#   summary_ents = [ent for ent in summary.ents]
#   client = CoreNLPClient(timeout=150000000, be_quiet=True, annotators=['openie'], 
#   endpoint='http://localhost:9002')
#   client.start()
#   import time
#   time.sleep(2)
#   story_triples = []

#   document = client.annotate(summary.text, output_format='json')
#   triples = []
#   for sentence in document['sentences']:
#       for triple in sentence['openie']:
#           triples.append({
#             'subject': triple['subject'],
#             'relation': triple['relation'],
#               'object': triple['object']
#           })
  
#   client.stop()
#   time.sleep(2)

#   if not triples:
#     return None, None
#   candidate_triples = []
#   for triple in triples:
#     subj = triple['subject']
#     obj = triple['object']
#     allow_triple = False
#     for ent in summary_ents:
#       if subj == ent.text or subj in ent.text or ent.text in subj:
#         allow_triple = True
#       if obj == ent.text or obj in ent.text or ent.text in obj:
#         allow_triple = True
#     if allow_triple == True:
#       candidate_triples.append(triple)
#   if not candidate_triples:
#     return None, None
#   triple_swap = random.choice(candidate_triples)
#   triple_swap["subject"], s_start, s_end = openIE_filter(triple_swap["subject"], summary)
#   triple_swap["object"], o_start, o_end = openIE_filter(triple_swap["object"], summary)
#   if s_start == 0:
#     triple_swap["object"] = triple_swap["object"].capitalize()
#     if summary[s_start].pos_ != "PROPN":
#       triple_swap["subject"] = triple_swap["subject"][0].lower() + triple_swap["subject"][1:]
#   if o_start == 0:
#     triple_swap["subject"] = triple_swap["subject"].capitalize()
#     if summary[o_start].pos_ != "PROPN":
#       triple_swap["object"] = triple_swap["object"][0].lower() + triple_swap["object"][1:]
  
#   if s_end <= o_start:
#     new_summary = summary[:s_start].text_with_ws + triple_swap["object"] + summary[s_end:o_start].text_with_ws + triple_swap["subject"] + summary[o_end:].text_with_ws
#     augmentation_span = [(s_start, o_end - o_start + s_start-1), (o_end - s_end + s_start, o_end-1)]
#   elif o_end <= s_start:
#     new_summary = summary[:o_start].text_with_ws + triple_swap["subject"] + summary[o_end:s_start].text_with_ws + triple_swap["object"] + summary[s_end:].text_with_ws
#     augmentation_span = [(o_start, s_end - s_start + o_start-1), (s_end - o_end + o_start-1, s_end-2)]
#   else:
#     return None, None
#   new_summary = spacy(new_summary)
#   if new_summary.text == summary.text:
#     return None, None

#   else:
#     return new_summary, augmentation_span


def s_o_swap(summary):
  summary_ents = [ent for ent in summary.ents]
  story_triples = []
  document = client.annotate(summary.text, output_format='json')
  triples = []
  for sentence in document['sentences']:
      for triple in sentence['openie']:
          triples.append({
            'subject': triple['subject'],
            'relation': triple['relation'],
              'object': triple['object']
          })
  if not triples:
    return None, None
  candidate_triples = []
  for triple in triples:
    subj = triple['subject']
    obj = triple['object']
    allow_triple = False
    for ent in summary_ents:
      if subj == ent.text or subj in ent.text or ent.text in subj:
        allow_triple = True
      if obj == ent.text or obj in ent.text or ent.text in obj:
        allow_triple = True
    if allow_triple == True:
      candidate_triples.append(triple)
  if not candidate_triples:
    return None, None
  triple_swap = random.choice(candidate_triples)
  triple_swap["subject"], s_start, s_end = openIE_filter(triple_swap["subject"], summary)
  triple_swap["object"], o_start, o_end = openIE_filter(triple_swap["object"], summary)
  if s_start == 0:
    triple_swap["object"] = triple_swap["object"].capitalize()
    if summary[s_start].pos_ != "PROPN":
      triple_swap["subject"] = triple_swap["subject"][0].lower() + triple_swap["subject"][1:]
  if o_start == 0:
    triple_swap["subject"] = triple_swap["subject"].capitalize()
    if summary[o_start].pos_ != "PROPN":
      triple_swap["object"] = triple_swap["object"][0].lower() + triple_swap["object"][1:]
  if s_end <= o_start:
    new_summary = summary[:s_start].text_with_ws + triple_swap["object"] + summary[s_end:o_start].text_with_ws + triple_swap["subject"] + summary[o_end:].text_with_ws
    augmentation_span = [(s_start, o_end - o_start + s_start-1), (o_end - s_end + s_start, o_end-1)]
  elif o_end <= s_start:
    new_summary = summary[:o_start].text_with_ws + triple_swap["subject"] + summary[o_end:s_start].text_with_ws + triple_swap["object"] + summary[s_end:].text_with_ws
    augmentation_span = [(o_start, s_end - s_start + o_start-1), (s_end - o_end + o_start-1, s_end-2)]
  else:
    return None, None
  new_summary = spacy(new_summary)
  if new_summary.text == summary.text:
    return None, None
  else:
    return new_summary, augmentation_span

SOURCE_LANG = "en"
ACCEPTED_LANGS = ["fr", "de", "zh-TW", "es", "ru"]
translator = Translator()

def backtranslate(summary):
  new_lang = random.choice(ACCEPTED_LANGS)
  summary_trans = translator.translate(summary.text, dest=new_lang)
  summary_btrans = translator.translate(summary_trans.text, dest=SOURCE_LANG)

  new_summary = spacy(summary_btrans.text)
  augmentation_span = (new_summary[0].i, new_summary[-1].i)

  if summary.text == new_summary.text:
    return None, None
  else:
    return new_summary, [augmentation_span]


NOISE_PROB = 0.05
DELETE_PROB = 0.8
def addnoise(summary, augmentation_span):
        summary_tokens = [token.text_with_ws for token in summary]

        new_summary = []
        for ix, token in enumerate(summary_tokens):
            # don't modify text inside an augmented span
            apply_augmentation = True
            if augmentation_span:
              for aug_span in augmentation_span:
                if aug_span:
                  span_start, span_end = aug_span
                  if span_start <= ix <= span_end:
                      apply_augmentation = False

            # decide whether to add noise
            if apply_augmentation and random.random() < NOISE_PROB:
                # decide whether to replicate or delete token
                if random.random() < DELETE_PROB:
                    # update spans and skip token
                    if augmentation_span:
                      for el in range(0, len(augmentation_span)):
                        aug_span = augmentation_span[el]
                        if aug_span:
                          span_start, span_end = aug_span
                          if ix < span_start:
                            span_start -= 1
                            span_end -= 1
                          aug_span = span_start, span_end
                          augmentation_span[el] = aug_span
                      if len(new_summary) > 0:
                        if new_summary[-1][-1] != " ":
                          new_summary[-1] = new_summary[-1] + " "
                      continue      
                else:
                  if augmentation_span:
                    for el in range(0, len(augmentation_span)):
                      aug_span = augmentation_span[el]
                      if aug_span:
                        span_start, span_end = aug_span
                        if ix < span_start:
                          span_start += 1
                          span_end += 1
                        aug_span = span_start, span_end
                        augmentation_span[el] = aug_span

                  new_summary.append(token)
            new_summary.append(token)
        new_summary = spacy("".join(new_summary))

        if summary.text == new_summary.text:
            return None
        else:
            return new_summary

In [130]:
from datasets import load_dataset
from datasets import load_from_disk
# ids_to_keep = np.load('ids_to_keep.npy')
xsum_filtered_train = load_from_disk('data/xsum_filtered/train')
# xsum_filtered_train = xsum_filtered_train._data.filter(lambda x: x['keep']).remove_columns('keep')
xsum_filtered_train._data = xsum_filtered_train._data.filter(xsum_filtered_train['keep'])

In [39]:
# xsum_df = pd.DataFrame(xsum_corrupted_train)
# !pip install install-jdk -t.

In [35]:
#RUN THIS CELL TO START THE CORENLP TAGGER
client = CoreNLPClient(timeout=150000000, be_quiet=True, annotators=['openie'], 
endpoint='http://localhost:9002')
client.start()

2021-12-04 21:44:20 INFO: Writing properties to tmp file: corenlp_server-189c65bd380f4e8b.props
2021-12-04 21:44:20 INFO: Starting server with command: java -Xmx5G -cp ./corenlp/* edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9002 -timeout 150000000 -threads 5 -maxCharLength 100000 -quiet True -serverProperties corenlp_server-189c65bd380f4e8b.props -annotators openie -preload -outputFormat serialized


In [34]:
# !apt update
# !apt install default-jre
# conda install -c anaconda openjdk

In [131]:
import tqdm
xsum_corrupted_train = xsum_filtered_train
corrupted_summaries = xsum_corrupted_train['summary']
already_done_tracker = []

In [None]:
perturbations = [
#     "s_o_swap", 
#     "bt", 
    "prn", 
    "dat", 
    "num", 
    "ent", 
    "neg"
]

corrupt_percent = 0.3
np.random.seed(2021)
corrupt_indicies = np.random.choice(range(len(xsum_corrupted_train)), size=int(corrupt_percent*len(xsum_corrupted_train)))
corruption_groups = np.array_split(corrupt_indicies, len(perturbations))

total_corruption_count = 0
iteration = 0

#   client = CoreNLPClient(timeout=150000000, be_quiet=True, annotators=['openie'], 
#   endpoint='http://localhost:9002')
#   client.start()

for pert_type, indicies in zip(perturbations, corruption_groups):
  for index in tqdm.tqdm(indicies):
#     if index in already_done_tracker:
#         continue
    summary = spacy(xsum_corrupted_train[int(index)]['summary'])
    article = spacy(xsum_corrupted_train[int(index)]['document'])
#     summary = spacy(xsum_df['summary'][index])
#     article = spacy(xsum_df['document'][index])
#     try:
#         if pert_type == "s_o_swap":
#           new_summary, ags = s_o_swap(summary)
    if pert_type == "bt":
      new_summary, ags = backtranslate(summary)
    if pert_type == "prn":
      new_summary, ags = pronounswap(summary)
    if pert_type == "dat":
      new_summary, ags = dateswap(summary, article)
    if pert_type == "num":
      new_summary, ags = numberswap(summary, article)
    if pert_type == "ent":
      new_summary, ags = entityswap(summary, article)
    if pert_type == "neg":
      new_summary, ags = negation(summary)
    if not new_summary and not ags:
      new_summary = summary
      ags = []
#     except:
#         new_summary = None
    new_summary = addnoise(new_summary, ags)
    if new_summary:
      corrupted_summaries[index] = new_summary.text
      total_corruption_count += 1
    already_done_tracker.append(index)
    
# client.stop()

100%|██████████| 8145/8145 [15:34<00:00,  8.71it/s]
100%|██████████| 8145/8145 [15:16<00:00,  8.89it/s]
100%|██████████| 8145/8145 [16:07<00:00,  8.42it/s]
 25%|██▌       | 2058/8145 [04:19<11:44,  8.64it/s]

In [None]:
xsum_corrupted_train = xsum_corrupted_train.add_column('corrupted_summary', corrupted_summaries)

In [None]:
xsum_corrupted_train.save_to_disk("data/xsum_baseline_corrupted/train")


In [None]:
# xsum_train_filtered.save_to_disk("data/xsum_filtered/train")
# xsum_val_filtered.save_to_disk("data/xsum_filtered/val")
# xsum_test_filtered.save_to_disk("data/xsum_filtered/test")

In [58]:

def backtranslate(summary):
  new_lang = random.choice(ACCEPTED_LANGS)
  summary_trans = translator.translate(summary, dest=new_lang)
  summary_btrans = translator.translate(summary_trans.text, dest=SOURCE_LANG)

  new_summary = summary_btrans
  augmentation_span = (new_summary[0].i, new_summary[-1].i)

  if summary == new_summary:
    return None, None
  else:
    return new_summary, [augmentation_span]


0

In [None]:
xsum_filtered_train = xsum_filtered_train.filter(lambda x: x['keep']).remove_columns('keep')
xsum_filtered_val= xsum_filtered_train.filter(lambda x: x['keep']).remove_columns('keep')
xsum_filtered_test= xsum_filtered_train.filter(lambda x: x['keep']).remove_columns('keep')

In [None]:
xsum_corrupted_train = xsum_filtered_train
corrupted_summaries = xsum_train_corrupted['summary']

corrupt_percent = 0.4

np.random.seed(2021)
corrupt_indicies = np.random.choice(range(len(xsum_train_corrupted)), size=int(corrupt_percent*len(xsum_train_corrupted)))
corruption_groups = np.split(corrupt_indicies, len(perturbations))

total_corruption_count = 0
for pert_type, indicies in zip(perturbations, corruption_groups):
  for index in indicies:
    if pert_type == "bt":
      new_summary, ags = backtranslate(summary)
    if pert_type == "prn":
      new_summary, ags = pronounswap(summary)
    if pert_type == "dat":
      new_summary, ags = dateswap(summary, article)
    if pert_type == "num":
      new_summary, ags = numberswap(summary, article)
    if pert_type == "ent":
      new_summary, ags = entityswap(summary, article)
    if pert_type == "neg":
      new_summary, ags = negation(summary)
    if not new_summary and not ags:
      new_summary = summary
      ags = []
    new_summary = addnoise(new_summary, ags)
    if new_summary:
      corrupted_summaries[index] = new_summary.text
      total_corruption_count += 1