In [1]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import matplotlib as mpl
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, AutoTokenizer, BertForMaskedLM
from torch.nn.functional import softmax
from difflib import SequenceMatcher
from spellchecker import SpellChecker
from string import punctuation
from time import time
import nltk
nltk.download('words')

[nltk_data] Downloading package words to C:\Users\bills-fish-
[nltk_data]     shack\AppData\Roaming\nltk_data...
[nltk_data]   Package words is already up-to-date!


True

In [2]:
topk = 2000  # number of top predicted tokens to retrieve (before excluding non-words) 

class GPT2:
    def __init__(self, model="gpt2"):
        self.model     =   GPT2LMHeadModel.from_pretrained(model)
        self.tokenizer = GPT2TokenizerFast.from_pretrained(model)
        self.model_id  = model
    
    def get_word_probs(self, sentence, n=topk):  # adapted from raul on stackoverflow
        inputs = self.tokenizer.encode(sentence, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(inputs)
            predictions = outputs[0]
        candidates = predictions[0, -1, :]                          # Get the next token candidates.
        topk_i = torch.topk(candidates, n).indices.tolist()         # Get the top k next token candidates.
        all_probs = torch.nn.functional.softmax(candidates, dim=-1) # Get the token probabilities for all candidates.
        topk_probs = all_probs[topk_i].tolist()                     # Filter the token probabilities for the top k candidates.
        topk_tokens = [self.tokenizer.decode([idx]).strip()         # Decode the top k candidates back to words.
                       for idx in topk_i]
        return list(zip(topk_tokens, topk_probs))

class BERT:
    def __init__(self, model="google-bert/bert-base-uncased"):
        self.model     = BertForMaskedLM.from_pretrained(model)
        self.tokenizer =   AutoTokenizer.from_pretrained(model)
        self.model_id  = model
        
    def get_word_probs(self, prompt, topk=topk):                  # Get topk masked token candidates
        inputs = self.tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
            logits = self.model(**inputs).logits
        mask_index  = (inputs.input_ids == self.tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
        mask_logits = logits.squeeze()[mask_index].squeeze()
        probs = softmax(mask_logits, dim=-1)
        topk = 5000
        topk_probs, topk_i = torch.topk(probs, topk, dim=-1)
        topk_tokens = np.array([self.tokenizer.decode([i]) for i in topk_i])
        return np.hstack((topk_tokens.reshape(-1,1), np.array(topk_probs).reshape(-1,1)))

M_GPT2 = GPT2("gpt2")
M_BERT = BERT("google-bert/bert-base-uncased")

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at google-bert/bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with ano

In [3]:
def similar(a, b):
    common_len = round((len(a)+len(b))/2)
    adjustment = 0
    adjustment_table = {1: 0.4, 2: 0.3, 3: 0.2, 4: 0.1}
    if common_len in adjustment_table: adjustment = adjustment_table[common_len]*(np.e**(-1*np.abs(len(a)-len(b))))
    return SequenceMatcher(None, a, b).ratio() + adjustment
def rreplace(string, word, new_word):
    start = string.rfind(word)
    return string[0:start] + new_word + string[start+len(word):]
wl          = set(nltk.corpus.words.words())
log_map     = lambda e: np.vectorize(lambda x: np.power(np.log(x/0.5)/np.log(2), e))  # specify exponent to return vectorized mapping
after_slash = lambda x: x[(x.rfind("/")+1 if x.rfind("/") != -1 else 0):]

In [4]:
def correct(string, back_n):
    places = range(1,back_n+1)
    words  = string.split()
    for n in places:
        if n > len(words): break
        if n > 1:
            model  = M_BERT
            masked = "[MASK]" + words[-n][-1] if not words[-n][-1].isalpha() else "[MASK]"
            target = words[-n].strip(punctuation)
            prompt = ' '.join(words[:-n] + [masked] + words[len(words)-(n-1):])
        else:
            model  = M_GPT2
            string = string.strip()
            last_space = string.rfind(' ')
            prompt = string[:last_space]
            target = string[last_space+1:].strip(punctuation)
        probs  = model.get_word_probs(prompt)                
        probsp = [(str(word), float(prob), float(similar(target, word))) for word, prob in probs if word in wl]
        close_probs = [prob for prob in probsp if prob[2] > 0.5 and prob[1] >= min(0.001, probsp[consider_top][1])]
        props = [(word, (prob**prob_exp)*log_map(log_exp)(sim)) for word, prob, sim in close_probs]
        props = sorted(props, reverse=True, key=lambda x: x[1])
        props = [prop for prop in props if prop[1] > 0.000001]
        probN = threshold(n)
        make_correction = False
        if len(props) > 0 and props[0][1] > probN:
            make_correction = True
            irr_t = props[0][1] * relevency_t
            for word, score in props: 
                if score < irr_t: break
                elif target.lower() == word.lower():
                    make_correction = False
        if make_correction: return (n, props[0][0])
    return False

In [5]:
strings = [#"when you come in can you remember to feel the cat",
           #"I have to right a note",
           #"This is Eric. He's going to fry to",
           #"do you want any walt or pepper",    
           #"Please don't forget to turn off the store when",
           #"are you going to wear the yellow hat or the bed one",
           #"when do you want to get up to see the fun rise?",
           #"can you go and let my",
           #"can you let my",
           #"you led us on a wild goose case",
           #"when you come over can you remember to being",
           #"when you come over can you being the soup",
           #"when you come over can you, being the",
           #"I went outside and the wind flew my hat",
           #"After I get out of the shower I usually growl",  # don't correct
           #"Don't step on my wet bug", 
           #"the rally",
           #"can you really climb all the way up that really tall birdling?",
           #"I think you need to pat more attention",
           #"This method isn't really as grate",
           #"Don't step on the wet floor, we're freaching",  # don't correct
           #"Who dat",  # don't correct
           #"When you come over, can you bring the flock we talked about?"
           ]
back_n = 3  # number of words back from end of string, 1 is just last word

log_exp        = 4  # exponent parameter for logarithmic mapping
prob_exp       = 1.6  # raise probability to power in ((prob**power)*log-sim)
consider_top   = 200  # max top model word predictions considered
relevency_t    = 0.2  # threshold defined by portion of top proposition to exclude much smaller scored propositions for correcting
base_t         = 0.0001  # decision threshold for last word: base threshold
threshold_type = "jump-exp"  # function defines decision threshold for word n from end
threshold      = {"constant":    lambda n: base_t,
                  "linear":      lambda n: base_t + (base_t * (n-1)),
                  "exponential": lambda n: base_t * (n**2),
                  "jump-exp":    lambda n: base_t * (max(n-1,1)**2),        # jump thresholds start growing after n=2
                  "jump-lin":    lambda n: base_t + (base_t * max(n-2, 0))
                 }[threshold_type]

correct("when you come in can you remember to feel the cat", 3)

(3, 'feed')