In [None]:
from tokenizers import decoders, models, normalizers, pre_tokenizers, processors, trainers, Tokenizer, Regex
from transformers import AutoTokenizer, AlbertTokenizerFast, T5TokenizerFast

In [None]:
import json

with open('../input/chestxraycaption/mimic_cxr/mimic_cxr/annotation.json') as f:
    data = json.load(f)

In [None]:
all_texts = [d['report'] for d in data['train']]

In [None]:
batch_size = 1000
def batch_iterator():
    for i in range(0, len(all_texts), batch_size):
        yield all_texts[i : i + batch_size]

In [None]:
tokenizer = Tokenizer(models.Unigram())

In [None]:
tokenizer.normalizer = normalizers.Sequence([
    normalizers.Replace("``", '"'),
    normalizers.Replace("''", '"'),
    normalizers.Replace('"', ''),
    normalizers.Replace('/', ''),
    normalizers.Replace('\\', ''),
    normalizers.Replace("'", ''),
    normalizers.Replace(Regex('\d. '), '. '),
    normalizers.Replace(Regex("\.( \.)+"), ''),
    normalizers.Replace(Regex('_+'), '_'),
    normalizers.Replace(Regex('\s+'), ' '),
    normalizers.Replace(Regex('\.+'), '.'),
    normalizers.Replace(Regex('[?;*!%^&_+():-]'), ''),
    normalizers.Replace('.', ' .'),
    normalizers.Replace(',', ' ,'),
    normalizers.Lowercase(),    
    normalizers.Strip()
])

tokenizer.normalizer.normalize_str(all_texts[0])

In [None]:
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace()

tokenizer.pre_tokenizer.pre_tokenize_str(all_texts[0])

In [None]:
trainer = trainers.UnigramTrainer(
    vocab_size=len(set(' '.join(all_texts).split(' '))),
    special_tokens=t5_tokenizer.all_special_tokens,
    unk_token=t5_tokenizer.unk_token,
)

tokenizer.train_from_iterator(batch_iterator(), trainer=trainer)

tokenizer.get_vocab_size()

In [None]:
sep_token_id = tokenizer.token_to_id("</s>")

tokenizer.post_processor = processors.TemplateProcessing(
    single="$A:0 </s>:0",
    pair="$A:0 </s>:0 $B:1 </s>:1",
    special_tokens=[
        ("</s>", sep_token_id),
    ],
)

tokenizer.decoder = decoders.Metaspace()

In [None]:
xray_tokenizer = T5TokenizerFast(tokenizer_object=tokenizer)
xray_tokenizer.save_pretrained('t5-mimic-cxr')