In [1]:
import os
from model.word_lm import SpellCorrectionModel
from model.char_lm import CharTokenizer
from data.dataset import TypoDataset
import torch
from tqdm import tqdm
from tabulate import tabulate
from torch.utils.tensorboard import SummaryWriter


# Evaluate on very noisy dataset

In [2]:
output_dir = "./bluebert-finetuned-mimic-v1/"
model = SpellCorrectionModel(NCBI_BERT = output_dir, config_file= "/config.json", max_candidates= 150)
typo_tokenizer = CharTokenizer()
model.to(model.device)
writer = SummaryWriter(log_dir='logs')

In [3]:
BATCH_SIZE = 30
dataset_val = TypoDataset(os.path.join("data/dlh_multiple_misspelling", 'test.tsv'), model.tokenizer, typo_tokenizer, num_process = 2)
dataloader_val = torch.utils.data.DataLoader(dataset_val,
                                                     batch_size=BATCH_SIZE,
                                                     shuffle=False,
                                                     drop_last=True,
                                                     num_workers=0,
                                                     collate_fn=dataset_val.get_collate_fn())

Read file data/dlh_multiple_misspelling/test.tsv... 30 rows
Parsing rows (2 processes)


100%|██████████| 30/30 [00:04<00:00,  7.47it/s]


In [4]:
data_length = len(dataset_val)

In [9]:
# Evaluation
val_iter = iter(dataloader_val)
global_step = 0


model.eval()
progress_bar = tqdm(range(data_length))
global_total = 0
global_correct = 0
while global_step < len(dataset_val):
    # Evaluation
    try:
        batch_val = next(val_iter)
    except StopIteration:
        val_iter = iter(dataloader_val)
        batch_val = next(val_iter)
    input_ids = batch_val["context_tokens"]
    attention_mask = batch_val['context_attention_mask']
    misspelling = batch_val['typo']
    correct_spelling = batch_val['correct']
    
    with torch.no_grad():
        _,  prediction = model.forward(input_ids, attention_mask, misspelling, correct_spelling)
    global_total += len(correct_spelling)
    current_correct = 0
    for index in range(len(correct_spelling)):
        if correct_spelling[index] == prediction[index]:
            global_correct +=1
    
    
    progress_bar.update(BATCH_SIZE)
    #if global_step %50 == 0:
    print(f'Total/Correct = {global_total} / {global_correct}')
    global_step+=BATCH_SIZE

    writer.add_scalar('Evaluation/Noisy-30', global_correct/global_total, global_step=global_step/BATCH_SIZE)


100%|██████████| 30/30 [00:27<00:00,  1.10it/s]
100%|██████████| 30/30 [00:02<00:00, 14.94it/s]

Total/Correct = 30 / 2


100%|██████████| 30/30 [00:15<00:00, 14.94it/s]

In [67]:
print(tabulate({"Correct": correct_spelling, "Misspelling": misspelling, "Prediction": prediction}, headers="keys"))

Correct          Misspelling     Prediction
---------------  --------------  ------------
[accommodation]  acomodationn    condition
ascites          uhsits          this
aphasia          afaciia         again
asymmetry        asimetry        symmetry
basilar          bazillar        vascular
brachial         brakiale        radial
calluses         colousses       masses
catheterization  cathritzacion   radiation
circumferential  circumfrencial  concurrent
chlamydia        kluhmideeuh     hidden
cords            chords          chords
diaphragm        dyufram         cuff
dyspareunia      disparoonia     diagnosis
epididymis       epideedimus     specimen
exacerbated      eggsaberted     ##ated
hemorrhage       hemrage         damage
hygiene          hijeen          hygiene
malacia          malaysia        mass
mucus            moucous         mouth
oophorectomy     ooforektomy     effort
ophthalmology    optomology      pathology
palliative       palativee       patient
pleurisy      

## Evaluation on the evaluation dataset

In [10]:
BATCH_SIZE = 100
dataset_val = TypoDataset(os.path.join("data/mimic_synthetic", 'val.tsv'), model.tokenizer, typo_tokenizer)
dataloader_val = torch.utils.data.DataLoader(dataset_val,
                                                     batch_size=BATCH_SIZE,
                                                     shuffle=False,
                                                     drop_last=True,
                                                     num_workers=0,
                                                     collate_fn=dataset_val.get_collate_fn())

Read file data/mimic_synthetic/val.tsv... 10000 rows
Parsing rows (10 processes)


100%|██████████| 10000/10000 [00:08<00:00, 1180.17it/s]


In [11]:
data_length = len(dataset_val)
data_length

10000

In [14]:
# Evaluation
val_iter = iter(dataloader_val)
global_step = 0


model.eval()
progress_bar = tqdm(range(data_length))
global_total = 0
global_correct = 0
while global_step < len(dataset_val):
    # Evaluation
    try:
        batch_val = next(val_iter)
    except StopIteration:
        val_iter = iter(dataloader_val)
        batch_val = next(val_iter)
    input_ids = batch_val["context_tokens"]
    attention_mask = batch_val['context_attention_mask']
    misspelling = batch_val['typo']
    correct_spelling = batch_val['correct']
    
    with torch.no_grad():
        _,  prediction = model.forward(input_ids, attention_mask, misspelling, correct_spelling)
    global_total += len(correct_spelling)
    current_correct = 0
    for index in range(len(correct_spelling)):
        if correct_spelling[index] == prediction[index]:
            global_correct +=1
    
    
    progress_bar.update(BATCH_SIZE)
    print(f'Total/Correct = {global_total} / {global_correct}')
    global_step+=BATCH_SIZE
    writer.add_scalar('Evaluation/model-eval', global_correct/global_total, global_step=global_step/BATCH_SIZE)


  4%|▍         | 400/10000 [00:53<21:16,  7.52it/s]


Total/Correct = 100 / 62




Total/Correct = 200 / 138




Total/Correct = 300 / 207




Total/Correct = 400 / 281




Total/Correct = 500 / 355




Total/Correct = 600 / 422




Total/Correct = 700 / 494




Total/Correct = 800 / 558




Total/Correct = 900 / 620




Total/Correct = 1000 / 691




Total/Correct = 1100 / 753




Total/Correct = 1200 / 818




Total/Correct = 1300 / 895




Total/Correct = 1400 / 963




Total/Correct = 1500 / 1032




Total/Correct = 1600 / 1098




Total/Correct = 1700 / 1166




Total/Correct = 1800 / 1237




Total/Correct = 1900 / 1310




Total/Correct = 2000 / 1378




Total/Correct = 2100 / 1440




Total/Correct = 2200 / 1506




Total/Correct = 2300 / 1580




Total/Correct = 2400 / 1651




Total/Correct = 2500 / 1727




Total/Correct = 2600 / 1792




Total/Correct = 2700 / 1863




Total/Correct = 2800 / 1934




Total/Correct = 2900 / 2008




Total/Correct = 3000 / 2081




Total/Correct = 3100 / 2146




Total/Correct = 3200 / 2218




Total/Correct = 3300 / 2290




Total/Correct = 3400 / 2355




Total/Correct = 3500 / 2418




Total/Correct = 3600 / 2484




Total/Correct = 3700 / 2557




Total/Correct = 3800 / 2629




Total/Correct = 3900 / 2700




Total/Correct = 4000 / 2770




Total/Correct = 4100 / 2829




Total/Correct = 4200 / 2894




Total/Correct = 4300 / 2965




Total/Correct = 4400 / 3039




Total/Correct = 4500 / 3110




Total/Correct = 4600 / 3177




Total/Correct = 4700 / 3246




Total/Correct = 4800 / 3325




Total/Correct = 4900 / 3381




Total/Correct = 5000 / 3444




Total/Correct = 5100 / 3519




Total/Correct = 5200 / 3594




Total/Correct = 5300 / 3663




Total/Correct = 5400 / 3739




Total/Correct = 5500 / 3805




Total/Correct = 5600 / 3874




Total/Correct = 5700 / 3929




Total/Correct = 5800 / 3994




Total/Correct = 5900 / 4061




Total/Correct = 6000 / 4130




Total/Correct = 6100 / 4199




Total/Correct = 6200 / 4268




Total/Correct = 6300 / 4332




Total/Correct = 6400 / 4395




Total/Correct = 6500 / 4469




Total/Correct = 6600 / 4538




Total/Correct = 6700 / 4614




Total/Correct = 6800 / 4692




Total/Correct = 6900 / 4762




Total/Correct = 7000 / 4836




Total/Correct = 7100 / 4909




Total/Correct = 7200 / 4976




Total/Correct = 7300 / 5047




Total/Correct = 7400 / 5125




Total/Correct = 7500 / 5197




Total/Correct = 7600 / 5266




Total/Correct = 7700 / 5334




Total/Correct = 7800 / 5407




Total/Correct = 7900 / 5474




Total/Correct = 8000 / 5542




Total/Correct = 8100 / 5617




Total/Correct = 8200 / 5690




Total/Correct = 8300 / 5750




Total/Correct = 8400 / 5813




Total/Correct = 8500 / 5877




Total/Correct = 8600 / 5948




Total/Correct = 8700 / 6019




Total/Correct = 8800 / 6097




Total/Correct = 8900 / 6170




Total/Correct = 9000 / 6240




Total/Correct = 9100 / 6307




Total/Correct = 9200 / 6379




Total/Correct = 9300 / 6442




Total/Correct = 9400 / 6508




Total/Correct = 9500 / 6577




Total/Correct = 9600 / 6645




Total/Correct = 9700 / 6717




Total/Correct = 9800 / 6782




Total/Correct = 9900 / 6855




Total/Correct = 10000 / 6928


# Model Evaluation on Synthetic Validation Set

In [15]:
BATCH_SIZE = 100
dataset_val = TypoDataset(os.path.join("data/dlh_mimic_synthetic", 'val.tsv'), model.tokenizer, typo_tokenizer)
dataloader_val = torch.utils.data.DataLoader(dataset_val,
                                                     batch_size=BATCH_SIZE,
                                                     shuffle=False,
                                                     drop_last=True,
                                                     num_workers=0,
                                                     collate_fn=dataset_val.get_collate_fn())
data_length = len(dataset_val)
data_length

Read file data/dlh_mimic_synthetic/val.tsv... 10000 rows
Parsing rows (10 processes)


100%|██████████| 10000/10000 [00:08<00:00, 1147.19it/s]


10000

In [17]:
# Evaluation
val_iter = iter(dataloader_val)
global_step = 0


model.eval()
progress_bar = tqdm(range(data_length))
global_total = 0
global_correct = 0
while global_step < len(dataset_val):
    # Evaluation
    try:
        batch_val = next(val_iter)
    except StopIteration:
        val_iter = iter(dataloader_val)
        batch_val = next(val_iter)
    input_ids = batch_val["context_tokens"]
    attention_mask = batch_val['context_attention_mask']
    misspelling = batch_val['typo']
    correct_spelling = batch_val['correct']
    
    with torch.no_grad():
        outputs,  prediction, label = model.forward(input_ids, attention_mask, misspelling, correct_spelling)
    global_total += len(correct_spelling)
    current_correct = 0
    for index in range(len(correct_spelling)):
        if correct_spelling[index] == prediction[index]:
            global_correct +=1
    
    
    progress_bar.update(BATCH_SIZE)
    print(f'Total/Correct = {global_total} / {global_correct}')
    global_step+=BATCH_SIZE
    writer.add_scalar('Evaluation/synthetic-eval', global_correct/global_total, global_step=global_step/BATCH_SIZE)


  0%|          | 0/10000 [12:23<?, ?it/s]
