# packages and dependencies


In [None]:

#!python -m spacy download nb_core_news_sm 
!spacy download nb_core_news_sm 
!pip install transformers


import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt

import string
import re
import nltk
nltk.download('punkt')
import json
import spacy
from spacy.lang.nb.examples import sentences

from collections import Counter
import copy
import pickle

import transformers
from transformers import pipeline


# spell checker and corrector


In [None]:

class spell_checker:

  def __init__(self,model_path):
    """
    overview: The class is a representation for a spell checker and auto-corrector for Norwegian. 
    We built it from scrtach for Norwegian language (inspired by levenshtein distance algo.) and powered by AI.  
    Arg: a path of the  model, more spesific for BERT which is compatible with mlm-task.
    Returns: a list of tokenized words
    """

    self.unmasker = pipeline('fill-mask', model= model_path)

    # a Norwegian dic
    self.dic=self.read_dic() 

    # the Norwegian alphabet
    self.alphabet="a b c d e f g h i j k l m n o p q r s t u v w x y z æ ø å".split()

  def tokenize(self, sent:str)->list:
    """
    overview: It splits a given doc/sentence, written in norwegian, into words.
    Arg: a doc/sentence
    Returns: list of tokenized words
    """
    return nltk.word_tokenize(sent)
 

  def lemmatize(self, doc:str)->list:  
    """
    overview: a Norwegian lemmatizer which provides the base form of a given sentence/doc.
    Arg: a doc/sentence
    Returns: list of lemmas (=base form of the words)
    """
    no_lemmatizer = spacy.load("nb_core_news_sm") 

    p_marks=string.punctuation
      
    lemmas = [] 
      
    for w in doc:
      if w.strip() not in p_marks:
        lemma=re.sub("[0-9]+", "",w.lower())
        lemma=(no_lemmatizer(w)[0].lemma_).lower()
        lemmas.append(lemma)
    return lemmas


  def read_dic(self)->list:
    """
    overview: It reads the Norwegian dictionary file.
    Returns: list of the dic words 
    """
    f=open("dic.txt","r", encoding='utf-8')
    dic=[]
    for line in f:
        l=line.split("\n")
        dic.append(l[0].lower())
    f.close()
    return dic


  def find_misspelled_ws(self,user_inp:str)->list:
    """
    overview: it detects the misspelled words in a search text by using the dic.
    arg: search text given by the user
    returns: a list of misspelled words of the inp (if found)
    """
    mis_ws=[]
    
    for w in user_inp:
      if w.lower() not in self.dic:
        mis_ws.append(w.lower())
    return mis_ws


  def insert(self,word:str)->list:
    """
    overview: it generates strings by adding Norwegians letters (one letter by time) before and after every 
    existing letter of the misspelled word.
    arg: a misspelled word as a String
    returns: a list of the generated strings by adding 
    """
    ar=[]
    chars=list(word)
    
    suggested_strs=[]
    for i,char in enumerate(chars): 
      for l in self.alphabet:
        if i==len(chars)-1:
          ar.append( word[:i+1] + l)
        else:
          ar.append( word[:i] +l+ word[i:])
    return ar


  def delete(self, word:str)->list:
    """
    overview: it generates strings by recursively deleting an existing letter of the misspelled word.
    arg: a misspelled word as a String
    returns: a list of the generated strings
    """
    generated_strs=[]
    chars=list(word)
    #if the nr. self.alphabet is 2 then no need to delete.
    if len(word)<=2:
      return []

    for i in range(len(chars)): 
        temp=word[:]        
        res_str = temp.replace(temp[i], '', 1)
        generated_strs.append(res_str)
        
    return generated_strs


  def replace(self, word:str)->list:
    """
    overview: it generates strings by recursively replacing an existing letter of the misspelled word with an another.
    arg: a misspelled word as a String
    returns: a list of the generated strings
    """
    generated_strs=[]
    chars=list(word)  
    for i, char in enumerate(chars): 
      temp=chars[i]
      for  l in self.alphabet: 
        chars[i]=l
        replaced_w="".join(chars)
        generated_strs.append(replaced_w)
      chars[i]=temp
    return generated_strs


  def swap(self, word:str)->list:
    """
    overview: it generates strings by recursively swapping two and two letters of the misspelled word.
    arg: a misspelled word as a String
    returns: a list of the generated strings
    """
    generated_strs=[]
    chars=list(word) 
    if len(chars)>1:
      for i in range(len(chars)-1):
        temp=chars.copy()
        t=temp[i]
        temp[i]= temp[i+1]
        temp[i+1]= t
        generated_strs.append("".join(temp))
      return generated_strs


  def check_one(self, word:str)->list:
    """
    overview: It just collects all generates strings given by the previous four methods.
    arg: a misspelled word as a String
    returns: a list of all generated strings
    """
    
    gen_strs=self.insert(word), self.delete(word), self.replace(word),self.swap(word)
    concat_list=[]
    for lst in gen_strs:
      concat_list.extend(lst)
    return sorted(set(concat_list))

  
  def check_all(self,miss_s_ws:list)->list:
    """
    overview: It just collects all generates strings given by the previous four methods for ALL misspelled words.
    arg: a list of the misspelled word as Strings
    returns: a matrix of all generated strings fr all misspelled words of the user input.
    """
    return [self.check_one(w) for w in miss_s_ws]


  def wrds_frm_strs(self, strs:list)->list:
    """
    overview: It extracts the words from the given 
    arg: a list of the misspelled word as Strings
    returns: a matrix of all generated strings fr all misspelled words of the user input.
    """
    words=[]
    for w in strs:
      if w  in self.dic:
        words.append(w)  
    return words
     

  def mask(self, miss_s_ws:list,user_inp:str)->str:
    """
    Returns: 
      A masked string which our model need it to predict a list of suggested words/tokens.
      Se method unmask_and_suggest_ws()
    Args:
      A list of the missspelled words
    """
    user_inp_copy=user_inp[:]
    
    print(miss_s_ws)
    for e in miss_s_ws:
      user_inp_copy[user_inp_copy.index(e)]='[MASK]' 
      concat=' '.join(user_inp_copy)
    return concat

  def unmask_and_suggest_ws(self, miss_s_ws:list,user_inp:str)->list:
    """
    Overview:
      By using AI, it masks the misspelled words and suggest candidates 
      words of them based on context of the user input.
    Args:
      miss_s_ws: a list of the misspelled words
      user_inp: user input/search txt as a string
    Returns: 
      a list of lists where every list represents a set of the candidate words of a given misspelled word. 
     """
    masked_inp=self.mask(miss_s_ws,user_inp)
    unmasked=self.unmasker(masked_inp)
    if len(unmasked)==0 or not unmasked:
      raise Exception("An error comes from the model!")

    suggested_words=[]

    if len(miss_s_ws)==1:  
      #candidate words for one misspelled word
      temp=[unmasked[i]['token_str'].lower() for i in range(len(unmasked))]
      suggested_words.append(temp)
    else:
      for l in unmasked:
        #candidate words for one misspelled word
        temp= [l[i]['token_str'].lower() for i in range(len(l))] 
        suggested_words.append(temp)
    return suggested_words


  def find_best_candidate(self, extracted_ws:list, suggested_ws:list)->str: 
    """
    Overview:
     Given a list of the  words whcih are extracted from the generated strings and a list of
     the words/tokens which are suggested/predicted by mlm model, the method tries to find a match bewteen them.
     In other words, the method sees if one of the predicted words is in the extracted words to replace the misssepelled word.
    returns:
      a matched token if found otherwise an empty string ''
      str: user input/search text
    Args:
      str list: extracted from the generated strings
      str list: suggested candidate words
    """
    for w in suggested_ws:
      if w in extracted_ws:
        return w
    return ''


  def read_json(self, file)->list:
    """
    overview: It reads the frequecy table as a json file
    arg: a json file
    returns: a list of the items .. (words : their frequency)
    """
    f = open(file) 
    data = json.load(f)  
    return data


  def find_best_freq_table(self, extracted_wrds:list, freq_table:list)->str:
    """
    overview: It find the most frequent candidate out of a set of candidate
    to replace the misspelled word.
    arg: a list of the candidates for a misspelled word
    returns: the most most frequent candidate
    """
    temp_freq={}
    for w in extracted_wrds:
      if w not in freq_table:
        temp_freq[w]=1
      else:
        temp_freq[w]=freq_table[w]
    return max(temp_freq, key=temp_freq.get)



  def auto_correct(self, search_txt:str):
    """
    overview: It is the main function which puts all other functions together.
    arg: search_txt or user input as a string
    returns: a string of the corrected input or the same user 
    input if no misspelled word or no candidate word for the misspelled word
    """
    user_inp=self.lemmatize(self.tokenize(search_txt))   #a list of cleaned input
    user_inp_copy = user_inp[:]
    
    miss_s_ws=self.find_misspelled_ws(user_inp)
    
    if len(miss_s_ws)==0:
      #no misspelled words in the user input 
      return user_inp

    else: 
      gen_strings=self.check_all(miss_s_ws)
      
      extracted_wrds=[self.wrds_frm_strs(strs) for strs in gen_strings]

      if len(extracted_wrds)==0:
        # no generated string is a word 
        return user_inp

      #the final candidates list which will replace the misspelled word(s) if found
      chosen=[]
      
      suggested_candidates=self.unmask_and_suggest_ws(miss_s_ws, user_inp_copy)

      for i, l in enumerate(suggested_candidates):
        
        #best candidate if found
        best_candidate_mlm=self.find_best_candidate(extracted_wrds, l)
        
        if len(best_candidate_mlm)!=0:
          chosen.append(best_candidate_mlm)
        else:
          ex_wrds=extracted_wrds[i][:]
          freq_table=self.read_json("freq_table.json")
          best_candidate_freqTable=self.find_best_freq_table(ex_wrds,freq_table)
          chosen.append(best_candidate_freqTable)
      
      temp=[]
      user_inp_copy2=user_inp[:]
      for i, candidate in enumerate(chosen):
        ii =user_inp.index(miss_s_ws[i]) #index of the misspelled word
        user_inp_copy2[ii]=candidate 
      
      print("Mener du:")
      formated_answer=' '.join(user_inp_copy2) 
    return formated_answer




def main():
  #demo

  inp="Når du har ojbb må du ha et skatekort."
  sc=spell_checker('ltgoslo/norbert2')
  print(sc.auto_correct(inp))

if __name__ == "__main__":
  main()

"""

output:
  Mener du:
  når du ha jobb måtte du ha en skattekort

"""



Some weights of the model checkpoint at ltgoslo/norbert2 were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


['ojbb', 'skatekort']
Mener du:
når du ha jobb måtte du ha en skattekort
