## Editing Model Demo

This notebook is a demonstration of how to use the T5 transformer model to automate the language editing process.

### Data Collection

A high quality dataseet of unedited and edited sentence pairs is required to train this model. Unfortunately no such dataset is freely available. As an alternative, this notebook demonstrates the process using an artificially constructed dataset that is designed to meet these requirements. A range of errors will be introduced to sentences from an NLTK corpus; in this case, the original sentence is equivalent to the edited sentence, and the sentence with introduced errors is equivalent to the unedited sentence. 

For demonstration purposes, three simple errors will be applied to the sentences: 1) plural nouns will be randomly singularized and singular nouns will be randomly pluralized; 2) indefinite articles (a/an) will be randomly switched; and 3) commas and semicolons will be randomly switched. 

In [132]:
import nltk
from nltk.corpus import brown
nltk.download('brown')
from nltk.tokenize.treebank import TreebankWordDetokenizer
import random
import pandas as pd
import pickle

[nltk_data] Downloading package brown to
[nltk_data]     C:\Users\User\AppData\Roaming\nltk_data...
[nltk_data]   Package brown is already up-to-date!


In [114]:
def plural_error(tagged_sent, p):
    """
    Introduce plural errors to nouns in tagged_sent with 
    probability p; tagged_sent is a list of pairs where the first
    pair item is the word and the second pair item is the tag
    """
    upper_prob_limit = int(1/p) - 1
    for i, (word, tag) in enumerate(tagged_sent):
        if tag == "NNS" and random.randint(0, upper_prob_limit) == 0:
            if word[-3:] == "ies":
                tagged_sent[i] = (word[:-3] + "y", tag)
            else:
                tagged_sent[i] = (word[:-1], tag)
        if tag == "NN" and random.randint(0, upper_prob_limit) == 0:
            if word[-1] == "y":
                tagged_sent[i] = (word[:-1] + "ies", tag)
            else:
                tagged_sent[i] = (word + "s", tag)
    return tagged_sent

In [120]:
sent = brown.tagged_sents(categories=['news'])[4618]
sent = plural_error(sent, 0.5)
print(TreebankWordDetokenizer().detokenize([w for w, t in sent]))

In college library , 57 per cents of the totals numbers of book are owned by 124 of 1,509 institutions surveyed last years by the U.S. Office of Education.


In [60]:
def indefinite_article_error(tagged_sent, p):
    """
    Introduce indefinite article errors (a/an) to tagged_sent with 
    probability p; tagged_sent is a list of pairs where the first
    pair item is the word and the second pair item is the tag
    """
    upper_prob_limit = int(1/p) - 1
    for i, (word, tag) in enumerate(tagged_sent):
        if word == "a" and random.randint(0, upper_prob_limit) == 0:
            tagged_sent[i] = ("an", tag)
        if word == "an" and random.randint(0, upper_prob_limit) == 0:
            tagged_sent[i] = ("a", tag)
    return tagged_sent

In [131]:
sent = brown.tagged_sents(categories=['news'])[237]
sent = plural_error(sent, 0.5)
sent = indefinite_article_error(sent, 0.5)
print(TreebankWordDetokenizer().detokenize([w for w, t in sent]))

-- President Kennedy today proposed an mammoth new medical cares program whereby social security taxes on 70 million American worker would be raised to pay the hospital and some other medical bills of 14.2 million Americans over 65 who are covered by social security or railroad retirements program.


In [73]:
def semicolon_error(tagged_sent, p):
    """
    Introduce semicolon errors (;/,) to tagged_sent with 
    probability p; tagged_sent is a list of pairs where the first
    pair item is the word and the second pair item is the tag
    """
    upper_prob_limit = int(1/p) - 1
    for i, (word, tag) in enumerate(tagged_sent):
        if word == "," and random.randint(0, upper_prob_limit) == 0:
            tagged_sent[i] = (";", tag)
        if word == ";" and random.randint(0, upper_prob_limit) == 0:
            tagged_sent[i] = (",", tag)
    return tagged_sent

In [87]:
sent = brown.tagged_sents(categories=['news'])[101]
sent = plural_error(sent, 0.5)
sent = indefinite_article_error(sent, 0.5)
sent = semicolon_error(sent, 0.5)
print(TreebankWordDetokenizer().detokenize([w for w, t in sent]))

Under committee rule; it went automatically to a subcommittees for one week.


In [121]:
edited_sents = [TreebankWordDetokenizer().detokenize(sent) for sent in brown.sents(categories=['news'])]

error_tagged_sents = [semicolon_error(indefinite_article_error(plural_error(sent, 0.5), 0.5), 0.5) 
                      for sent in brown.tagged_sents(categories=['news'])]
unedited_sents = [TreebankWordDetokenizer().detokenize([w for w, t in error_tagged_sent]) 
                  for error_tagged_sent in error_tagged_sents]

In [122]:
df = pd.DataFrame({"unedited": unedited_sents, "edited": edited_sents})
df

Unnamed: 0,unedited,edited
0,The Fulton County Grand Jury said Friday an in...,The Fulton County Grand Jury said Friday an in...
1,The jury further said in term-ends presentment...,The jury further said in term-end presentments...
2,The September-October terms juries had been ch...,The September-October term jury had been charg...
3,"\"" Only an relative handful of such report was...","\"" Only a relative handful of such reports was..."
4,The juries said it did find that many of Georg...,The jury said it did find that many of Georgia...
...,...,...
4618,"In college library , 57 per cent of the total ...","In college libraries , 57 per cent of the tota..."
4619,And over 66 per cent of the elementary schools...,And over 66 per cent of the elementary schools...
4620,"In every aspect of services--to the public, to...","In every aspect of service--to the public, to ..."
4621,Only public understandings and supports can pr...,Only public understanding and support can prov...


In [134]:
with open("dataset", "wb") as file:
    pickle.dump(df, file)

### Prepare Dataloaders

In [4]:
import pickle
from sklearn.model_selection import train_test_split
import pandas as pd

import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

In [2]:
with open("dataset", "rb") as file:
    df = pickle.load(file)

In [5]:
# Remove entries without editing changes
df = df[~(df.unedited == df.edited)]

# Add the "edit" instruction to the input sentences for the T5 model
df.unedited = pd.Series(["edit: "] * df.shape[0]) + df.unedited

Split the dataset into training and validation sets.

In [6]:
df_train, df_val = train_test_split(df, test_size=0.2)

In [7]:
class EditDataset(Dataset):
    def __init__(self, unedited_sentences, edited_sentences):
        self.unedited_sentences = unedited_sentences
        self.edited_sentences = edited_sentences
         
    def __len__(self):
        return (len(self.unedited_sentences))
    
    def __getitem__(self, i):
        return (self.unedited_sentences[i], self.edited_sentences[i])

In [8]:
train_data = EditDataset(df_train.unedited.values, df_train.edited.values)
val_data = EditDataset(df_val.unedited.values, df_val.edited.values)

# dataloaders
trainloader = DataLoader(train_data, batch_size=8, shuffle=True)
valloader = DataLoader(val_data, batch_size=8, shuffle=True)