In [None]:
from datasets import load_dataset
from tqdm import tqdm
import csv

In [None]:
!pip3 install flair
import flair
from flair.data import Sentence
from flair.models import SequenceTagger

# load tagger
tagger = SequenceTagger.load("flair/ner-english-large")

In [None]:
import anonymizer
from anonymizer import entity
from anonymizer.core import initialize
from anonymizer.cache import NECache
import anonymizer.entity.person as person
from anonymizer.entity.org import org, org_wiki
from anonymizer.entity.gpe import gpe, gpe_wiki
from functools import partial
initialize()

In [None]:
def replace_entities_flair_wiki(text):
    # make example sentence
    sentence = Sentence(text)
    # predict NER tags
    tagger.predict(sentence)
    # iterate over entities and print
    replacements = []
    replacement_map = {}
    if not sentence.get_spans('ner'):
        return text

    for entity in sentence.get_spans('ner'):
        if entity.text in replacement_map:
            replacements.append((entity.start_position, entity.end_position, replacement_map[entity.text], entity.text))
            continue
        if entity.get_label().value == "ORG":
            repl = org.handle(entity.text.split(" "), NECache())
            if not repl or " ".join(repl) == entity.text:
                repl = org.handle(entity.text.split(" "), NECache())
            if not repl or " ".join(repl) == entity.text:
                continue
            replacements.append((entity.start_position, entity.end_position, " ".join(repl), entity.text))
            replacement_map[entity.text] = " ".join(repl)
        elif entity.get_label().value == "PER":
            repl = person.handle(entity.text.split(" "), NECache())
            if not repl or " ".join(repl) == entity.text:
                repl = person.handle(entity.text.split(" "), NECache())
            if not repl or " ".join(repl) == entity.text:
                continue
            replacements.append((entity.start_position, entity.end_position, " ".join(repl), entity.text))
            replacement_map[entity.text] = " ".join(repl)
        elif entity.get_label().value == "LOC":
            repl = gpe.handle([entity.text], {})
            if not repl or " ".join(repl) == entity.text:
                repl = gpe.handle([entity.text], {})
            if not repl or " ".join(repl) == entity.text:
                continue
            replacements.append((entity.start_position, entity.end_position, " ".join(repl), entity.text))
            replacement_map[entity.text] = " ".join(repl)

    if replacements:
        res = []
        i = 0
        s = text
        for (start, end, txt, orig) in replacements:
            assert orig != txt
            res.append(s[i:start] + txt)
            i = end
        res.append(s[end:])
        return ''.join(res)
    return text

In [None]:
cls_data = load_dataset("imdb")
train_data = cls_data['train']
print(train_data[0])

In [None]:
with open("imdb_train_flair_wiki.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["text","label"])
    for p in tqdm(train_data):
        src = replace_entities_flair_wiki(p['text'].replace("<br /><br />", " ").replace("<br />", ""))
        writer.writerow((src, p['label']))

In [None]:
cls_data = load_dataset("cnn_dailymail")
train_data = cls_data['train']
print(train_data[0])

In [None]:
with open("cnn_dm_train_flair_wiki.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["article","highlights"])
    for p in tqdm(train_data):
        src = replace_entities_flair_wiki(p['article'])
        trg = replace_entities_flair_wiki(p['highlights'])
        writer.writerow((src, trg))

## Spacy wiki

In [None]:
import spacy
nlp = spacy.load('en_core_web_sm')

In [None]:
def replace_entities_spacy_wiki(text):
    parsed = nlp(text)
    # iterate over entities and print
    replacements = []
    replacement_map = {}
    if all([w.ent_type == 0 for w in parsed]):
        return text

    for word in parsed:
        if word.text in replacement_map:
            replacement_map["-"] = "-"
            replacements.append((word.idx, word.idx + len(word.text), replacement_map[word.text], word.text))
            continue
        if word.ent_type_ == "ORG":
            repl = org.handle(word.text.split(" "), NECache())
            if not repl or " ".join(repl) == word.text:
                repl = org.handle(word.text.split(" "), NECache())
            if not repl or " ".join(repl) == word.text:
                continue
            replacements.append((word.idx, word.idx + len(word.text), " ".join(repl), word.text))
            replacement_map[word.text] = " ".join(repl)
        elif word.ent_type_ == "PERSON":
            repl = person.handle(word.text.split(" "), NECache())
            if not repl or " ".join(repl) == word.text:
                repl = person.handle(word.text.split(" "), NECache())
            if not repl or " ".join(repl) == word.text:
                continue
            replacements.append((word.idx, word.idx + len(word.text), " ".join(repl), word.text))
            replacement_map[word.text] = " ".join(repl)
        elif word.ent_type_ == "GPE":
            repl = gpe.handle([word.text], {})
            if not repl or " ".join(repl) == word.text:
                repl = gpe.handle([word.text], {})
            if not repl or " ".join(repl) == word.text:
                continue
            replacements.append((word.idx, word.idx + len(word.text), " ".join(repl), word.text))
            replacement_map[word.text] = " ".join(repl)
            replacement_map["-"] = "-"

    if replacements:
        res = []
        i = 0
        for (start, end, txt, orig) in replacements:
            assert orig != txt
            res.append(text[i:start] + txt)
#             print("\"" + text[i:start] + "\"", "\"" + orig + "\"", "\"" + txt + "\"")
            i = end
        res.append(text[end:])
        return ''.join(res)
    return text

In [None]:
cls_data = load_dataset("imdb")
train_data = cls_data['train']
print(train_data[0])

In [None]:
with open("imdb_train_spacy_wiki.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["text","label"])
    for p in tqdm(train_data):
        src = replace_entities_spacy_wiki(p['text'].replace("<br /><br />", " ").replace("<br />", ""))
        writer.writerow((src, p['label']))

In [None]:
with open("imdb_train_spacy_wiki.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["text","label"])
    for p in tqdm(train_data):
        src = replace_entities_spacy_wiki(p['text'].replace("<br /><br />", " ").replace("<br />", ""))
        writer.writerow((src, p['label']))

In [None]:
cls_data = load_dataset("cnn_dailymail")
train_data = cls_data['train']
print(train_data[0])

In [None]:
with open("cnn_dm_train_spacy_wiki.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["article","highlights"])
    for p in tqdm(train_data):
        src = replace_entities_spacy_wiki(p['article'])
        trg = replace_entities_spacy_wiki(p['highlights'])
        writer.writerow((src, trg))