In [1]:
import re
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import pytorch_lightning as pl
from transformers import BertTokenizer, BertForMaskedLM
import random
from string import ascii_letters
from datasets import Dataset
from pytorch_lightning.loggers import TensorBoardLogger

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from lightning_bert import *

In [3]:
def map_text(items):
    result = tokenizer(list(
        map(replace_augment, items["input"])), truncation=True, padding="max_length")
        
    target = tokenizer(
        items['input'], truncation=True, padding="max_length")
    result['target_ids'] = target["input_ids"]
    result['target_attention_mask'] = target["attention_mask"]

    return result

In [4]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

texts = load_data("en.txt")
texts = [text.lower() for text in texts]
my_dict = {"input":texts}
dataset = Dataset.from_dict(my_dict)
dataset = dataset.map(map_text, batched=True,num_proc=4)
dataset = dataset.remove_columns("input")
dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'target_ids', 'target_attention_mask'])

dataset = dataset.train_test_split(test_size=0.2)

train_dataloader = DataLoader(dataset["train"], shuffle=True, batch_size=8)
test_dataloader = DataLoader(dataset["test"], batch_size=16)



#0:   0%|          | 0/3 [00:00<?, ?ba/s]
[A

[A[A

#0:  33%|███▎      | 1/3 [00:00<00:01,  1.24ba/s]
[A

[A[A
#0:  67%|██████▋   | 2/3 [00:01<00:00,  1.27ba/s]
#2: 100%|██████████| 3/3 [00:01<00:00,  1.54ba/s]


#3: 100%|██████████| 3/3 [00:01<00:00,  1.54ba/s]
#1: 100%|██████████| 3/3 [00:01<00:00,  1.53ba/s]
#0: 100%|██████████| 3/3 [00:01<00:00,  1.52ba/s]


In [5]:
# model
model = SpellingBert()

#logger

logger = TensorBoardLogger("tb_logs", name="my_model")
# training
trainer = pl.Trainer(logger=logger,gpus=1, precision=16, max_epochs=2)# ,limit_train_batches=0.5
trainer.fit(model, train_dataloader, test_dataloader)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | model | BertForMaskedLM  | 109 M 
1 | loss

Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  6.80it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 1: 100%|██████████| 1125/1125 [03:58<00:00,  4.71it/s, loss=2.02, v_num=21]


In [17]:
#sanity check

inputs = tokenizer("my nahme is peter", return_tensors="pt")

with torch.no_grad():
    logits = model.model(**inputs).logits

predicted_token_ids = logits.argmax(axis=-1)
print(tokenizer.decode(predicted_token_ids[0]))

[CLS] my family is is. [SEP]


In [7]:
a = tokenizer("my name is slim", return_tensors="pt",add_special_tokens=True)#
a

{'input_ids': tensor([[  101,  2026,  2171,  2003, 11754,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [8]:
tokenizer.decode(a['input_ids'][0])

'[CLS] my name is slim [SEP]'

In [9]:
tokenizer("my name is slim", return_tensors="pt", add_special_tokens=False)

{'input_ids': tensor([[ 2026,  2171,  2003, 11754]]), 'token_type_ids': tensor([[0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1]])}