In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from nltk.tokenize import TreebankWordTokenizer
from nltk.corpus import wordnet as wn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if torch.cuda.is_available() :
  print("CUDA")
  device = torch.device("cuda")
# elif torch.backends.mps.is_available() :
#   print("M1-mps")
#   device = torch.device("mps")
else :
  print("CPU")
  device = torch.device("cpu")

CPU


In [3]:
class TestDataset(Dataset) :
  #Dataset - English/typo-added/labeled
  def __init__(self, df) :
    self.df = df
  
  def __len__(self) :
    return len(self.df)
  
  def __getitem__(self, idx):
    text = self.df.iloc[idx, 0]
    label = self.df.iloc[idx, 1]
    return text, label

In [4]:
train_rate, test_rate = 0.9, 0.09
itr = 1
p_itr = 100
epochs = 5
batch = 10

In [5]:
total_df = pd.read_csv('augmented_data/Dataset_aug_complex_10424_.csv', sep=',')
#total_df = pd.read_csv('augmented_data/Dataset_aug_complex_10424_original.csv', sep=',')

total_df.dropna(inplace=True)
total_df = total_df[["text", "label"]]
total_df["label"] = [1 if i == "nothate" else 0 for i in total_df["label"]]
print(total_df)
total_dataset = TestDataset(total_df)
total_loader = DataLoader(total_dataset, batch_size=batch, shuffle=True)

                                                    text  label
0      just called this hotch a cum dumpster on the t...      0
1      This dress looks fudk7nf bad on me and I look ...      1
2      If they people don't react well, to you coming...      1
3      "There is no correlation between black populat...      0
4                F6sk yeah, I would kill for some fries.      1
...                                                  ...    ...
10419  would you like to be called brown? Then stop u...      1
10420  i hate being a lesbn#1 living in this country,...      1
10421  The curfew you wanted but now you don't, make ...      1
10422  I love how HE has to come on to this forum, ch...      0
10423  would you like to be called a q7eeT? Then stop...      1

[10424 rows x 2 columns]


In [6]:
test_df, train_df, _ = np.split(total_df, [int(test_rate*len(total_df)), int(test_rate*len(total_df) + train_rate*len(total_df))])
print(len(test_df), len(train_df))
train_dataset = TestDataset(train_df)
train_loader = DataLoader(train_dataset, batch_size=batch, shuffle=True)
test_dataset = TestDataset(test_df)
test_loader = DataLoader(test_dataset, batch_size=batch, shuffle=True)

938 9381


# Application of SpellChecker-Seq2Seq Model

# Hyperparams

1. LSTM (Single-Direction) w/ Attention <br>
lr = 5e-4 <br>
embedding_dim = 512 <br>
hidden_size = 512 <br>
epochs = 20 <br>
batch = 10 <br>

2. RNN (Bi-Direction) <br>
lr = 5e-4 <br>
embedding_dim = 512 <br>
hidden_size = 512 <br>
epochs = 20 <br>
batch = 10 <br>

In [7]:
from Spell_correction_model import *

wordnetdict = wn.words(lang='eng')
tokenizer = TreebankWordTokenizer()

model_path = "models/"
lstm_name = "spelling_lstm.model"
rnn_name = "spelling_base_rnn.model"

model = torch.load(model_path + rnn_name, map_location=device)

In [8]:
#Do THIS at the preprocessing-part of training loop of the classification model
# corrected_texts = []
# for text, label in train_loader :
#     text_ascii = ["".join([c for c in t if ascii_range(c)]) for t in text]
#     text_corrected = [spell_correction(t, tokenizer, wordnetdict, model, device=device, model_type="rnn") for t in text_ascii]
#     corrected_texts += text_corrected

In [9]:
corrected_texts = []
i=0
interval=1000
for text in total_df['text'] :
    text_ascii = "".join([c for c in text if ascii_range(c)])
    text_corrected = spell_correction(text_ascii, tokenizer, wordnetdict, model, device=device, model_type="rnn")
    corrected_texts.append(text_corrected)
    if i%interval==0 : print("i : {} / len : {}".format(i, len(corrected_texts)))
    i+=1

i : 0 / len : 1
i : 1000 / len : 1001
i : 2000 / len : 2001
i : 3000 / len : 3001
i : 4000 / len : 4001
i : 5000 / len : 5001
i : 6000 / len : 6001
i : 7000 / len : 7001
i : 8000 / len : 8001
i : 9000 / len : 9001
i : 10000 / len : 10001


In [10]:
assert(len(corrected_texts)==len(total_df))
total_df['text_corrected'] = corrected_texts
total_df.to_csv("augmented_data/Dataset_aug_complex_{}_spellcheck_rnn.csv".format(len(total_df)),sep=',')