<a href="https://colab.research.google.com/github/GDO-Galileo/do-voice-interaction/blob/error_correction/gdo_voicebot/grammar_correction_service/grammar_checker_model/Grammar_Checker_Model_Usage.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Using the Trained Grammar Checker Model**

This notebook descibes our method of using our custom Grammar Checker model (trained using our [Grammar_Checker_Model_Training](https://colab.research.google.com/drive/1_7RQQPkUHyF3ip5vCI0b2aOxSejOZcTv?usp=sharing) file).

To convert predictions from out model into '1' (correct) and '0' (incorrect) acceptability labels, the outputted predictions must first go through a sigmoid function, then rounded to the nearest integer (which will either be '1' or '0'). Our model was trained with the `BCEWithLogitsLoss()` loss function, which adds a sigmoid activation function during training.

Our model is specifically trained to predict for **lowercase** sentences **without punctuation except for apostropes** as is the output of the GDO's speech-to-text service.

## **Preparation for Predictions**

In [None]:
###################################
#             Imports             #
###################################

import torch
import numpy as np
from transformers import BertModel, BertTokenizer
from keras.preprocessing.sequence import pad_sequences

In [None]:
###################################
#      Upload Trained Model       #
###################################

MODEL_NAME = 'bert-base-uncased-GDO-trained.pth'

# Upload:
#   bert-base-uncased-GDO-trained.pth
from google.colab import files
uploaded = files.upload()

### **The Model Class**

This is a model class using the BERT base uncased model with an extra linear layer to give one output. For more information, see our [Grammar_Checker_Model_Training](https://colab.research.google.com/drive/1_7RQQPkUHyF3ip5vCI0b2aOxSejOZcTv#scrollTo=UGD2ncwqXuH4) file.

In [None]:
##################################
#          Model Class           #
##################################
# with reference to https://stackoverflow.com/questions/64156202/add-dense-layer-on-top-of-huggingface-bert-model
# and documentation at https://huggingface.co/transformers/model_doc/bert.html#bertmodel

class CustomBERTModel(torch.nn.Module):
  def __init__(self):
    super(CustomBERTModel, self).__init__()
    self.bert = BertModel.from_pretrained('bert-base-uncased')
    self.linear = torch.nn.Linear(768, 1)

  # A forward pass through both the BERT model and linear layer
  def forward(self, input_ids):
    outputs = self.bert(input_ids, token_type_ids=None)

    # Gets the ouput of the last hidden layer of the BERT model
    last_hidden_states = outputs.last_hidden_state
    linear_output = self.linear(last_hidden_states[:,0,:])

    return linear_output

  # Modified state_dict to only save linear layer weights and bias
  def state_dict(self):
    return self.linear.state_dict()

  # Modified load_state_dict to only load linear layer weights and bias
  def load_state_dict(self, state_dict):
    self.linear.load_state_dict(state_dict)

### **Loading the Model**

The model's state dictionary is loaded from local storage and applied to our custom model

In [None]:
def load_grammar_checker_model():
  # Set device to 'CPU' since this is only for predictions
  device = torch.device('cpu')

  # Load the uploaded model weights to the model class
  grammar_checker = CustomBERTModel()
  grammar_checker.load_state_dict(torch.load(MODEL_NAME, map_location=device))

  # Set to evaluation mode to prepare for predictions
  grammar_checker.eval()

  return grammar_checker

### **Data Tokenization**

As in our [Grammar_Checker_Model_Training](https://colab.research.google.com/drive/1_7RQQPkUHyF3ip5vCI0b2aOxSejOZcTv#scrollTo=UGD2ncwqXuH4) file, tokenizing is done using the standard BERT base uncased tokenizer with the `do_lower_case` flag set to true (since this is the case for the GDO speech-to-text system).

In [None]:
##################################
#           Tokenizer            #
##################################

def create_tokenizer():
  return BertTokenizer.from_pretrained('bert-base-uncased', 
                                       do_lower_case = True)

## **Making Predictions**

Making predictions is done similarly to testing. In this case, we use the `check_GE` function to go through a "batch" of sentences. Like previously, the sentences are tokenized, padded and converted to numerical tokens to be inputted. After this, a forward step is made and the predictions are converted to acceptability labels of '1' for a correct prediction and '0' for an incorrect prediction using sigmoid and rounding.

In [None]:
# Sigmoid function for use in converting our model's preditions
#   to '0' and '1' tags
def sigmoid(x):
  return 1/(1 + np.exp(-x))


# Check for a grammatical error using the grammar_checker model
#   Takes in a list of sentences for which it will predict a
#   '0' for an INCORRECT or '1' for a CORRECT sentence
def check_GE(sents, checker_model, tokenizer):

  ## Prepare Sentences for Input ##

  # Tokenize each inputted sentence
  tokenized_texts = [tokenizer.tokenize(str(sent)) for sent in sents]

  # Padding sentences to the maximum length sentence
  padded_sequence = [tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts]
  max_len = max([len(txt) for txt in padded_sequence])

  # Pad the input tokens
  input_ids = pad_sequences(padded_sequence, maxlen=max_len, dtype="long",
                            truncating="post", padding="post")

  prediction_inputs = torch.tensor(input_ids)


  ## Make Predictions ##

  # Don't compute or store gradients
  with torch.no_grad():
    # Forward pass, calculate predictions
    logits = checker_model(prediction_inputs)

  # Move predictions to CPU
  logits = logits.detach().cpu().numpy()

  # To calculate the prediction, use sigmoid (since BCEWithLogitsLoss
  #   was used for training) and then round to the nearest integer ('0' or '1')
  predictions = np.rint(sigmoid(logits))

  return predictions

## **Using `check_GE` to Predict Grammatical Errors**

Below is an example of how `check_GE` may be used to make predictions for sentences. 

In [None]:
INPUT_SENTENCES = [
    "the grammar correction on",
    "is the grammar correction on",
    "is grammar correction on",
    "turn on grammar correct",
    "turn on grammar correction",
    "the laboratory display are pretty good",
    "the laboratory displays are pretty good",
    "the laboratory displays are pretty match",
    "displays are pretty good"
]

# Load tokenizer
print("Loading tokeniser...")
tokenizer = create_tokenizer()

# Load trained model
print("Loading checker_model...")
checker_model = load_grammar_checker_model()

# Check sentences with check_GE
print("Checking sentences...")
predictions = check_GE(INPUT_SENTENCES, checker_model, tokenizer)

# Print sentences next to predictions
for i in range(len(INPUT_SENTENCES)):
    print(" " + str(INPUT_SENTENCES[i]) + "\t" + str(predictions[i][0]))