In [1]:
#!pip install torch fairseq transformers dill fastDamerauLevenshtein tensorboardX accelerate textdistance

In [12]:
import os
from model.word_lm import SpellCorrectionModel
from model.char_lm import CharTokenizer
from data.dataset import TypoDataset
import torch
from tqdm.auto import tqdm
import torch
from data.dataset import TypoOnlineDataset
from fastDamerauLevenshtein import damerauLevenshtein
from torch.utils.tensorboard import SummaryWriter
from transformers import pipeline, AutoModelForMaskedLM


In [3]:
model = SpellCorrectionModel(config_file="/bert_config.json")
typo_tokenizer = CharTokenizer()
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
writer = SummaryWriter(log_dir='logs')


Some weights of the model checkpoint at ./bert/ncbi_bert_base/pytorch_model.bin were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
dataset_train = TypoOnlineDataset("data/mimic3/split", "data/lexicon/lexicon_en.json", model.tokenizer, typo_tokenizer,2)
dataset_val = TypoDataset(os.path.join("data/mimic_synthetic", 'val.tsv'), model.tokenizer, typo_tokenizer)

Read file data/mimic_synthetic/val.tsv... 10000 rows
Parsing rows (10 processes)


100%|██████████| 10000/10000 [00:09<00:00, 1098.78it/s]


In [5]:
'''
Loss function with misspelling penalty
'''
def loss_function(probabilities, correct_label, predicted_spellings, correct_spellings):
    loss = torch.nn.functional.cross_entropy(probabilities.view(-1, model.tokenizer.vocab_size), correct_label.view(-1))
    distance = damerauLevenshtein(' '.join(predicted_spellings), ' '.join(correct_spellings))
    total_loss = loss + 0.5 * distance
    return total_loss

In [6]:
BATCH_SIZE = 100
dataloader_train = torch.utils.data.DataLoader(dataset_train,
                                                batch_size=BATCH_SIZE,
                                                num_workers=0,
                                                collate_fn=dataset_train.get_collate_fn())

dataloader_val = torch.utils.data.DataLoader(dataset_val,
                                                     batch_size=BATCH_SIZE,
                                                     shuffle=False,
                                                     drop_last=True,
                                                     num_workers=0,
                                                     collate_fn=dataset_val.get_collate_fn())

In [7]:
from torch.optim import AdamW
from accelerate import Accelerator

num_epochs = 5
warmup_proportion = 0.1
num_training_steps = 10000

accelerator = Accelerator()
optimizer = AdamW(model.parameters(), lr=5e-5)
total_steps = num_training_steps * num_epochs
warmup_steps = int(warmup_proportion * total_steps)

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, dataloader_train, dataloader_val
)


In [8]:
from transformers import get_scheduler

output_dir = "bluebert-finetuned-mimic-v1"
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=warmup_steps, 
    num_training_steps=total_steps
   )


In [9]:
model.train()
train_iter = iter(dataloader_train)
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    progress_bar = tqdm(range(1000))
    global_step = 0

    while global_step < 1000:
        # Training
        batch = next(train_iter)
        input_ids = batch["context_tokens"]
        attention_mask = batch['context_attention_mask']
        #print(attention_mask.dtype)
        misspelling = batch['typo']
        correct_spelling = batch['correct']
        
        outputs, prediction = model.forward(input_ids, attention_mask, misspelling, correct_spelling)
        
        # Compute the loss
        loss = model.loss
        train_loss += loss.item()
        accelerator.backward(loss)


        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(BATCH_SIZE)
        # Save and store
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.bert.save_pretrained(output_dir, save_function=accelerator.save)
        unwrapped_model.config.save_pretrained(output_dir, save_function=accelerator.save)
        unwrapped_model.tokenizer.save_pretrained(output_dir, save_function=accelerator.save)

        if accelerator.is_main_process:
            model.bert.save_pretrained(output_dir)
            model.config.save_pretrained(output_dir)
            model.tokenizer.save_pretrained(output_dir)
            lr_scheduler.step()
        global_step += BATCH_SIZE
    train_loss /= (1000/BATCH_SIZE)
    writer.add_scalar('Loss', train_loss, global_step=epoch)
    print(f"Epoch {epoch+1}, train loss: {train_loss:.4f}")

    

  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 1, train loss: 0.0490


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 2, train loss: 0.0525


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 3, train loss: 0.0468


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 4, train loss: 0.0464


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 5, train loss: 0.0349


In [15]:
from transformers import pipeline, AutoModelForMaskedLM

model_trained = AutoModelForMaskedLM.from_pretrained("./bluebert-finetuned-mimic")

mask_filler = pipeline(
    "fill-mask", model=output_dir, top_k= 20
)

preds = mask_filler("Patient mentioned she took tylenol for the [MASK]")

for pred in preds:
    print(f"{pred['sequence']}")
    

patient mentioned she took tylenol for the.
patient mentioned she took tylenol for the?
patient mentioned she took tylenol for the pain
patient mentioned she took tylenol for the headache
patient mentioned she took tylenol for the fever
patient mentioned she took tylenol for the rash
patient mentioned she took tylenol for the discomfort
patient mentioned she took tylenol for the ha
patient mentioned she took tylenol for the cough
patient mentioned she took tylenol for the same
patient mentioned she took tylenol for the past
patient mentioned she took tylenol for the day
patient mentioned she took tylenol for the night
patient mentioned she took tylenol for the swelling
patient mentioned she took tylenol for the nausea
patient mentioned she took tylenol for the family
patient mentioned she took tylenol for the flu
patient mentioned she took tylenol for the last
patient mentioned she took tylenol for the morning
patient mentioned she took tylenol for the am


In [16]:
a = input()
a

'apple'