This project is a work in progress. I want to add a feature that can predict how a misspelled word would be pronounced using the linguistic rules of English.

This first cell contains the necessary libraries for this Jupyter notebook, as well as configuration variables to fine-tune the functions in the sentence prediction/spellchecking model.

In [1]:
import nltk
import re
import math
import string

from nltk import pos_tag
from nltk.corpus import cmudict, brown, reuters, stopwords
from collections import defaultdict, Counter

nltk.download("cmudict")
nltk.download("brown")
nltk.download("reuters")
nltk.download("stopwords")

nltk.download('averaged_perceptron_tagger', force=True)
nltk.download('averaged_perceptron_tagger_eng', force=True)
nltk.download('punkt')

# ---------- configuration ----------
LAMBDA = [0.05, 0.15, 0.3, 0.5] # format: [unigram_weight, bigram_weight, trigram_weight, 4gram_weight]; sum(LAMBDA) ≈ 1.0; longer n-grams have more weight
MAX_NGRAM = len(LAMBDA)
SMOOTH = 1e-8 # generally a fallback value in case an n-gram has a count of 0

[nltk_data] Downloading package cmudict to
[nltk_data]     C:\Users\uddin\AppData\Roaming\nltk_data...
[nltk_data]   Package cmudict is already up-to-date!
[nltk_data] Downloading package brown to
[nltk_data]     C:\Users\uddin\AppData\Roaming\nltk_data...
[nltk_data]   Package brown is already up-to-date!
[nltk_data] Downloading package reuters to
[nltk_data]     C:\Users\uddin\AppData\Roaming\nltk_data...
[nltk_data]   Package reuters is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\uddin\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\uddin\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping taggers\averaged_perceptron_tagger.zip.
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     C:\Users\uddin\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping taggers\averaged_perceptron_tagger_eng.z

This second cell contains all of the functions that precompute variables that are used in the model -- such as the full corpus, n-gram counts, sum of the number of each n-gram.

In [2]:
def fuse_corpora(corpus_one, corpus_two, limit=None):
    """
    Combine two corpora into a single list of lower-cased words.

    Parameters
    ----------
    corpus_one : NLTK corpus
        the first corpus; must have a .words() method
    corpus_two : NLTK corpus
        the second corpus; must have a .words() method
    limit : int
        the maximum number of words to take from each corpus

    Examples
    --------
    >>> fuse_corpora(brown, reuters, 3)
    ['the', 'fulton', 'county', 'asian', 'exporters', 'fear']
    """
    return [word.lower() for word in corpus_one.words()[:limit]] + [word.lower() for word in corpus_two.words()[:limit]]

def extract_ngrams(ngram_size: int, corpus: list):
    """
    Extract ngrams from a corpus.

    Examples
    --------
    >>> extract_ngrams(3, brown.words()[-5:-1])
    [('boucle', 'dress', 'was'), ('dress', 'was', 'stupefying')]
    """
    corpus = [word.lower() for word in corpus]

    return [tuple(corpus[pos:pos+ngram_size]) for pos in range(len(corpus) - (ngram_size - 1))]

def build_counts_model(corpus, min_size=1, max_size=MAX_NGRAM, smooth=SMOOTH):
    """
    Returns the number of occurences of each ngram in a corpus, sorted by the ngram size.

    Parameters
    ----------
    corpus: list
      the corpus from which ngrams will be extracted
    minsize : int
      the smallest ngram size to analyze
    maxsize : int
      the largest ngram size to analyze

    Examples
    --------
    >>> build_counts_model(brown.words()[:50])
    {1: defaultdict(<function ...>, {('the',): 3, ('fulton',): 1, ...}),
     ...
     4: defaultdict(<function ...>, {('the', 'fulton', 'county', 'grand'): 1, ('fulton', 'county', 'grand', 'jury'): 1, ...})}
    
    >>> build_counts_model(brown.words()[:50])[2][('the', 'fulton')]
    1
    """
    counts = {}

    for ngram_size in range(min_size, max_size + 1):
        raw = Counter(extract_ngrams(ngram_size, corpus)) # iterable of (ngram, count)
        d = defaultdict(lambda: smooth)
        for ngram, c in raw.items():
            d[ngram] = c
        counts[ngram_size] = d

    return counts

def precompute_prefix_denominators(counts: dict, max_size: int=MAX_NGRAM):
    """
    Build a lookup: prefix_sums[n][prefix] = total count of n-grams starting with that prefix.
    This makes P(word | prefix) lookup O(1).
    """
    prefix_sums = {n: defaultdict(int) for n in range(2, max_size + 1)}

    for n in range(2, max_size + 1):
        counter = counts.get(n, {})
        for ngram, freq in counter.items():
            prefix = ngram[:-1]
            prefix_sums[n][prefix] += freq

    return prefix_sums

def precompute_totals(counts: dict, smooth: float=SMOOTH):
    """
    Precompute total smoothed counts for each n-gram order to speed up probability lookups.
    Returns a dict like {1: total_unigram_mass, 2: total_bigram_mass, ...}.
    """
    total_counts = {}
    for n, counter in counts.items():
        total_counts[n] = sum(counter.values()) + smooth * len(counter)
    return total_counts

This third cell computes all the variables that will be used in the final model. This code takes around 30 seconds to run at minimum, so print statements are included to indicate if the cell is running properly.

In [3]:
# full corpus & corpus with punctuation and stopwords removed
corpus_full = fuse_corpora(brown, reuters)

punctuation = set(string.punctuation) # also used later to determine sentence boundaries
to_be_filtered = punctuation | set(stopwords.words("english")) # tokens to be removed from full corpus
corpus_filtered = [word for word in corpus_full if word not in to_be_filtered]
print("corpuses built")

# full & filtered counts of unigrams to 4-grams
counts_full = build_counts_model(corpus_full)
counts_filtered = build_counts_model(corpus_filtered)
print("counts models built")

# precompute prefix denominators for full & filtered counts
prefix_denoms_full = precompute_prefix_denominators(counts_full)
prefix_denoms_filtered = precompute_prefix_denominators(counts_filtered)
print("prefix denominators precomputed")

# precompute total counts for full & filtered counts
total_counts_full = precompute_totals(counts_full)
total_counts_filtered = precompute_totals(counts_filtered)
print("total counts precomputed")

# create pronunciation dictionary using cmudict
pronounce_dictionary = cmudict.dict()
print(pronounce_dictionary['cat']) # outputs [['K', 'AE1', 'T']]

# create list of valid keys in pronounceDict
validWords = list(pronounce_dictionary.keys())

# create dictionary that maps pronunciations tuples to words
reverse_pronounce_dictionary = defaultdict(list)

for word, pronunciations in cmudict.dict().items():
    for pronunciation in pronunciations:
        reverse_pronounce_dictionary[tuple(pronunciation)].append(word)
print(reverse_pronounce_dictionary[('K', 'AE1', 'T')]) # outputs ['cat', 'catt', 'kat', 'katt']

# Collect words tagged as adjectives, prepositions, and nouns in Brown -- don't need full corpus for this, as this for ironing out common their/they're/there type errors
adjectives = set()
prepositions = set()
nouns = set()

for (word, tag) in brown.tagged_words():
    if tag.startswith("JJ"):
        adjectives.add(word.lower())
    if tag.startswith("IN"):
        prepositions.add(word.lower())
    if tag.startswith("NN"):
        nouns.add(word.lower())
nouns.remove("a") # fixes a tagging error

corpuses built
counts models built
prefix denominators precomputed
total counts precomputed
[['K', 'AE1', 'T']]
['cat', 'catt', 'kat', 'katt']


The following cells contain the functions are used in the model.

In [4]:
def prob_from_counts(word: str, context_size: int, context_data: tuple, 
                     counts: dict, prefix_sums: int, total_counts: int):
    """
    Return P(word | context) estimated from counts and precomputed denominators.
    """
    n = context_size + 1
    counter = counts.get(n, Counter())
    total = total_counts.get(n, sum(counter.values()) + SMOOTH * len(counter))

    if context_size == 0:
        # unigram probability
        return (counter.get((word,), 0) + SMOOTH) / total

    prefix = tuple(context_data[-context_size:]) if context_size <= len(context_data) else None
    if prefix is None:
        return SMOOTH / (total + SMOOTH) ** 2

    denom = prefix_sums[n].get(prefix, 0) + SMOOTH * len(counter)
    joint_count = counter.get(prefix + (word,), 0) + SMOOTH

    if denom == 0:
        return SMOOTH / (total + SMOOTH) ** 2

    return joint_count / denom


# Tests
print(prob_from_counts("year", 2, ("in", "the", "last"), counts_full, prefix_denoms_full, total_counts_full))
print(prob_from_counts("day", 2, ("in", "the", "last"), counts_full, prefix_denoms_full, total_counts_full))

0.0383781100283299
0.027717523915271918


In [None]:
def nbest_continuations(sent: list, counts: dict, prefix_sums: int, total_counts: int, nbest: int=5):
    """
    Outputs the most likely next word of the input sentence.
    """
    ctx = [w.lower() for w in sent]
    candidate_probs = defaultdict(float)

    # Gather candidates from higher-order contexts first
    candidates = set()

    for ngram_size in range(2, min(len(ctx) + 1, MAX_NGRAM + 1)):
        prefix = tuple(ctx[-(ngram_size - 1):])
        
        # collect all words that appear after this prefix
        continuations = [
            ngram[-1]
            for ngram, count in counts[ngram_size].items()
            if ngram[:-1] == prefix
        ]
        
        top_continuations = sorted(
            continuations,
            key=lambda w: counts[ngram_size][prefix + (w,)],
            reverse=True
        )[:100]
        candidates.update(top_continuations)
    
    # fallback to top unigrams only if no higher-order candidates were found
    if len(candidates) == 0:
        top_unigrams = sorted(counts[1], key=counts[1].get, reverse=True)[:300]
        candidates.update(top_unigrams)

    # Compute interpolated log probabilities
    for w in candidates:
        p_mix = 0.0
        for context_size in range(MAX_NGRAM):
            p_mix += LAMBDA[context_size] * prob_from_counts(
                w, context_size, ctx, counts, prefix_sums, total_counts
            )
        candidate_probs[w] = math.log(p_mix)

    ranked = sorted(candidate_probs.items(), key=lambda x: x[1], reverse=True)
    return ranked[:nbest]

['the', 'diddler']


  test2 = re.findall(r"[\w']+|[.,!?;]", test2)


'TODO: \n-make a bunch of test cases -- seems to struggle with short sentences, especially with full counts/prefixes\n-figure out how many words to use from unigrams list and each iteration of top_continuations\n-figure out how to incorporate both full and filtered counts into 1 model\n-POS tagging for sentences that result in the use of unigram predictions (unseen words in context)\n-figure out whether or not to parse sentences for punctuation -- break for loop if ctx[-(ngram_size - 1)] in [punctuation]'

In [6]:
def close_words(word, pronounce_dict=pronounce_dictionary, reverse_pronounce_dict=reverse_pronounce_dictionary):
    """
    Return all words with the same pronunciations.
    """
    if word not in pronounce_dict:
        return []

    results = set()
    for pronunciation in pronounce_dict[word]:
        pronunciation = tuple(pronunciation)
        
        if pronunciation in reverse_pronounce_dict:
            results.update(reverse_pronounce_dict[pronunciation])

    return list(results)


In [7]:
def nextLastWordIndex(sent, current_index, punct=punctuation):
    """
    Finds the index of the word before the next punctuation mark in a sentence.
    """
    for i in range(current_index + 1, len(sent)):
        if sent[i] in punct:
            return i - 1
        
    return len(sent) - 1  # return last index of sentence if no punctuation found

def best_word(sent: list, wordIndex, candidates,
             counts_full, prefix_denoms_full, total_counts_full,
             smooth=SMOOTH, lambda_vals=LAMBDA, maxSize=MAX_NGRAM, adjs=adjectives, ins=prepositions, nns=nouns):
    """
    Chooses the best candidate replacement for sent[wordIndex] using n-gram context.
    Uses both past and future context. Returns the candidate with highest log-prob.
    """

    wordProbs = defaultdict(float)
    stopIndex = nextLastWordIndex(sent, wordIndex)  # punctuation boundary

    for cand in candidates:
        total_logp = 0.0

        # ----- 1. PAST CONTEXT -----
        # compute P(cand | previous words)
        for n in range(len(lambda_vals)):
            total_logp += lambda_vals[n-1] * prob_from_counts(
                cand,
                n,
                sent[:wordIndex] + [cand],
                counts_full, prefix_denoms_full, total_counts_full
            )

        # ----- 2. FUTURE CONTEXT -----
        # compute P(next_word | cand), P(next2 | cand next), ...
        future_slice = sent[wordIndex+1: stopIndex]
        context_window = [cand] + future_slice  # synthetic window

        for i in range(len(context_window) - 1):
            next_token = context_window[i + 1]

            # n-gram size grows as we move right
            n = min(maxSize, i + 2)

            total_logp += prob_from_counts(
                next_token,
                n,
                context_window[:i+1],
                counts_full, prefix_denoms_full, total_counts_full,
            )

            if n == maxSize:
                break  # no larger n-grams

        wordProbs[cand] = total_logp

    wordProbs[sent[wordIndex]] += smooth   # small boost for original word

    # cases for THEY'RE / THEIR / THERE
    if sent[wordIndex] in {"they're", 'their', 'there'} and wordIndex < len(sent) - 1:
        for i in range(min(len(sent) - wordIndex, maxSize)):
            if sent[wordIndex + i] in ins:
                wordProbs["they're"] += 1
                break

            if sent[wordIndex + i] in adjs:
                continue

    # pick the highest-scoring candidate
    return max(wordProbs, key=wordProbs.get)

This next cell contains the code for the sentence prediction/spellchecking model. The following cell can be used to test it by inputting a sentence.

In [8]:
def sentence_prediction(sent: list, punct=punctuation, nns=nouns):
    output = []

    # Localize globals for speed
    close_words_fn = close_words
    best_word_fn = best_word
    nbest_fn = nbest_continuations
    model = (counts_full, prefix_denoms_full, total_counts_full)
    append = output.append
    stop = nns | punct

    # spellcheck first
    for i, word in enumerate(sent):
        # special cases: , . ? ! ; - --, etc.  (signals to end parsing)
        if word in punct:
            append(word)
            continue
            # TODO: make cases for (), which treats what's inside as its own sentence to be analzyed, but still remembers what's outside the ()

        closest_words = close_words_fn(word) # TODO: prechache close words for all words in Brown corpus beforehand?
        correct_word = best_word_fn(sent, i, closest_words, *model)
        append(correct_word)

    # predict next word until next word is noun or punctuation
    while True:
        print(output)
        nbest = nbest_fn(output, *model, 1)
        next_word = nbest[0][0]
        append(next_word)
        
        if next_word in stop:
            break

    return output

In [15]:
print("input a sentence:")
sent = input()

pattern = r"\d+[:.]\d+|\w+(?:'\w+)?|[.,!?;:()\-—]"
sent = re.findall(pattern, sent)

sentence_prediction(sent)

# predict next word
# nbest_continuations(sent, ...) #TODO: loop function until the next word isn't in to_be_filtered / is punctiation?

# TODO: combine functions so there aren't many nested functions

input a sentence:
['i', 'like', 'to']


KeyboardInterrupt: 