In [1]:
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM
from nltk.tokenize import WordPunctTokenizer
import textdistance
import numpy as np
import time

In [2]:
# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-large-cased')
model = BertForMaskedLM.from_pretrained('bert-large-cased')

In [3]:
top_k=5

In [4]:
# Input single sentence
text = input('Enter text.(Do not forget to place punctuation): ')

Enter text.(Do not forget to place punctuation): spelling correction is quote diffcult task !


In [5]:
#Break down sentence in to tokens and add special tokens to indicate begining and end of sentence
text = WordPunctTokenizer().tokenize(text)
text.insert(0,'[CLS]')
text.insert(len(text),'[SEP]')
text

['[CLS]',
 'spelling',
 'correction',
 'is',
 'quote',
 'diffcult',
 'task',
 '!',
 '[SEP]']

In [6]:
#Iterate over every token in the sequence and replace it with special token [MASK]. The model will then try to predict the token
#in the [MASK] position
for i,token in enumerate(text):
    copy_text = text[:]  
    
    if token not in ('[CLS]','[SEP]'):
        print(copy_text[i])
        original_token = copy_text[i]
        copy_text[i]='[MASK]'
        copy_text = ' '.join(copy_text)
        tokenized_text = tokenizer.tokenize(copy_text)

        masked_index = tokenized_text.index('[MASK]')
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

        # Create the segments tensors.
        segments_ids = [0] * len(tokenized_text)

        # Convert inputs to PyTorch tensors
        tokens_tensor = torch.tensor([indexed_tokens])
        segments_tensors = torch.tensor([segments_ids])

        # Load pre-trained model (weights)
        
        model.eval()

        # Predict all tokens
        with torch.no_grad():
            predictions = model(tokens_tensor, segments_tensors)

        probs = torch.nn.functional.softmax(predictions[0][0][masked_index], dim=-1)
        
        # All redicted tokens are selected here. len(probs) can be replaced with any number suitable to the use case
        top_k_weights, top_k_indicies = torch.topk(probs, len(probs), sorted=True)

        output_dict={}
        for i, pred_idx in enumerate(top_k_indicies):
            predicted_token = tokenizer.convert_ids_to_tokens([pred_idx])[0]
            token_weight = top_k_weights[i]
            if len(original_token)>3:
                if textdistance.levenshtein.normalized_similarity(predicted_token,original_token)>0.5:
                    output_dict[predicted_token]=float(token_weight)
            else:
                output_dict[predicted_token]=float(token_weight)
        output_dict = dict(list(output_dict.items())[0: top_k]) 
        print(output_dict)

        
    
    
        

spelling
{'spell': 8.29263444757089e-05, 'spelling': 7.319104042835534e-05, 'Feeling': 3.1823601602809504e-05, 'Opening': 2.290070551680401e-05, 'something': 1.0126244887942448e-05}
correction
{'correction': 0.010122857056558132, 'convention': 0.001368048251606524, 'competition': 0.0011327865067869425, 'combination': 0.0008300655172206461, 'direction': 0.0006971591501496732}
is
{',': 0.2501995265483856, '.': 0.12372457981109619, '!': 0.06161179766058922, ';': 0.05436330288648605, '-': 0.05090939253568649}
quote
{'quite': 4.846660885959864e-06, 'quit': 7.05156750768765e-08, 'wrote': 5.465818730954197e-08, 'que': 5.041277262307631e-08, 'mute': 2.6868049118888848e-08}
diffcult
{'difficult': 0.002158178947865963, 'different': 0.00029643948073498905, 'difficulty': 4.1997889638878405e-05, 'default': 1.6654354112688452e-05, 'difficulties': 1.1171185178682208e-05}
task
{'talk': 6.015098188072443e-05, 'task': 2.4907898477977142e-05, 'track': 1.802486258384306e-05, 'taste': 1.5805690054548904e-0

#### This is a very basic implementation but still as we can see it can handle out of vocabulary as well as in vocabulary spelling mistakes. 'diffcult' is an out of vocabulary word and the model could suggest the right word. 'quote' is a spelling mistake where the intended word was 'quite' but this is a in vocabulary word. The model is able to identify that as well. The model can also handle punctuation mistakes.