# Error Detection - Training LSTM model (char + word embedding)

## Setting up environment (Pytorch + Pandas + Numpy)

In [1]:
%load_ext autoreload
%autoreload 2

import sys
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np

import collections

if r"../../../kb-data-cleaning/kbclean" not in sys.path:
    sys.path.append(r"../../../kb-data-cleaning/kbclean")

In [2]:
from utils.config import load_hparams

hparams = load_hparams("../../config/hparams.yaml")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading data using TorchText

In [3]:
import regex as re

def preprocess(str1):
    str1 = "".join(str1)
    str1 = re.sub("[A-Z]", "A", str1)
    str1 = re.sub("[a-z]", "a", str1)
    str1 = re.sub("[0-9]", "0", str1)

    return list(str1) 

type_to_regex = {"UPPERCASE": "[A-Z]+", "LOWERCASE": "[a-z]+", "DIGIT": "[0-9]+", "ALPHABET": "[A-Za-z]+","ALPHANUM": "[A-Za-z0-9]+"}

def mask_word(tokens):
    masked_tokens = []
    for token in tokens:
        for type_, regex in type_to_regex.items():
            if re.match(f"^{regex}$", token):
                masked_tokens.append(type_)
                break
        else:
            masked_tokens.append(token)
    return masked_tokens

In [4]:
from torchtext.data import (
    TabularDataset,
    Field,
    NestedField,
    LabelField,
)
from torch.utils.data import random_split, DataLoader

nesting_field = Field(
    tokenize=list,
    pad_token="<cpad>",
    init_token="<w>",
    eos_token="</w>",
    batch_first=True,
    fix_length=hparams.max_char_length,
    preprocessing=preprocess
)

char1w_field = NestedField(
    nesting_field,
    pad_token="<wpad>",
    include_lengths=True
)

word_field = Field(
    pad_token="<wpad>",
    batch_first=True,
    lower=False,
    include_lengths=True,
)

label = LabelField()

dataset = TabularDataset(
    path="../../data/train/train_500000c.csv",
    format="csv",
    fields={
        "str1": [("src_word", word_field), ("src_char", char1w_field)],
        "str2": [("trg_word", word_field), ("trg_char", char1w_field)],
        "sim": [("lbl", label)],
    },
)

test_dataset = TabularDataset(
    path="../../data/test/wiki/labeled_output.csv",
    format="csv",
    fields={
        "str1": [("src_word", word_field), ("src_char", char1w_field)],
        "str2": [("trg_word", word_field), ("trg_char", char1w_field)],
        "sim": [("lbl", label)]
    },
    csv_reader_params={
        "quotechar":"'"
    }
)

## Building language vocabulary from data

In [5]:
from torchtext import vocab
from pathlib import Path
from torchtext.vocab import GloVe

word_field.build_vocab(dataset.src_word, dataset.trg_word)
char1w_field.build_vocab(dataset.src_char, dataset.trg_char)

label.build_vocab(dataset.lbl)
label.vocab.stoi = {"True": 0, "False": 1}

hparams.word_vocab_size = len(word_field.vocab)
hparams.char_vocab_size = len(char1w_field.vocab)

hparams.char_vocab_size, hparams.word_vocab_size

(126, 33014)

## Splitting data into training and validation sets

In [6]:
from torchtext.data import BucketIterator

train_iterator = BucketIterator(dataset, device=device, batch_size=hparams.batch_size)
test_iterator = BucketIterator(test_dataset, device=device, batch_size=hparams.batch_size)

## Implementing callback function for training

In [7]:
from IPython.display import clear_output
from functools import partial


def idx2word(index_sequences, field):
    return " ".join(
        [field.vocab.itos[idx] if idx != 1 else "" for idx in index_sequences]
    )


def show_samples(trainer):
    i = 0
    for batch in train_iterator:
        x, y = batch
#         y_hat = trainer(*x)
        y = y.reshape(-1, 1)

        for i in range(len(y)):
            print(
                f"Example {i}: {idx2word(x[0][0][i], word_field)} ---"
                f" {idx2word(x[2][0][i], word_field)} ==>  --- {y[i].item()}"
            )
        break
        
show_samples(None)

Example 0: 12C   --- 10C    ==>  --- 0
Example 1: L, 51-7  --- Oct 6   ==>  --- 1
Example 2: blocker   --- refactor    ==>  --- 1
Example 3: Apr 3, 2007 --- US7489094    ==>  --- 1
Example 4: at Texas  --- Home    ==>  --- 1
Example 5: .000   --- 16.6    ==>  --- 0
Example 6: Fort Madison, Iowa --- virginia    ==>  --- 0
Example 7: 24 Feb 2011 --- 18 Mar 2009  ==>  --- 0
Example 8: 19 Mar 2009 --- 1 Feb 2013  ==>  --- 0
Example 9: 12/10/2012 03:39AM  --- 12/10/2012 03:08AM   ==>  --- 0
Example 10: comments 0  --- # times read 1 ==>  --- 1
Example 11: Sioux City  --- 67°F    ==>  --- 1
Example 12: -7.6 °F  --- West (280°)   ==>  --- 1
Example 13: REMI   --- Mar 17, 2008  ==>  --- 1
Example 14: 1   --- /tmp/apc/apc.9BtWBv    ==>  --- 0
Example 15: Glg   --- NY    ==>  --- 0


## Building and training LSTM model

In [8]:
from pytorch_lightning import Trainer


glove = vocab.GloVe(name="6B", dim=300)

# char_cnn = CharCNN1W(
#     max_char_length=20,
#     char_vocab_size=hparams.char_vocab_size,
#     char_embedding_size=hparams.char_embedding_size,
#     dropout=0.2,
#     output_size=50,
# )

char_cnn = MultiCharCNN(
    char_vocab_size=hparams.char_vocab_size,
    char_embedding_size=hparams.char_embedding_size
)

lstm = CharCNNLSTM(
    char_cnn,
    word_vocab_size=hparams.word_vocab_size,
    embedding_size=hparams.embedding_size,
    hidden_size=hparams.hidden_size,
    pretrained_embeddings=glove.vectors
)

In [9]:
import os
from pytorch_lightning.callbacks import ModelCheckpoint

trainer = Trainer(gpus=1, amp_level='O1', benchmark=False, default_save_path="../../checkpoints", )
trainer.fit(lstm, train_dataloader=train_iterator, val_dataloaders=[test_iterator])

HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=22.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=22.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=22.0, style=Pro…




1