# CMPT 413 Programming Homework 1: Contextual Spell Checking

## Importing Libraries

In [None]:
from transformers import pipeline
from difflib import SequenceMatcher
import logging
import os
import csv

## Setting up Logging Config and Loading Pipeline with Pre-Trained Model

In [None]:
logging.basicConfig(level=logging.DEBUG)

fill_mask = pipeline('fill-mask', model='distilbert-base-uncased')
mask = fill_mask.tokenizer.mask_token

## Extract Typo Locations from a TSV File.
    
### Args:
fh (file): File handle of the TSV file.
        
### Yields:
Tuple: Typo locations (indices) and the corresponding sentence tokens.


In [None]:
def get_typo_locations(fh):
    tsv_f = csv.reader(fh, delimiter='\t')
    for line in tsv_f:
        yield (
            [int(i) for i in line[0].split(',')],
            line[1].split()
        )

# Select the Best Correction for a Given Typo Based on Prediction Results.
    
## Args:
typo (str): Original typo word.
    predict (list): List of predicted replacements with associated scores.
        
## Returns:
str: Best correction for the typo.

## Approach Used:
The select_correction function chooses the most suitable correction for a typo from a list of predictions. It compares the similarity ratios between the original typo and predicted replacements, selecting the one with the highest ratio while considering capitalization. The chosen correction is returned.

In [None]:
def select_correction(typo, predict):
    best_correction = typo
    best_ratio = 0

    for pred in predict:
        pred_str = pred['token_str']
        pred_lower = pred_str.lower()

        ratio = SequenceMatcher(None, typo, pred_lower).ratio()

        if ratio > best_ratio and (typo != pred_str):
            if typo[0].isupper():
                best_correction = pred_str.capitalize()
            else:
                best_correction = pred_str

            best_ratio = ratio

    return best_correction

# Perform Spell-Checking on a Given Text File with Typos.
    
## Args:
fh (file): File handle of the text file.
        
## Yields:
Tuple: Typo locations (indices) and the spell-checked sentence tokens.

In [None]:

def spellchk(fh):
    for (locations, sent) in get_typo_locations(fh):
        spellchk_sent = sent
        for i in locations:
            predict = fill_mask(
                " ".join([sent[j] if j != i else mask for j in range(len(sent))]),
                top_k=200
            )
            logging.info(predict)
            spellchk_sent[i] = select_correction(sent[i], predict)
        yield (locations, spellchk_sent)

# Main Execution Block

In [None]:

if __name__ == '__main__':
    import argparse

    # Command-line argument parsing
    argparser = argparse.ArgumentParser()
    argparser.add_argument("-i", "--inputfile",
                           dest="input",
                           default=os.path.join('data', 'input', 'dev.tsv'),
                           help="File to segment")
    argparser.add_argument("-l", "--logfile",
                           dest="logfile",
                           default=None,
                           help="Log file for debugging")
    opts = argparser.parse_args()

    # Configure logging to write to the specified log file
    if opts.logfile is not None:
        logging.basicConfig(filename=opts.logfile, filemode='w', level=logging.DEBUG)

    # Open the input file and perform spell-checking
    with open(opts.input) as f:
        for (locations, spellchk_sent) in spellchk(f):
            print("{locs}\t{sent}".format(
                locs=",".join([str(i) for i in locations]),
                sent=" ".join(spellchk_sent)
            ))