In [1]:
from torch.nn.functional import softmax
from transformers import pipeline, GPT2TokenizerFast, GPT2LMHeadModel, AutoTokenizer, BertForMaskedLM
from autocorrect import Speller
from nltk.stem import WordNetLemmatizer, PorterStemmer
from difflib import SequenceMatcher
from string import punctuation
from time import time
import numpy as np
import pandas as pd
import torch
import nltk
nltk.download('words')

[nltk_data] Downloading package words to C:\Users\bills-fish-
[nltk_data]     shack\AppData\Roaming\nltk_data...
[nltk_data]   Package words is already up-to-date!


True

In [2]:
topk = 2000  # number of top predicted tokens to retrieve (before excluding non-words) 

class GPT2:
    def __init__(self, model="gpt2"):
        self.model     =   GPT2LMHeadModel.from_pretrained(model)
        self.tokenizer = GPT2TokenizerFast.from_pretrained(model)
        self.model_id  = model
    
    def get_word_probs(self, sentence, n=topk):  # adapted from raul on stackoverflow
        inputs = self.tokenizer.encode(sentence, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(inputs)
            predictions = outputs[0]
        candidates = predictions[0, -1, :]                          # Get the next token candidates.
        topk_i = torch.topk(candidates, n).indices.tolist()         # Get the top k next token candidates.
        all_probs = torch.nn.functional.softmax(candidates, dim=-1) # Get the token probabilities for all candidates.
        topk_probs = all_probs[topk_i].tolist()                     # Filter the token probabilities for the top k candidates.
        topk_tokens = [self.tokenizer.decode([idx]).strip()         # Decode the top k candidates back to words.
                       for idx in topk_i]
        return list(zip(topk_tokens, topk_probs))

class BERT:
    def __init__(self, model="google-bert/bert-base-uncased"):
        self.model     = BertForMaskedLM.from_pretrained(model)
        self.tokenizer =   AutoTokenizer.from_pretrained(model)
        self.model_id  = model
        
    def get_word_probs(self, prompt, topk=topk):                  # Get topk masked token candidates
        inputs = self.tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
            logits = self.model(**inputs).logits
        mask_index  = (inputs.input_ids == self.tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
        mask_logits = logits.squeeze()[mask_index].squeeze()
        probs = softmax(mask_logits, dim=-1)
        topk = 5000
        topk_probs, topk_i = torch.topk(probs, topk, dim=-1)
        topk_tokens = np.array([self.tokenizer.decode([i]) for i in topk_i])
        return np.hstack((topk_tokens.reshape(-1,1), np.array(topk_probs).reshape(-1,1)))

M_GPT2 = GPT2("gpt2")
M_BERT = BERT("google-bert/bert-base-uncased")

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From ðŸ‘‰v4.50ðŸ‘ˆ onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at google-bert/bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or wi

In [3]:
def similar(a, b):
    common_len = np.ceil((len(a)+len(b))/2)
    adjustment = 0
    adjustment_table = {1: 0.5, 2: 0.3, 3: 0.2, 4: 0.1}
    if common_len in adjustment_table: adjustment = adjustment_table[common_len]*(np.e**(-k*(np.abs(len(a)-len(b))-ap))-bp)
    return SequenceMatcher(None, a, b).ratio() + adjustment
def rreplace(string, word, new_word):
    start = string.rfind(word)
    return string[0:start] + new_word + string[start+len(word):]

lemmatizer  = WordNetLemmatizer()
lemma       = lambda x: lemmatizer.lemmatize(x)
stemmer     = PorterStemmer()
stem        = lambda x: stemmer.stem(x)
spell       = Speller()
wl          = set(nltk.corpus.words.words())
log_map     = lambda e: np.vectorize(lambda x: np.power(np.log(x/0.5)/np.log(2), e))  # specify exponent to return vectorized mapping
after_slash = lambda x: x[(x.rfind("/")+1 if x.rfind("/") != -1 else 0):]

In [4]:
def correction(string, back_n):
    places = reversed(range(1,back_n+1))
    if back_n == 0: places = [1, 3, 2]
    string = string.strip()
    words  = string.split()
    last_space = string.rfind(' ')
    for n in places:
        if n > len(words) or len(words) == 1: break
        spelled = False
        if n > 1:
            model  = M_BERT
            masked = "[MASK]" + words[-n][-1] if not words[-n][-1].isalpha() else "[MASK]"
            target = words[-n].strip(punctuation)
            prompt = ' '.join(words[:-n] + [masked] + words[len(words)-(n-1):])
            target = words[-n].strip(punctuation)
            if target != spell(target):
                spelled = True
        else:
            model  = M_GPT2
            string = string.strip()
            last_space = string.rfind(' ')
            prompt = string[:last_space]
            target = string[last_space+1:].strip(punctuation)
            target = words[-n].strip(punctuation)
            if target != spell(target):
                spelled = True
        probs  = model.get_word_probs(prompt)                
        probsp = [(str(word), float(prob), float(similar(target.lower(), word.lower()))) for word, prob in probs if word in wl]
        close_probs = [prob for prob in probsp if prob[2] > 0.5 and prob[1] >= min(0.001, probsp[consider_top][1])]
        props = [(word, (prob**prob_exp)*log_map(log_exp)(sim)) for word, prob, sim in close_probs]
        props = sorted(props, reverse=True, key=lambda x: x[1])
        props = [prop for prop in props if prop[1] > 0.000001]
        probN = threshold(n)
        make_correction = False
        if len(props) > 0 and props[0][1] > probN:
            make_correction = True
            irr_t = props[0][1] * relevency_t
            for word, score in props: 
                if score < irr_t: break
                elif target.lower() == word.lower() or stem(target.lower())  == word.lower() or lemma(target.lower()) == word.lower():
                    make_correction = False
        if make_correction: return (n, props[0][0], spelled)
        if spelled: return (n, target, spelled)
    return False
def process_correction(string, back_n):
    corrected = string
    words     = string.split()
    is_correction = correction(string, back_n)
    if is_correction:
        n, word, _ = is_correction
        words[-n] = word if words[-n][-1] not in punctuation else word + words[-n][-1]
        corrected = " ".join(words)
        #is_correction = correction(corrected, back_n)
    return corrected if corrected != string else False

In [5]:
strings = [
("What if we went to the stove?",                       "What if we went to the store?"),
("when you come in can you remember to feel the cat?",  "when you come in can you remember to feed the cat?"),
("when you come in can you remember to feed the cad?",  "when you come in can you remember to feed the cat?"),
("I have to right a note",                              "I have to write a note"),
("This is Eric. He's going to fry to",                  "This is Eric. He's going to try to"),
("do you want any walt or pepper",                      "do you want any salt or pepper"),
("Please don't forget to turn off the store when",      "Please don't forget to turn off the stove when"),
("are you going to wear the yellow hat or the bed one", "are you going to wear the yellow hat or the red one"),
("when do you want to get up to see the fun rise?",     "when do you want to get up to see the sun rise?"),
("can you go and let my",                               "can you go and get my"),
("can you let my",                                      "can you let my"),
("you led us on a wild goose case",                     "you led us on a wild goose chase"),
("when you come over can you remember to being",        "when you come over can you remember to bring"),
("when you come over can you being the soup",           "when you come over can you bring the soup"),
("when you come over can you, being the",               "when you come over can you, being the"),
("I went outside and the wind flew my hat",             "I went outside and the wind blew my hat"),
("After I get out of the shower I usually growl",       "After I get out of the shower I usually growl"),
("Don't step on my wet bug",                            "Don't step on my wet rug, I'm cleaning"),
("the rally big bar",                                   "the really big bar"),
("I think you need to pat more attention",              "I think you need to pay more attention"),
("This method isn't really as grate",                   "This method isn't really as great"),
("Don't step on the wet rug, we're gleaning",           "Don't step on the wet rug, we're cleaning"),
("Who dat",                                             "Who dat"),
("When you come over, can you bring the flock we talked", "When you come over, can you bring the flock we talked"),
("can you remember to feed my cad?",                    "can you remember to feed my cat?"),
("When you come home can",                              "When you come home can"),
("everything that I want to happen. Im",                "everything that I want to happen. Im"),
("I'm really not that tipe",                            "I'm really not that tipe"),           
("Are you hot as well?",                                "Are you hot as well?"),
("This isn't the final version, but it'll work",        "This isn't the final version, but it'll work"),
("I just wanted",                                       "I just wanted"),
("I want to read a good bool",                          "I want to read a good book"),
("The car drove down the rode",                         "The car drove down the road"),
("They went on a walk in the pouring rein",             "They went on a walk in the pouring rain"),
("The shoe was way too lose",                           "The shoe was way too loose"),
("He bought a gift for his mother",                     "He bought a gift for his mother"),
("The movie was really exiting",                        "The movie was really exciting"),
("Please bring a piece offering",                       "Please bring a peace offering"),
("I'll meat you at",                                    "I'll meet you at"),
("Let's go out to the bark tomorrow",                   "Let's go out to the park tomorrow"),
("Can you tell me a good peace of advice",              "Can you tell me a good piece of advice"),
("I need to go get some more bread at the barkery",     "I need to go get some more bread at the bakery"),
("She wanted to file a complain",                       "She wanted to file a complaint"),
("She has a really unique styl",                        "She has a really unique style"),
("The fire truck went dawn the road",                   "The fire truck went down the road"),
("I love the new flwer arrangement",                    "I love the new flower arrangement"),
("Please bare with me",                                 "Please bear with me"),
("Please don't make any load noises",                   "Please don't make any loud noises"),
("Please don't make any loud nose",                     "Please don't make any loud noise"),
("The sign read: no praking" ,                          "The sign read: no parking"),
("The cloths were drying",                              "The clothes were drying"),
("He got a big laugh out of that hoke",                 "He got a big laugh out of that joke"),
("He was a brave solder",                               "He was a brave soldier"),
# add negatives (sentences not changed), maybe 50/50
]
for i, x in enumerate(strings): print(f'{i}: {x}')

0: ('What if we went to the stove?', 'What if we went to the store?')
1: ('when you come in can you remember to feel the cat?', 'when you come in can you remember to feed the cat?')
2: ('when you come in can you remember to feed the cad?', 'when you come in can you remember to feed the cat?')
3: ('I have to right a note', 'I have to write a note')
4: ("This is Eric. He's going to fry to", "This is Eric. He's going to try to")
5: ('do you want any walt or pepper', 'do you want any salt or pepper')
6: ("Please don't forget to turn off the store when", "Please don't forget to turn off the stove when")
7: ('are you going to wear the yellow hat or the bed one', 'are you going to wear the yellow hat or the red one')
8: ('when do you want to get up to see the fun rise?', 'when do you want to get up to see the sun rise?')
9: ('can you go and let my', 'can you go and get my')
10: ('can you let my', 'can you let my')
11: ('you led us on a wild goose case', 'you led us on a wild goose chase')
12:

In [6]:
back_n  = 0  # number of words back from end of string, 1 is just last word, 0 is [1, 3, 2]
k       = 1.2  # exponent parameter for exponential decay of word length augmedented SequenceMatcher
ap      = 0.57  # exponent parameter
bp      = 1  # exponent parameter
log_exp      = 5  # exponent parameter for logarithmic mapping
prob_exp     = 1.5  # raise probability to power in ((prob**power)*log-sim)
consider_top = 100  # max top model word predictions considered
relevency_t  = 0.05  # threshold defined by portion of top proposition to exclude much smaller scored propositions for correcting
base_t       = 0.0001  # decision threshold for last word: base threshold
threshold_e  = 1.5  # exponent for exponential thresholds
threshold_t  = "exponential"  # function defines decision threshold for word n from end
threshold    = {"constant":    lambda n: base_t,
                "linear":      lambda n: base_t + (base_t * (n-1)),
                "exponential": lambda n: base_t * (n**threshold_e),
                "jump-exp":    lambda n: base_t * (max(n-1,1)**threshold_e),        # jump thresholds start growing after n=2
                "jump-lin":    lambda n: base_t + (base_t * max(n-2, 0))
               }[threshold_t]

y_h = []
for x, y in strings: y_h.append(process_correction(x, back_n) == y)
y_h = np.array(y_h)
f'total: {y_h.shape[0]}, true: {np.array(y_h).sum()}, accuracy: {np.array(y_h).sum()/y_h.shape[0]}'

'total: 53, true: 32, accuracy: 0.6037735849056604'