# Error Detection - Training LSTM model (char + word embedding)

## Setting up environment (Pytorch + Pandas + Numpy)

In [1]:
%load_ext autoreload
%autoreload 2

import sys
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np

import collections

if r"../../../kb-data-cleaning/kbclean" not in sys.path:
    sys.path.append(r"../../../kb-data-cleaning/kbclean")

In [2]:
from utils.config import load_hparams

hparams = load_hparams("../../config/hparams.yaml")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading data using TorchText

In [3]:
import regex as re

def preprocess(str1):
    str1 = "".join(str1)
    str1 = re.sub("[A-Z]", "A", str1)
    str1 = re.sub("[a-z]", "a", str1)
    str1 = re.sub("[0-9]", "0", str1)

    return list(str1)

In [4]:
from torch.utils.data import DataLoader, random_split
from torchtext.data import Field, LabelField, NestedField, TabularDataset

char_field = Field(
    tokenize=list,
    pad_token="<wpad>",
    batch_first=True,
    lower=False,
    include_lengths=True,
)

label = LabelField()

dataset = TabularDataset(
    path="../../data/train/train_500000c.csv",
    format="csv",
    fields={
        "str1": [("src_char", char_field)],
        "str2": [("trg_char", char_field)],
        "sim": [("lbl", label)],
    },
)

test_dataset = TabularDataset(
    path="../../data/test/wiki/labeled_output.csv",
    format="csv",
    fields={
        "str1": [("src_char", char_field)],
        "str2": [("trg_char", char_field)],
        "sim": [("lbl", label)]
    },
    csv_reader_params={
        "quotechar":"'"
    }
)

## Building language vocabulary from data

In [5]:
from pathlib import Path

from torchtext import vocab
from torchtext.vocab import GloVe

char_field.build_vocab(dataset.src_char, dataset.trg_char)

label.build_vocab(dataset.lbl)
label.vocab.stoi = {"True": 0, "False": 1}

hparams.char_vocab_size = len(char_field.vocab)

hparams.char_vocab_size

184

## Splitting data into training and validation sets

In [6]:
from torchtext.data import BucketIterator

train_iterator = BucketIterator(dataset, device=device, batch_size=hparams.batch_size)
test_iterator = BucketIterator(test_dataset, device=device, batch_size=hparams.batch_size)

## Implementing callback function for training

In [7]:
from IPython.display import clear_output
from functools import partial


def idx2word(index_sequences, field):
    return " ".join(
        [field.vocab.itos[idx] if idx != 1 else "" for idx in index_sequences]
    )


def show_samples(trainer):
    i = 0
    for batch in train_iterator:
        x, y = batch
#         y_hat = trainer(*x)
        y = y.reshape(-1, 1)

        for i in range(len(y)):
            print(
                f"Example {i}: {idx2word(x[0][0][i], word_field)} ---"
                f" {idx2word(x[1][0][i], word_field)} ==>  --- {y[i].item()}"
            )
        break


## Building and training LSTM model

In [8]:
from pytorch_lightning import Trainer

from ml.nets import Transformer

transformer = Transformer(
    char_vocab_size=hparams.char_vocab_size,
    embedding_size=hparams.char_embedding_size,
    hidden_size=hparams.hidden_size,
    head_size=hparams.head_size,
    num_layers=hparams.num_layers
)

In [13]:
import os
from pytorch_lightning.callbacks import ModelCheckpoint

trainer = Trainer(gpus=1, amp_level='O1', benchmark=False, default_save_path="../../checkpoints", )
trainer.fit(transformer, train_dataloader=train_iterator, val_dataloaders=[test_iterator])

HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=22.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=22.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=22.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=22.0, style=Pro…




1