In [1]:
from torch.nn.functional import softmax
from transformers import pipeline, GPT2TokenizerFast, GPT2LMHeadModel, AutoTokenizer, BertForMaskedLM
from autocorrect import Speller
from nltk.stem import WordNetLemmatizer, PorterStemmer
from difflib import SequenceMatcher
from string import punctuation
from textblob import TextBlob
from time import time
import numpy as np
import pandas as pd
import torch
import nltk
nltk.download('words')

[nltk_data] Downloading package words to C:\Users\bills-fish-
[nltk_data]     shack\AppData\Roaming\nltk_data...
[nltk_data]   Package words is already up-to-date!


True

In [2]:
topk = 200  # number of top predicted tokens to retrieve (before excluding non-words) 

class GPT2:
    def __init__(self, model="gpt2"):
        self.model     =   GPT2LMHeadModel.from_pretrained(model)
        self.tokenizer = GPT2TokenizerFast.from_pretrained(model)
        self.model_id  = model
    
    def get_word_probs(self, sentence, n=topk):  # adapted from raul on stackoverflow
        inputs = self.tokenizer.encode(sentence, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(inputs)
            predictions = outputs[0]
        candidates = predictions[0, -1, :]                          # Get the next token candidates.
        topk_i = torch.topk(candidates, n).indices.tolist()         # Get the top k next token candidates.
        all_probs = torch.nn.functional.softmax(candidates, dim=-1) # Get the token probabilities for all candidates.
        topk_probs = all_probs[topk_i].tolist()                     # Filter the token probabilities for the top k candidates.
        topk_tokens = [self.tokenizer.decode([idx]).strip()         # Decode the top k candidates back to words.
                       for idx in topk_i]
        return np.array(list(zip(topk_tokens, topk_probs)))

class BERT:
    def __init__(self, model="google-bert/bert-base-uncased"):
        self.model     = BertForMaskedLM.from_pretrained(model)
        self.tokenizer =   AutoTokenizer.from_pretrained(model)
        self.model_id  = model
        
    def get_word_probs(self, prompt, topk=topk):                  # Get topk masked token candidates
        inputs = self.tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
            logits = self.model(**inputs).logits
        mask_index  = (inputs.input_ids == self.tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
        mask_logits = logits.squeeze()[mask_index].squeeze()
        probs = softmax(mask_logits, dim=-1)
        topk = 5000
        topk_probs, topk_i = torch.topk(probs, topk, dim=-1)
        topk_tokens = np.array([self.tokenizer.decode([i]) for i in topk_i])
        return np.hstack((topk_tokens.reshape(-1,1), np.array(topk_probs).reshape(-1,1)))

M_GPT2 = GPT2("gpt2")
M_BERT = BERT("google-bert/bert-base-uncased")

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at google-bert/bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with ano

In [3]:
def similar(a, b):
    common_len = np.ceil((len(a)+len(b))/2)
    adjustment = 0
    adjustment_table = {1: 0.5, 2: 0.3, 3: 0.2, 4: 0.1}
    if common_len in adjustment_table: adjustment = adjustment_table[common_len]*(np.e**(-k*(np.abs(len(a)-len(b))-ap))-bp)
    return SequenceMatcher(None, a, b).ratio() + adjustment
def rreplace(string, word, new_word):
    start = string.rfind(word)
    return string[0:start] + new_word + string[start+len(word):]
def blob_correct(sentence):
    blob = TextBlob(sentence)
    return str(blob.correct())

lemmatizer  = WordNetLemmatizer()
lemma       = lambda x: lemmatizer.lemmatize(x)
stemmer     = PorterStemmer()
stem        = lambda x: stemmer.stem(x)
spell       = Speller()
wl          = set(nltk.corpus.words.words())
log_map     = lambda e: np.vectorize(lambda x: np.power(np.log(x/0.5)/np.log(2), e))  # specify exponent to return vectorized mapping
after_slash = lambda x: x[(x.rfind("/")+1 if x.rfind("/") != -1 else 0):]

In [4]:
def get_probs(string, back_n):
    probs = []
    places = range(1,back_n+1)
    string = string.strip()
    words  = string.split()
    last_space = string.rfind(' ')
    for n in places:
        if n > len(words) or len(words) == 1: break
        spelled = False
        if n > 1:
            model  = M_BERT
            masked = "[MASK]" + words[-n][-1] if not words[-n][-1].isalpha() else "[MASK]"
            target = words[-n].strip(punctuation)
            prompt = ' '.join(words[:-n] + [masked] + words[len(words)-(n-1):])
        else:
            model  = M_GPT2
            string = string.strip()
            last_space = string.rfind(' ')
            prompt = string[:last_space]
        prob = model.get_word_probs(prompt)
        #print(prob)
        if len(prob) > 0:
            probs.append(prob)
        else:
            probs.append([['dummy', 0]])
    return probs
def correction(string, back_n, pprobs=False):
    places = reversed(range(1,back_n+1))
    if back_n == 0: places = [1, 3, 2]
    string = string.strip()
    words  = string.split()
    last_space = string.rfind(' ')
    for n in places:
        if n > len(words) or len(words) == 1: break
        spelled = False
        if n > 1:
            model  = M_BERT
            masked = "[MASK]" + words[-n][-1] if not words[-n][-1].isalpha() else "[MASK]"
            target = words[-n].strip(punctuation)
            prompt = ' '.join(words[:-n] + [masked] + words[len(words)-(n-1):])
            target = words[-n].strip(punctuation)
            if target != spell(target):
                spelled = True
        else:
            model  = M_GPT2
            string = string.strip()
            last_space = string.rfind(' ')
            prompt = string[:last_space]
            target = string[last_space+1:].strip(punctuation)
            target = words[-n].strip(punctuation)
            if target != spell(target):
                spelled = True
        if pprobs:
            probs = pprobs[n-1]
        else:
            probs = model.get_word_probs(prompt)
        probs[:,1] = probs[:,1].astype(float)/probs[:,1].astype(float).sum()
        probsp = [(str(word), float(prob), float(similar(target.lower(), word.lower()))) for word, prob in probs if word in wl]
        close_probs = [prob for prob in probsp if prob[2] > 0.5 and prob[1] >= 0.001]
        props = [(word, (prob**prob_exp)*log_map(log_exp)(sim)) for word, prob, sim in close_probs]
        props = sorted(props, reverse=True, key=lambda x: x[1])
        props = [prop for prop in props if prop[1] > 0.000001]
        probN = threshold(n)
        make_correction = False
        if len(props) > 0 and props[0][1] > probN:
            make_correction = True
            irr_t = props[0][1] * relevency_t
            for word, score in props: 
                if score < irr_t: break
                elif target.lower() == word.lower() or stem(target.lower()) == word.lower() or lemma(target.lower()) == word.lower():
                    make_correction = False
        if make_correction: return (n, props[0][0], spelled)
        #if spelled: return (n, target, spelled)
    return False
def process_correction(string, back_n, pprobs=False):
    corrected = string
    words     = string.split()
    is_correction = correction(string, back_n, pprobs)
    if is_correction:
        n, word, _ = is_correction
        words[-n] = word if words[-n][-1] not in punctuation else word + words[-n][-1]
        corrected = " ".join(words)
        #is_correction = correction(corrected, back_n)
    return corrected if corrected != string else False

In [5]:
strings = pd.read_csv("strings.txt", quotechar='"', header=None, index_col=False, skipinitialspace=True).values
strings[:5]

array([['when you come in can you remember to feel the cat?',
        'when you come in can you remember to feed the cat?'],
       ['when you come in can you remember to feed the cad?',
        'when you come in can you remember to feed the cat?'],
       ['when you come in can you remember to feed the cad? He',
        'when you come in can you remember to feed the cat? He'],
       ['when you come in can you remember to feed the cad? He needs',
        'when you come in can you remember to feed the cat? He needs'],
       ['can you remember to feed the cad when',
        'can you remember to feed the cat when']], dtype=object)

In [6]:
batch_size = 200
batch = strings[np.random.choice(strings.shape[0], batch_size, replace=False),:]
batch[:5]

array([['If it rains', 'If it rains'],
       ['can you go and let', 'can you go and let'],
       ['I want to read', 'I want to read'],
       ['You have to live up', 'You have to live up'],
       ['can you let my aunt use your car when she comes',
        'can you let my aunt use your car when she comes']], dtype=object)

In [7]:
probs = []
times = []
for i, string in enumerate(batch):
    t0 = time()
    probs.append(get_probs(string[0], 4))
    times.append(time()-t0)
    print(f'({i+1}/{len(batch)}): {times[-1]:.3f} ', end='')
print(f'\naverage: {np.mean(times):.3f} seconds')

(1/200): 1.733 (2/200): 0.746 (3/200): 0.338 (4/200): 0.373 (5/200): 0.371 (6/200): 0.300 (7/200): 0.323 (8/200): 0.360 (9/200): 0.273 (10/200): 0.977 (11/200): 0.365 (12/200): 0.340 (13/200): 0.423 (14/200): 0.374 (15/200): 0.383 (16/200): 0.343 (17/200): 0.364 (18/200): 0.320 (19/200): 0.239 (20/200): 0.326 (21/200): 0.261 (22/200): 0.283 (23/200): 0.342 (24/200): 0.453 (25/200): 0.426 (26/200): 0.439 (27/200): 0.407 (28/200): 0.467 (29/200): 0.291 (30/200): 0.155 (31/200): 0.393 (32/200): 0.378 (33/200): 0.289 (34/200): 0.382 (35/200): 0.378 (36/200): 0.440 (37/200): 0.410 (38/200): 0.464 (39/200): 0.459 (40/200): 0.196 (41/200): 0.197 (42/200): 0.510 (43/200): 0.476 (44/200): 0.406 (45/200): 0.696 (46/200): 0.146 (47/200): 0.409 (48/200): 0.158 (49/200): 0.449 (50/200): 0.332 (51/200): 0.417 (52/200): 0.433 (53/200): 0.396 (54/200): 0.388 (55/200): 0.480 (56/200): 0.464 (57/200): 0.488 (58/200): 0.437 (59/200): 0.410 (60/200): 0.448 (61/200): 0.470 (62/200): 0.413 (63/200): 0.416 (

In [8]:
back_n  = 4  # number of words back from end of string, 1 is just last word, 0 is [1, 3, 2]
k       = 1.2  # exponent parameter for exponential decay of word length augmedented SequenceMatcher
ap      = 0.55  # exponent parameter
bp      = 1  # exponent parameter
log_exp      = 5  # exponent parameter for logarithmic mapping
prob_exp     = 1  # raise probability to power in ((prob**power)*log-sim)
consider_top = 100  # max top model word predictions considered
relevency_t  = 0.07  # threshold defined by portion of top proposition to exclude much smaller scored propositions for correcting
base_t       = 0.0015   # decision threshold for last word: base threshold
threshold_e  = 1.8  # exponent for exponential thresholds
threshold_t  = "exponential"  # function defines decision threshold for word n from end
threshold    = {"constant":    lambda n: base_t,
                "linear":      lambda n: base_t + (base_t * (n-1)),
                "exponential": lambda n: base_t * (n**threshold_e),
                "jump-exp":    lambda n: base_t * (max(n-1,1)**threshold_e),        # jump thresholds start growing after n=2
                "jump-lin":    lambda n: base_t + (base_t * max(n-2, 0))
               }[threshold_t]

corrections = []
times = []
for i, (pprobs, (x, _)) in enumerate(zip(probs, batch)): 
    t0 = time()
    corrections.append(process_correction(x, back_n, pprobs))
    times.append(time()-t0)
    print(f'({i+1}/{len(batch)}): {times[-1]:.3f} ', end='')
print(f'\naverage: {np.mean(times):.3f} seconds')

(1/200): 0.000 (2/200): 4.316 (3/200): 0.177 (4/200): 0.209 (5/200): 0.179 (6/200): 0.128 (7/200): 0.204 (8/200): 0.213 (9/200): 0.000 (10/200): 0.209 (11/200): 0.188 (12/200): 0.261 (13/200): 0.361 (14/200): 0.313 (15/200): 0.167 (16/200): 0.083 (17/200): 0.160 (18/200): 0.183 (19/200): 0.000 (20/200): 0.148 (21/200): 0.000 (22/200): 0.000 (23/200): 0.139 (24/200): 0.141 (25/200): 0.149 (26/200): 0.155 (27/200): 0.163 (28/200): 0.149 (29/200): 0.000 (30/200): 0.000 (31/200): 0.047 (32/200): 0.197 (33/200): 0.000 (34/200): 0.157 (35/200): 0.182 (36/200): 0.057 (37/200): 0.172 (38/200): 0.181 (39/200): 0.197 (40/200): 0.000 (41/200): 0.000 (42/200): 0.125 (43/200): 0.172 (44/200): 0.274 (45/200): 0.313 (46/200): 0.000 (47/200): 0.259 (48/200): 0.000 (49/200): 0.166 (50/200): 0.000 (51/200): 0.269 (52/200): 0.145 (53/200): 0.289 (54/200): 0.173 (55/200): 0.301 (56/200): 0.185 (57/200): 0.174 (58/200): 0.076 (59/200): 0.150 (60/200): 0.146 (61/200): 0.185 (62/200): 0.142 (63/200): 0.168 (

In [9]:
y_h = np.array(corrections == batch[:,1])
same = batch[:,0] == batch[:,1]
not_corrected = np.array(corrections) == "False"
tn  = np.logical_and(same, not_corrected)
fn  = np.logical_and(np.logical_not(same), not_corrected)
fp  = np.logical_and(same, np.logical_not(not_corrected))
tp = np.logical_and(np.logical_not(same), np.logical_not(not_corrected))
ttp  = np.logical_and(tp, y_h)
TN = tn.sum(); FN = fn.sum(); FP = fp.sum(); TP = tp.sum(); TTP = ttp.sum(); TP = tp.sum(); FTP = TP-TTP
# [Total] True Positives, True True Positives, True Negatives, False Negatives, False Positives, False True Positives
f'TP={TP}, TTP={TTP}, TN={TN}, FN={FN}, FP={FP}, FTP={FTP}'

'TP=55, TTP=52, TN=95, FN=47, FP=3, FTP=3'

In [10]:
precision   = TTP/(TP+FP)
recall      = TP/(TP+FN)
specificity = TN/(TN+FP+FTP)
accuracy    = (TTP+TN)/(TP+TN+FP+FN)
f1          = (2*precision*recall)/(precision+recall)
f'precision={precision:.3f}, recall={recall:.3f}, specificity={specificity:.3f}, accuracy={accuracy:.3f}, f1={f1:.3f}'#, f'k={k}, a={ap}, b={bp}, log_exp={log_exp}, prob_exp={prob_exp}, relevency_t={relevency_t}, base_t={base_t}, threshold_e={threshold_e}'

'precision=0.897, recall=0.539, specificity=0.941, accuracy=0.735, f1=0.673'

In [11]:
corrections = []
for x, _ in batch: 
    corrections.append(blob_correct(x))

In [12]:
y_h = np.array(corrections == batch[:,1])
same = batch[:,0] == batch[:,1]
not_corrected = np.array(corrections) == batch[:,0]
tn  = np.logical_and(same, not_corrected)
fn  = np.logical_and(np.logical_not(same), not_corrected)
fp  = np.logical_and(same, np.logical_not(not_corrected))
tp = np.logical_and(np.logical_not(same), np.logical_not(not_corrected))
ttp  = np.logical_and(tp, y_h)
TN = tn.sum(); FN = fn.sum(); FP = fp.sum(); TP = tp.sum(); TTP = ttp.sum(); TP = tp.sum(); FTP = TP-TTP
# [Total] True Positives, True True Positives, True Negatives, False Negatives, False Positives, False True Positives
f'TP={TP}, TTP={TTP}, TN={TN}, FN={FN}, FP={FP}, FTP={FTP}'

'TP=44, TTP=4, TN=78, FN=58, FP=20, FTP=40'

In [13]:
precision   = TTP/(TP+FP)
recall      = TP/(TP+FN)
specificity = TN/(TN+FP+FTP)
accuracy    = (TTP+TN)/(TP+TN+FP+FN)
f1          = (2*precision*recall)/(precision+recall)
f'precision={precision:.3f}, recall={recall:.3f}, specificity={specificity:.3f}, accuracy={accuracy:.3f}, f1={f1:.3f}'#, f'k={k}, a={ap}, b={bp}, log_exp={log_exp}, prob_exp={prob_exp}, relevency_t={relevency_t}, base_t={base_t}, threshold_e={threshold_e}'

'precision=0.062, recall=0.431, specificity=0.565, accuracy=0.410, f1=0.109'