In [1]:
from torch.nn.functional import softmax
from transformers import pipeline, GPT2TokenizerFast, GPT2LMHeadModel, AutoTokenizer, BertForMaskedLM
from autocorrect import Speller
from nltk.stem import WordNetLemmatizer, PorterStemmer
from textblob import TextBlob
from difflib import SequenceMatcher
from string import punctuation
from time import time
import numpy as np
import pandas as pd
import torch
import nltk
nltk.download('words')

[nltk_data] Downloading package words to C:\Users\bills-fish-
[nltk_data]     shack\AppData\Roaming\nltk_data...
[nltk_data]   Package words is already up-to-date!


True

In [2]:
topk = 200  # number of top predicted tokens to retrieve (before excluding non-words) 

class GPT2:
    def __init__(self, model="gpt2"):
        self.model     =   GPT2LMHeadModel.from_pretrained(model)
        self.tokenizer = GPT2TokenizerFast.from_pretrained(model)
        self.model_id  = model
    
    def get_word_probs(self, sentence, n=topk):  # adapted from raul on stackoverflow
        inputs = self.tokenizer.encode(sentence, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(inputs)
            predictions = outputs[0]
        candidates = predictions[0, -1, :]                          # Get the next token candidates.
        topk_i = torch.topk(candidates, n).indices.tolist()         # Get the top k next token candidates.
        all_probs = torch.nn.functional.softmax(candidates, dim=-1) # Get the token probabilities for all candidates.
        topk_probs = all_probs[topk_i].tolist()                     # Filter the token probabilities for the top k candidates.
        topk_tokens = [self.tokenizer.decode([idx]).strip()         # Decode the top k candidates back to words.
                       for idx in topk_i]
        return np.array(list(zip(topk_tokens, topk_probs)))

class BERT:
    def __init__(self, model="google-bert/bert-base-uncased"):
        self.model     = BertForMaskedLM.from_pretrained(model)
        self.tokenizer =   AutoTokenizer.from_pretrained(model)
        self.model_id  = model
        
    def get_word_probs(self, prompt, topk=topk):                  # Get topk masked token candidates
        inputs = self.tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
            logits = self.model(**inputs).logits
        mask_index  = (inputs.input_ids == self.tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
        mask_logits = logits.squeeze()[mask_index].squeeze()
        probs = softmax(mask_logits, dim=-1)
        topk_probs, topk_i = torch.topk(probs, topk, dim=-1)
        topk_tokens = np.array([self.tokenizer.decode([i]) for i in topk_i])
        return np.hstack((topk_tokens.reshape(-1,1), np.array(topk_probs).reshape(-1,1)))

M_GPT2 = GPT2("gpt2")
M_BERT = BERT("google-bert/bert-base-uncased")

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at google-bert/bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with ano

In [3]:
def similar(a, b):
    common_len = np.ceil((len(a)+len(b))/2)
    adjustment = 0
    adjustment_table = {1: 0.5, 2: 0.3, 3: 0.2, 4: 0.1}
    if common_len in adjustment_table: adjustment = adjustment_table[common_len]*(np.e**(-k*(np.abs(len(a)-len(b))-ap))-bp)
    return SequenceMatcher(None, a, b).ratio() + adjustment
def rreplace(string, word, new_word):
    start = string.rfind(word)
    return string[0:start] + new_word + string[start+len(word):]
def blob_correct(sentence):
    blob = TextBlob(sentence)
    return str(blob.correct())
    
threshold = {"constant":    lambda n: base_t,
             "linear":      lambda n: base_t + (base_t * (n-1)),
             "exponential": lambda n: base_t * (n**threshold_e),
             "jump-exp":    lambda n: base_t * (max(n-1,1)**threshold_e),        # jump thresholds start growing after n=2
             "jump-lin":    lambda n: base_t + (base_t * max(n-2, 0))}
lemmatizer  = WordNetLemmatizer()
lemma       = lambda x: lemmatizer.lemmatize(x)
stemmer     = PorterStemmer()
stem        = lambda x: stemmer.stem(x)
spell       = Speller()
wl          = set(nltk.corpus.words.words())
log_map     = lambda e: np.vectorize(lambda x: np.power(np.log(x/0.5)/np.log(2), e))  # specify exponent to return vectorized mapping
after_slash = lambda x: x[(x.rfind("/")+1 if x.rfind("/") != -1 else 0):]

In [4]:
def get_probs(string, back_n):
    probs = []
    places = range(1,back_n+1)
    string = string.strip()
    words  = string.split()
    last_space = string.rfind(' ')
    for n in places:
        if n > len(words) or len(words) == 1: break
        spelled = False
        if n > 1:
            model  = M_BERT
            masked = "[MASK]" + words[-n][-1] if not words[-n][-1].isalpha() else "[MASK]"
            target = words[-n].strip(punctuation)
            prompt = ' '.join(words[:-n] + [masked] + words[len(words)-(n-1):])
        else:
            model  = M_GPT2
            string = string.strip()
            last_space = string.rfind(' ')
            prompt = string[:last_space]
        prob = model.get_word_probs(prompt)
        if len(prob) > 0:
            probs.append(prob)
        else:
            probs.append(np.array([['dummy', 0]]))
    return probs

def get_props(target, probs):
    probs[:, 1] = probs[:, 1].astype(float) / probs[:, 1].astype(float).sum()
    probsp = [(str(word), float(prob), float(similar(target.lower(), word.lower()))) 
              for word, prob in probs if word in wl]
    close_probs = [prob for prob in probsp if prob[2] > 0.5 and prob[1] >= 0.001]
    props = sorted([(word, (prob**prob_exp)*log_map(log_exp)(sim)) for word, prob, sim in close_probs], 
                   reverse=True, key=lambda x: x[1])
    return props

def make_correction(target, props, probN):
    make_correction = False
    if len(props) > 0 and float(props[0][1]) > probN:
        make_correction = True
        irr_t = float(props[0][1]) * relevency_t
        for word, score in props: 
            if float(score) < irr_t: break
            elif target.lower() == word.lower() or stem(target.lower()) == word.lower() or lemma(target.lower()) == word.lower():
                make_correction = False
    return make_correction

def correction(string, back_n, pprobs=False):
    places = reversed(range(1,back_n+1))
    if back_n == 0: places = [1, 3, 2]
    string = string.strip()
    words  = string.split()
    last_space = string.rfind(' ')
    for n in places:
        if n > len(words) or len(words) == 1: break
        spelled = False
        if n > 1:
            model  = M_BERT
            masked = "[MASK]" + words[-n][-1] if not words[-n][-1].isalpha() else "[MASK]"
            target = words[-n].strip(punctuation)
            prompt = ' '.join(words[:-n] + [masked] + words[len(words)-(n-1):])
            target = words[-n].strip(punctuation)
            if target != spell(target):
                spelled = True
        else:
            model  = M_GPT2
            string = string.strip()
            last_space = string.rfind(' ')
            prompt = string[:last_space]
            target = string[last_space+1:].strip(punctuation)
            target = words[-n].strip(punctuation)
            if target != spell(target):
                spelled = True
        if pprobs:
            probs = pprobs[n-1]
        else:
            probs = model.get_word_probs(prompt) 
        props = get_props(target, probs)
        probN = threshold[threshold_t](n)
        if make_correction(target, props, probN): return (n, props[0][0])
        if spelled: 
            target = spell(target)
            props = get_props(target, probs)
            if len(props) > 0 and float(props[0][1]) > probN and props[0][0] == target:
                return (n, target)
    return False
    
def process_correction(string, back_n, pprobs=False):
    corrected = string
    words     = string.split()
    is_correction = correction(string, back_n, pprobs)
    if is_correction:
        n, word = is_correction
        words[-n] = word if words[-n][-1] not in punctuation else word + words[-n][-1]
        corrected = " ".join(words)
    return corrected if corrected != string else False

In [5]:
strings = pd.read_csv("strings.txt", quotechar='"', header=None, index_col=False, skipinitialspace=True).values
strings[:5], len(strings)

(array([['when you come in can you remember to feel the cat?',
         'when you come in can you remember to feed the cat?'],
        ['when you come in can you remember to feed the cad?',
         'when you come in can you remember to feed the cat?'],
        ['when you come in can you remember to feed the cad? He',
         'when you come in can you remember to feed the cat? He'],
        ['when you come in can you remember to feed the cad? He needs',
         'when you come in can you remember to feed the cat? He needs'],
        ['can you remember to feed the cad when',
         'can you remember to feed the cat when']], dtype=object),
 295)

In [6]:
sb_size = 250
super_batch = strings[np.random.choice(strings.shape[0], sb_size, replace=False),:]
probs = []
times = []
for i, string in enumerate(super_batch):
    t0 = time()
    probs.append(get_probs(string[0], 4))
    times.append(time()-t0)
    print(f'({i+1}/{len(super_batch)}): {times[-1]:.3f} ', end='')
print(f'\naverage: {np.mean(times):.3f} seconds')

(1/250): 0.198 (2/250): 0.143 (3/250): 0.000 (4/250): 0.134 (5/250): 0.130 (6/250): 0.138 (7/250): 0.139 (8/250): 0.061 (9/250): 0.143 (10/250): 0.099 (11/250): 0.107 (12/250): 0.141 (13/250): 0.126 (14/250): 0.092 (15/250): 0.132 (16/250): 0.199 (17/250): 0.300 (18/250): 0.366 (19/250): 0.402 (20/250): 0.236 (21/250): 0.374 (22/250): 0.419 (23/250): 0.280 (24/250): 0.436 (25/250): 0.000 (26/250): 0.407 (27/250): 0.367 (28/250): 0.310 (29/250): 0.401 (30/250): 0.307 (31/250): 0.372 (32/250): 0.423 (33/250): 0.484 (34/250): 0.439 (35/250): 0.363 (36/250): 0.361 (37/250): 0.000 (38/250): 0.311 (39/250): 0.314 (40/250): 0.329 (41/250): 0.323 (42/250): 0.330 (43/250): 0.312 (44/250): 0.317 (45/250): 0.293 (46/250): 0.374 (47/250): 0.340 (48/250): 0.000 (49/250): 0.307 (50/250): 0.283 (51/250): 0.307 (52/250): 0.111 (53/250): 0.116 (54/250): 0.411 (55/250): 0.462 (56/250): 0.398 (57/250): 0.288 (58/250): 0.357 (59/250): 0.289 (60/250): 0.389 (61/250): 0.146 (62/250): 0.283 (63/250): 0.289 (

In [7]:
back_n  = 4  # number of words back from end of string, 1 is just last word, 0 is [1, 3, 2]
k       = 1.2  # exponent parameter for exponential decay of word length augmedented SequenceMatcher
ap      = 0.55  # exponent parameter
bp      = 1  # exponent parameter
log_exp      = 5  # exponent parameter for logarithmic mapping
prob_exp     = 1  # raise probability to power in ((prob**power)*log-sim)
consider_top = 100  # max top model word predictions considered
relevency_t  = 0.07  # threshold defined by portion of top proposition to exclude much smaller scored propositions for correcting
base_t       = 0.0015   # decision threshold for last word: base threshold
threshold_e  = 1.8  # exponent for exponential thresholds
threshold_t  = "exponential"  # function defines decision threshold for word n from end

batch_size = 150
batch_i = np.random.choice(super_batch.shape[0], batch_size, replace=False)
batch = super_batch[batch_i,:]
corrections = []
times = []
for i, (pprobs, (x, _)) in enumerate(zip([probs[i] for i in batch_i], batch)): 
    t0 = time()
    corrections.append(process_correction(x, back_n, pprobs))
    times.append(time()-t0)
    print(f'({i+1}/{len(batch)}): {times[-1]:.3f} ', end='')
print(f'\naverage: {np.mean(times):.3f} seconds')

(1/150): 0.015 (2/150): 2.571 (3/150): 0.000 (4/150): 0.007 (5/150): 0.013 (6/150): 0.010 (7/150): 0.121 (8/150): 0.000 (9/150): 0.009 (10/150): 0.007 (11/150): 0.008 (12/150): 0.000 (13/150): 0.010 (14/150): 0.008 (15/150): 0.008 (16/150): 0.004 (17/150): 0.006 (18/150): 0.005 (19/150): 0.012 (20/150): 0.006 (21/150): 0.000 (22/150): 0.000 (23/150): 0.016 (24/150): 0.008 (25/150): 0.009 (26/150): 0.004 (27/150): 0.005 (28/150): 0.014 (29/150): 0.000 (30/150): 0.010 (31/150): 0.010 (32/150): 0.009 (33/150): 0.007 (34/150): 0.008 (35/150): 0.007 (36/150): 0.010 (37/150): 0.000 (38/150): 0.007 (39/150): 0.103 (40/150): 0.000 (41/150): 0.010 (42/150): 0.009 (43/150): 0.008 (44/150): 0.013 (45/150): 0.013 (46/150): 0.010 (47/150): 0.009 (48/150): 0.007 (49/150): 0.003 (50/150): 0.011 (51/150): 0.000 (52/150): 0.005 (53/150): 0.013 (54/150): 0.012 (55/150): 0.000 (56/150): 0.000 (57/150): 0.013 (58/150): 0.013 (59/150): 0.000 (60/150): 0.016 (61/150): 0.011 (62/150): 0.000 (63/150): 0.009 (

In [8]:
y_h = np.array(corrections == batch[:,1])
same = batch[:,0] == batch[:,1]
not_corrected = np.array(corrections) == "False"
tn  = np.logical_and(same, not_corrected)
fn  = np.logical_and(np.logical_not(same), not_corrected)
fp  = np.logical_and(same, np.logical_not(not_corrected))
tp = np.logical_and(np.logical_not(same), np.logical_not(not_corrected))
ttp  = np.logical_and(tp, y_h)
TN = tn.sum(); FN = fn.sum(); FP = fp.sum(); TP = tp.sum(); TTP = ttp.sum(); TP = tp.sum(); FTP = TP-TTP
# [Total] True Positives, True True Positives, True Negatives, False Negatives, False Positives, False True Positives
f'TP={TP}, TTP={TTP}, TN={TN}, FN={FN}, FP={FP}, FTP={FTP}'

'TP=47, TTP=45, TN=71, FN=29, FP=3, FTP=2'

In [9]:
precision = TTP/(TP+FP)
recall    = TP/(TP+FN)
fpr       = FP/same.sum()
specificity = TN/(TN+FP+FTP)
accuracy    = (TTP+TN)/(TP+TN+FP+FN)
f1          = (2*precision*recall)/(precision+recall)
f'precision={precision:.3f}, recall={recall:.3f}, false_positive_rate={fpr:.3f}, specificity={specificity:.3f}, accuracy={accuracy:.3f}, f1={f1:.3f}'#, f'k={k}, a={ap}, b={bp}, log_exp={log_exp}, prob_exp={prob_exp}, relevency_t={relevency_t}, base_t={base_t}, threshold_e={threshold_e}'

'precision=0.900, recall=0.618, false_positive_rate=0.041, specificity=0.934, accuracy=0.773, f1=0.733'

In [10]:
corrections = []
for x, _ in batch: 
    corrections.append(blob_correct(x))

In [11]:
y_h = np.array(corrections == batch[:,1])
same = batch[:,0] == batch[:,1]
not_corrected = np.array(corrections) == batch[:,0]
tn  = np.logical_and(same, not_corrected)
fn  = np.logical_and(np.logical_not(same), not_corrected)
fp  = np.logical_and(same, np.logical_not(not_corrected))
tp = np.logical_and(np.logical_not(same), np.logical_not(not_corrected))
ttp  = np.logical_and(tp, y_h)
TN = tn.sum(); FN = fn.sum(); FP = fp.sum(); TP = tp.sum(); TTP = ttp.sum(); TP = tp.sum(); FTP = TP-TTP
# [Total] True Positives, True True Positives, True Negatives, False Negatives, False Positives, False True Positives
f'TP={TP}, TTP={TTP}, TN={TN}, FN={FN}, FP={FP}, FTP={FTP}'

'TP=38, TTP=4, TN=62, FN=38, FP=12, FTP=34'

In [12]:
precision = TTP/(TP+FP)
recall    = TP/(TP+FN)
fpr       = FP/same.sum()
specificity = TN/(TN+FP+FTP)
accuracy    = (TTP+TN)/(TP+TN+FP+FN)
f1          = (2*precision*recall)/(precision+recall)
f'precision={precision:.3f}, recall={recall:.3f}, false_positive_rate={fpr:.3f}, specificity={specificity:.3f}, accuracy={accuracy:.3f}, f1={f1:.3f}'#, f'k={k}, a={ap}, b={bp}, log_exp={log_exp}, prob_exp={prob_exp}, relevency_t={relevency_t}, base_t={base_t}, threshold_e={threshold_e}'

'precision=0.080, recall=0.500, false_positive_rate=0.162, specificity=0.574, accuracy=0.440, f1=0.138'