In [1]:
# !pip install transformers sentence-transformers stanza datasets rouge_score 
# !pip install spacy
# ! python -m spacy download en_core_web_lg
# !pip install googletrans==4.0.0-rc1
# !pip install pyinflect
# !pip install stanza

import stanza

corenlp_dir = './corenlp'
stanza.install_corenlp(dir=corenlp_dir)

# Set the CORENLP_HOME environment variable to point to the installation location
import os
os.environ["CORENLP_HOME"] = corenlp_dir
from stanza.server import CoreNLPClient


import datasets
import transformers
import logging
import os

import re
import sys
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from datasets import load_dataset, load_metric
from dataclasses import dataclass, field
from typing import Optional

from tqdm.notebook import tqdm
from stanza.server import CoreNLPClient

import pandas as pd

import pyinflect
import spacy
nlp = spacy.load("en_core_web_lg")

import pandas as pd
import random


# !pip install nltk --upgrade
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet as wn

[nltk_data] Downloading package wordnet to /home/geoff/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
def align_ws(old_token, new_token):
    # Align trailing whitespaces between tokens
    if old_token[-1] == new_token[-1] == " ":
        return new_token
    elif old_token[-1] == " ":
        return new_token + " "
    elif new_token[-1] == " ":
        return new_token[:-1]
    else:
        return new_token
    
NEGATABLE_TOKENS = ("are", "is", "was", "were", "have", "has", "had",
                                   "do", "does", "did", "can", "ca", "could", "may",
                                   "might", "must", "shall", "should", "will", "would")



def negation(summary):
  candidate_tokens = [token for token in summary if token.text in NEGATABLE_TOKENS]

  if not candidate_tokens:
      return None, None

  # choose random token to negate
  negated_token = random.choice(candidate_tokens)
  negated_index = negated_token.i
  L = len(summary)

  # negation occurs at the first negatable token (e.g. does not have)
  if negated_index > 0:
    if summary[negated_index-1].text in NEGATABLE_TOKENS:
      negated_token = summary[negated_index-1]
      negated_index = negated_index-1
  
  #check whether qualified by a negative
  is_negative = False
  if (L-1) > negated_index:
    if summary[negated_index + 1].text in ["not", "n't"]:
      is_negative = True
    elif summary[negated_index+1].text == "no":
      return None, None

  #add not/n't if is_negative is False, remove if True
  tokens = [token.text_with_ws for token in summary]
  if is_negative:
      if summary[negated_index + 1].text.lower() == "n't":
          if summary[negated_index + 1].text.lower() == "ca":
              tokens[negated_index] = "can" if tokens[negated_index].islower() else "Can"
          tokens[negated_index] = tokens[negated_index] + " "
      tokens.pop(negated_index + 1)

  else:
      if summary[negated_index].text.lower() in ["am", "may", "might", "must", "shall", "will"]:
          negation = "not "
      else:
          negation = random.choice(["not ", "n't "])

      if negation == "n't ":
          if summary[negated_index].text.lower() == "can":
              tokens[negated_index] = "ca" if tokens[negated_index].islower() else "Ca"
          else:
              tokens[negated_index] = tokens[negated_index][:-1]
      tokens.insert(negated_index + 1, negation)
  
  new_summary = spacy("".join(tokens))
  augmentation_span = [(negated_index, negated_index if is_negative else negated_index + 1)]

  if new_summary.text == summary.text:
    return None, None
  else:
    return new_summary, augmentation_span

CLASS_TO_PRONOUN = {
            "SUBJECT": ["you", "he", "she", "we", "they"],
            "OBJECT": ["me", "you", "him", "her", "us", "them"],
            "POSSESSIVE": ["my", "your", "his", "her", "its", "our", "your", "their"],
            "REFLEXIVE": ["myself", "yourself", "himself", "itself", "ourselves", "yourselves", "themselves"]
        }


PRONOUN_TO_CLASS = {pronoun: key for (key, values) in CLASS_TO_PRONOUN.items() for pronoun in values}
PRONOUNS = [pronoun for pronoun in PRONOUN_TO_CLASS.keys()]


def pronounswap(summary):
  summary_pronouns = [token for token in summary if token.text.lower() in PRONOUNS]

  if not summary_pronouns:
    return None, None

  swap = random.choice(summary_pronouns)
  swap_index = swap.i
  swap_class = PRONOUN_TO_CLASS[swap.text.lower()]

  candidate_tokens = [token for token in CLASS_TO_PRONOUN[swap_class] if token != swap.text.lower()]

  if not candidate_tokens:
    return None, None

  swapped_token = random.choice(candidate_tokens)
  swapped_token = align_ws(swap.text_with_ws, swapped_token)
  swapped_token = swapped_token if swap.text.islower() else swapped_token.capitalize()

  summary_tokens = [token.text_with_ws for token in summary]
  summary_tokens[swap_index] = swapped_token

  new_summary = spacy("".join(summary_tokens))

  augmentation_span = [(swap_index, swap_index)]

  if new_summary.text == summary.text:
    return None, None

  else:
    return new_summary, augmentation_span


ENTITY_CATEGORIES = ("PERSON", "ORG", "NORP", "FAC", "GPE", "LOC", "PRODUCT",
                           "WORK_OF_ART", "EVENT")

def entityswap(summary, source):
  source_ents = [ent for ent in source.ents if ent.label_ in ENTITY_CATEGORIES]
  summary_ents = [ent for ent in summary.ents if ent.label_ in ENTITY_CATEGORIES]

  if not source_ents or not summary_ents:
    return None, None

  swap = random.choice(summary_ents)
  candidate_ents = [ent for ent in source_ents if ent.text != swap.text and ent.text not in swap.text and swap.text not in ent.text]

  if not candidate_ents:
    return None, None

  swapped_ent = random.choice(candidate_ents)
  summary_tokens = [token.text_with_ws for token in summary]
  swapped_token = align_ws(swap.text_with_ws, swapped_ent.text_with_ws)
  summary_tokens = summary_tokens[:swap.start] + [swapped_token] + summary_tokens[swap.end:]

  new_summary = spacy("".join(summary_tokens))
  augmentation_span = [(swap.start, swap.start + len(swapped_ent) - 1)]

  if new_summary.text == summary.text:
    return None, None

  else:
    return new_summary, augmentation_span

NUMBER_CATEGORIES = ("PERCENT", "MONEY", "QUANTITY", "CARDINAL")

def numberswap(summary, source):
  source_ents = [ent for ent in source.ents if ent.label_ in NUMBER_CATEGORIES]
  summary_ents = [ent for ent in summary.ents if ent.label_ in NUMBER_CATEGORIES]

  if not source_ents or not summary_ents:
    return None, None

  swap = random.choice(summary_ents)
  candidate_ents = [ent for ent in source_ents if ent.text != swap.text and ent.text not in swap.text and swap.text not in ent.text]

  if not candidate_ents:
    return None, None

  swapped_ent = random.choice(candidate_ents)
  summary_tokens = [token.text_with_ws for token in summary]
  swapped_token = align_ws(swap.text_with_ws, swapped_ent.text_with_ws)
  summary_tokens = summary_tokens[:swap.start] + [swapped_token] + summary_tokens[swap.end:]

  new_summary = spacy("".join(summary_tokens))
  augmentation_span = [(swap.start, swap.start + len(swapped_ent) - 1)]

  if new_summary.text == summary.text:
    return None, None

  else:
    return new_summary, augmentation_span

DATE_CATEGORIES = ("DATE", "TIME")

def dateswap(summary, source):
  source_ents = [ent for ent in source.ents if ent.label_ in DATE_CATEGORIES]
  summary_ents = [ent for ent in summary.ents if ent.label_ in DATE_CATEGORIES]

  if not source_ents or not summary_ents:
    return None, None

  swap = random.choice(summary_ents)
  candidate_ents = [ent for ent in source_ents if ent.text != swap.text and ent.text not in swap.text and swap.text not in ent.text]

  if not candidate_ents:
    return None, None

  swapped_ent = random.choice(candidate_ents)
  summary_tokens = [token.text_with_ws for token in summary]
  swapped_token = align_ws(swap.text_with_ws, swapped_ent.text_with_ws)
  summary_tokens = summary_tokens[:swap.start] + [swapped_token] + summary_tokens[swap.end:]

  new_summary = spacy("".join(summary_tokens))
  augmentation_span = [(swap.start, swap.start + len(swapped_ent) - 1)]

  if new_summary.text == summary.text:
    return None, None

  else:
    return new_summary, augmentation_span

from stanza.server import CoreNLPClient


POS_TAGS = ("NOUN", "ADJ", "PROPN")

def openIE_filter(extract, summary):
  new_text = ""
  extract = spacy(extract)
  start_index = 0
  for token in summary:
    if token.text != extract[0].text:
      start_index+=1
    else:
      break
  tracking_index = start_index
  for word in extract:
    while (summary[tracking_index].text != word.text):
      tracking_index+=1
    if summary[tracking_index].pos_ in POS_TAGS:
      break
  start_index = tracking_index
  while (summary[tracking_index].pos_ != "NOUN" and summary[tracking_index].pos_ != "PROPN" and tracking_index < len(summary)):
    new_text+=(summary[tracking_index].text_with_ws)
    tracking_index += 1
  if len(summary) == tracking_index:
    return new_text, start_index, tracking_index
  while (summary[tracking_index].pos_ == "NOUN" or summary[tracking_index].pos_ == "PROPN"):
    new_text+=(summary[tracking_index].text_with_ws)
    tracking_index += 1
  return new_text, start_index, tracking_index


def s_o_swap(summary):
  summary_ents = [ent for ent in summary.ents]
  story_triples = []
  document = client.annotate(summary.text, output_format='json')
  triples = []
  for sentence in document['sentences']:
      for triple in sentence['openie']:
          triples.append({
            'subject': triple['subject'],
            'relation': triple['relation'],
              'object': triple['object']
          })
  if not triples:
    return None, None
  candidate_triples = []
  for triple in triples:
    subj = triple['subject']
    obj = triple['object']
    allow_triple = False
    for ent in summary_ents:
      if subj == ent.text or subj in ent.text or ent.text in subj:
        allow_triple = True
      if obj == ent.text or obj in ent.text or ent.text in obj:
        allow_triple = True
    if allow_triple == True:
      candidate_triples.append(triple)
  if not candidate_triples:
    return None, None
  triple_swap = random.choice(candidate_triples)
  triple_swap["subject"], s_start, s_end = openIE_filter(triple_swap["subject"], summary)
  triple_swap["object"], o_start, o_end = openIE_filter(triple_swap["object"], summary)
  if s_start == 0:
    triple_swap["object"] = triple_swap["object"].capitalize()
    if summary[s_start].pos_ != "PROPN":
      triple_swap["subject"] = triple_swap["subject"][0].lower() + triple_swap["subject"][1:]
  if o_start == 0:
    triple_swap["subject"] = triple_swap["subject"].capitalize()
    if summary[o_start].pos_ != "PROPN":
      triple_swap["object"] = triple_swap["object"][0].lower() + triple_swap["object"][1:]
  if s_end <= o_start:
    new_summary = summary[:s_start].text_with_ws + triple_swap["object"] + summary[s_end:o_start].text_with_ws + triple_swap["subject"] + summary[o_end:].text_with_ws
    augmentation_span = [(s_start, o_end - o_start + s_start-1), (o_end - s_end + s_start, o_end-1)]
  elif o_end <= s_start:
    new_summary = summary[:o_start].text_with_ws + triple_swap["subject"] + summary[o_end:s_start].text_with_ws + triple_swap["object"] + summary[s_end:].text_with_ws
    augmentation_span = [(o_start, s_end - s_start + o_start-1), (s_end - o_end + o_start-1, s_end-2)]
  else:
    return None, None
  new_summary = spacy(new_summary)
  if new_summary.text == summary.text:
    return None, None
  else:
    return new_summary, augmentation_span


NOISE_PROB = 0.05
DELETE_PROB = 0.8
def addnoise(summary, augmentation_span):
        summary_tokens = [token.text_with_ws for token in summary]

        new_summary = []
        for ix, token in enumerate(summary_tokens):
            # don't modify text inside an augmented span
            apply_augmentation = True
            if augmentation_span:
              for aug_span in augmentation_span:
                if aug_span:
                  span_start, span_end = aug_span
                  if span_start <= ix <= span_end:
                      apply_augmentation = False

            # decide whether to add noise
            if apply_augmentation and random.random() < NOISE_PROB:
                # decide whether to replicate or delete token
                if random.random() < DELETE_PROB:
                    # update spans and skip token
                    if augmentation_span:
                      for el in range(0, len(augmentation_span)):
                        aug_span = augmentation_span[el]
                        if aug_span:
                          span_start, span_end = aug_span
                          if ix < span_start:
                            span_start -= 1
                            span_end -= 1
                          aug_span = span_start, span_end
                          augmentation_span[el] = aug_span
                      if len(new_summary) > 0:
                        if new_summary[-1][-1] != " ":
                          new_summary[-1] = new_summary[-1] + " "
                      continue      
                else:
                  if augmentation_span:
                    for el in range(0, len(augmentation_span)):
                      aug_span = augmentation_span[el]
                      if aug_span:
                        span_start, span_end = aug_span
                        if ix < span_start:
                          span_start += 1
                          span_end += 1
                        aug_span = span_start, span_end
                        augmentation_span[el] = aug_span

                  new_summary.append(token)
            new_summary.append(token)
        new_summary = spacy("".join(new_summary))

        if summary.text == new_summary.text:
            return None
        else:
            return new_summary
        



def tree_traverse(tree):
  used_names = []
  used_names.append(tree[0].name().split(".")[0])
  while tree[1:]:
    tree = random.choice(tree[1:])
    used_names.append(tree[0].name().split(".")[0])
  return used_names

def find_syn(word):
  syns = wn.synsets(word)
  if not syns:
    return None
  syn = random.choice(syns)
  return syn

def pos_to_wn(pos):
  if pos == "VERB":
    return wn.VERB
  if pos == "ADV":
    return wn.ADV
  if pos == "ADJ":
    return wn.ADJ
  


def rand(a, b):
  if not a:
    return b
  elif not b:
    return a
  else:
    r = random.random()
    if r < 0.5:
      return a
    else:
      return b


def tree_pert(word, POS):
  synonyms = []
  antonyms = []

  ssets = wn.synsets(word, pos=pos_to_wn(POS))
  if not ssets:
    ssets = wn.synsets(word)
    
  for syn in ssets:
      for l in syn.lemmas():
          synonyms.append(l.name())
          if l.antonyms():
              antonyms.append(l.antonyms()[0].name())
  if not synonyms and not antonyms:
    return None
  pert_base = rand(synonyms, antonyms)
  pert_set = []
  for el in pert_base:
    if el != word:
      pert_set.append(el)
  for wrd in pert_base:
    N = random.randint(1, 50)
    for iter in range(0, N):
      syn = find_syn(wrd)
      if not syn:
        break
      tree = wn.synset(syn.name()).mst(lambda s:s.hyponyms())
      choices = tree_traverse(tree)
      wrd = random.choice(choices)
      pert_set.append(wrd)

  pert_set = list(set(pert_set))
  new = random.choice(pert_set)
  return new


def generate_new(word, POS):
  new = tree_pert(word, POS)
  i = 0
  while new == word and i < 20:
    new = tree_pert(word, POS)
    i+=1
  if not new:
    return None
  new = new.split("_")
  new_word = ""
  for i in range(0, len(new)):
    if i < len(new) - 1:
      new_word+=(new[i] + " ")
    else:
      new_word+=new[i]
  return new_word


PERTURBABLE = ("ADJ", "ADV", "VERB")



def smartswap(summary):  
  document = client.annotate(summary.text, output_format='json')
  triples = []
  for sentence in document['sentences']:
      for triple in sentence['openie']:
          triples.append({
            'subject': triple['subject'],
            'relation': triple['relation'],
              'object': triple['object']
          })
  if not triples:
    return None, None
  candidate_triples = triples
  triple_pert = random.choice(candidate_triples)
  triple_rel = spacy(triple_pert["relation"] + " " + triple_pert["object"])
  triple_rel = [token.text for token in triple_rel]
  target_tokens = [token for token in summary if token.text in triple_rel and token.pos_ in PERTURBABLE and token.text!="-"]
  timer = 0
  while not target_tokens and timer < 20:
    triple_pert = random.choice(candidate_triples)
    triple_rel = spacy(triple_pert["relation"] + " " + triple_pert["object"])
    triple_rel = [token.text for token in triple_rel]
    target_tokens = [token for token in summary if token.text in triple_rel and token.pos_ in PERTURBABLE]
    timer+=1

  if not target_tokens:
    return None, None
  if len(target_tokens) > 4:
    target_tokens = random.sample(target_tokens, 4)
  
  indices = [token.i for token in target_tokens]
  replacements = []
  span_len = []
  for word in target_tokens:
    new = generate_new(word.text, word.pos_)
    if not new:
      new = word.text
    if new[0].isupper:
      new = new[0].lower() + new[1:]
    replacements.append(spacy(new))
    span_len.append(new.count(" ") + 1)
  summary_tokens = [token.text_with_ws for token in summary]
  for ix in range(0, len(indices)):
    summary_tokens[indices[ix]] = replacements[ix].text_with_ws + " "
  new_summary = "".join(summary_tokens)

  new_summary = spacy(new_summary)
  augmentation_span = []
  for ix in range(0, len(indices)):
    augmentation_span.append((indices[ix], indices[ix] + span_len[ix] - 1))

  new_summary_tokens = [token.text_with_ws for token in new_summary]

  for ix in range(0, len(indices)):
    if summary[indices[ix]].pos_ == "VERB":
      nlp_sum = nlp(summary.text)
      nlp_new = nlp(new_summary.text)
      if new_summary[indices[ix]].pos_ == "VERB":
        adj = nlp_new[indices[ix]]._.inflect(nlp_sum[indices[ix]].tag_)
        if adj:
          new_summary_tokens[indices[ix]] = adj + " "

  new_summary = "".join(new_summary_tokens)

  new_summary = spacy(new_summary)

  
  if summary.text == new_summary.text:
      return None, None
  else:
      return new_summary, augmentation_span


def predicateswap(summary):
  summary_preds = [token for token in summary if token.pos_ == "VERB" or token.pos_ == "ADV" or token.pos_ == "ADJ" and token.text!="-"]
  if not summary_preds:
      return None, None
  original = random.choice(summary_preds)
  ix = original.i
  new = generate_new(original.text, original.pos_)
  if not new:
    return None, None
  span_len = new.count(" ") + 1

  summary_tokens = [token.text_with_ws for token in summary]
  summary_tokens[ix] = new
  if ix < len(summary_tokens) - 1 and summary_tokens[ix+1] != "-":
    summary_tokens[ix] += " "
  new_summary = "".join(summary_tokens)
  new_summary = spacy(new_summary)
  augmentation_span = [(ix, ix + span_len - 1)]


  new_summary_tokens = [token.text_with_ws for token in new_summary]
  
  if summary[ix].pos_ == "VERB":
    nlp_sum = nlp(summary.text)
    nlp_new = nlp(new_summary.text)
    if new_summary[ix].pos_ == "VERB":
      adj = nlp_new[ix]._.inflect(nlp_sum[ix].tag_)
      if adj:
        new_summary_tokens[ix] = adj + " "

  new_summary = "".join(new_summary_tokens)

  new_summary = spacy(new_summary)

  if new_summary.text == summary.text:
    return None, None
  else:
    return new_summary, augmentation_span



In [3]:
from datasets import load_from_disk
dataset = load_from_disk('data/xsum_filtered/val')
dataset._data = dataset._data.filter(dataset['keep'])

In [4]:
# #RUN THIS CELL TO START THE CORENLP TAGGER
# client = CoreNLPClient(timeout=150000000, be_quiet=True, annotators=['openie'], 
# endpoint='http://localhost:9002')
# client.start()

In [5]:
# !apt update
# !apt install default-jre
# conda install -c anaconda openjdk

In [12]:
import pickle
import tqdm
perturbations = [
    "predicate",
    "smart",
    "s_o",
]

corrupt_percent = 0.2
np.random.seed(2021)
corrupt_indicies = np.random.choice(range(len(dataset)), size=int(corrupt_percent*len(dataset)))
corruption_groups = np.array_split(corrupt_indicies, len(perturbations))

total_corruption_count = 0
iteration = 0

spacy = nlp

output_dir = 'data/corruption_dict_test'
os.makedirs(output_dir, exist_ok=True)

no_noise = True

for pert_type, indicies in zip(perturbations, corruption_groups):
    corrupted_summaries = {}
    already_done_tracker = []
    total_corruption_count = 0
    iteration = 0
    if not no_noise:
        if os.path.exists(output_dir + f'/test_{pert_type}.pkl'):
            print('loaded previous file')
            with open(output_dir + f'/test_{pert_type}.pkl', 'rb') as f:

                corrupted_summaries = pickle.load(f)
                already_done_tracker = corrupted_summaries.keys()

    print('Corruption: ' + pert_type)
    
    for index in tqdm.tqdm(indicies):

        doc = dataset[int(index)]
        docid = doc['id']
        
        if docid in already_done_tracker:
            continue
            
        summary = nlp(doc['summary'])
        article = nlp(doc['document'])
        
        if pert_type == "s_o":
            try:
                new_summary, ags = s_o_swap(summary)
            except Exception as e:
                print(e)
                new_summary = None
                
        elif pert_type == "predicate":
                new_summary, ags = predicateswap(summary)
        elif pert_type == "smart":
            try:
                new_summary, ags = smartswap(summary)
            except Exception as e:
                print(e)
                new_summary = None

        if new_summary and pert_type != 'bt' and not no_noise:
#             print('nonoise')
            new_summary = addnoise(new_summary, ags)

        if new_summary:
            if new_summary.text != summary.text:
              corrupted_summaries[docid] = new_summary.text
              total_corruption_count += 1
                
        iteration +=1
        if iteration % 100 == 0:
          print(f"{pert_type}, save_index: {iteration} - SAVED!")
          print(f"{len(corrupted_summaries)}")
          with open(output_dir +  f'/test_{pert_type}.pkl', 'wb') as f:
            pickle.dump(corrupted_summaries, f)
    
    
    print('Corruption done and saved to: ' + output_dir + f'/test_{pert_type}.pkl')
    print('Total corruption_count = ' + str(total_corruption_count))
    
# client.stop()

Corruption: predicate


 20%|████████████████                                                                 | 101/510 [00:13<00:45,  9.01it/s]

predicate, save_index: 100 - SAVED!
96


 39%|███████████████████████████████▉                                                 | 201/510 [00:25<00:33,  9.27it/s]

predicate, save_index: 200 - SAVED!
190


 59%|███████████████████████████████████████████████▊                                 | 301/510 [00:37<00:22,  9.47it/s]

predicate, save_index: 300 - SAVED!
286


 79%|███████████████████████████████████████████████████████████████▋                 | 401/510 [00:49<00:13,  7.88it/s]

predicate, save_index: 400 - SAVED!
378


 98%|███████████████████████████████████████████████████████████████████████████████▌ | 501/510 [01:02<00:00,  9.68it/s]

predicate, save_index: 500 - SAVED!
469


100%|█████████████████████████████████████████████████████████████████████████████████| 510/510 [01:03<00:00,  8.04it/s]


Corruption done and saved to: data/corruption_dict_test/test_predicate.pkl
Total corruption_count = 491
Corruption: smart


  0%|                                                                                           | 0/509 [00:00<?, ?it/s]2021-12-11 01:16:40 INFO: Starting server with command: java -Xmx5G -cp ./corenlp/* edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 8889 -timeout 150000000 -threads 5 -maxCharLength 100000 -quiet True -serverProperties corenlp_server-93a781ec458c47ae.props -annotators openie -preload -outputFormat serialized
 20%|████████████████                                                                 | 101/509 [00:26<01:07,  6.06it/s]

smart, save_index: 100 - SAVED!
81


 40%|████████████████████████████████▏                                                | 202/509 [00:44<00:33,  9.13it/s]

smart, save_index: 200 - SAVED!
150


 59%|███████████████████████████████████████████████▉                                 | 301/509 [01:00<00:30,  6.93it/s]

smart, save_index: 300 - SAVED!
222


 79%|███████████████████████████████████████████████████████████████▊                 | 401/509 [01:17<00:18,  5.92it/s]

smart, save_index: 400 - SAVED!
297


 98%|███████████████████████████████████████████████████████████████████████████████▋ | 501/509 [01:33<00:01,  6.88it/s]

smart, save_index: 500 - SAVED!
367


100%|█████████████████████████████████████████████████████████████████████████████████| 509/509 [01:34<00:00,  5.38it/s]


Corruption done and saved to: data/corruption_dict_test/test_smart.pkl
Total corruption_count = 385
Corruption: s_o


  2%|█▉                                                                                | 12/509 [00:01<00:47, 10.53it/s]

[E040] Attempt to access token at 24, max length 24.


  7%|█████▋                                                                            | 35/509 [00:04<01:04,  7.39it/s]

[E040] Attempt to access token at 38, max length 38.


 14%|███████████▊                                                                      | 73/509 [00:09<01:02,  7.01it/s]

[E040] Attempt to access token at 18, max length 18.
[E040] Attempt to access token at 12, max length 12.


 20%|███████████████▉                                                                 | 100/509 [00:12<00:52,  7.80it/s]

s_o, save_index: 100 - SAVED!
55


 24%|███████████████████▋                                                             | 124/509 [00:15<00:51,  7.46it/s]

[E040] Attempt to access token at 12, max length 12.


 39%|███████████████████████████████▉                                                 | 201/509 [00:25<00:35,  8.78it/s]

s_o, save_index: 200 - SAVED!
119


 45%|████████████████████████████████████▌                                            | 230/509 [00:28<00:37,  7.43it/s]

[E040] Attempt to access token at 31, max length 31.


 59%|████████████████████████████████████████████████                                 | 302/509 [00:36<00:21,  9.45it/s]

s_o, save_index: 300 - SAVED!
183


 67%|██████████████████████████████████████████████████████▍                          | 342/509 [00:41<00:19,  8.38it/s]

[E040] Attempt to access token at 47, max length 47.


 79%|███████████████████████████████████████████████████████████████▊                 | 401/509 [00:49<00:14,  7.53it/s]

s_o, save_index: 400 - SAVED!
250


 96%|█████████████████████████████████████████████████████████████████████████████▉   | 490/509 [00:59<00:02,  9.36it/s]

[E040] Attempt to access token at 30, max length 30.


 98%|███████████████████████████████████████████████████████████████████████████████▋ | 501/509 [01:01<00:00,  8.23it/s]

s_o, save_index: 500 - SAVED!
311


100%|█████████████████████████████████████████████████████████████████████████████████| 509/509 [01:02<00:00,  8.20it/s]

Corruption done and saved to: data/corruption_dict_test/test_s_o.pkl
Total corruption_count = 328





In [None]:
# xsum_corrupted_train = xsum_corrupted_train.add_column('corrupted_summary', corrupted_summaries)
# xsum_corrupted_train.save_to_disk("data/xsum_new_corrupted/train")

In [None]:
corrupted_summaries

In [None]:
for docid in list(corrupted_summaries.keys())[:30]:
    ind = xsum_corrupted_train['id'].index(docid)
    print('original:\t'+xsum_corrupted_train[ind]['summary'])
    print('corrupt:\t' + corrupted_summaries[docid])