In [None]:
from transformers import BertTokenizer, BertForMaskedLM
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
import torch
import os

In [None]:
# Step 1: Load pretrained BERT model and tokenizer
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForMaskedLM.from_pretrained(model_name)

In [None]:
# Step 2: Extend tokenizer with new domain-specific words
new_tokens = ["angiocardiography", "echocardiogram", "neurofibromatosis"]
num_added = tokenizer.add_tokens(new_tokens)
print(f"Added {num_added} tokens.")

In [None]:
# Step 3: Resize the embedding layer in the model to accommodate new tokens
model.resize_token_embeddings(len(tokenizer))

In [None]:
# Step 4: Example dataset (list of text strings)
custom_corpus = [
    "The echocardiogram revealed a potential defect.",
    "Angiocardiography is often used in diagnostic imaging.",
    "Neurofibromatosis can lead to tumor formation."
]
print(custom_corpus)

In [None]:
# Step 5: Tokenize dataset
tokenized_data = tokenizer(custom_corpus, return_tensors='pt', padding=True, truncation=True)

# Optional: Setup for Masked Language Modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
# HuggingFace-style Dataset (you can build a real dataset class too)
from torch.utils.data import Dataset

class SimpleTextDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.encodings.items()}
    def __len__(self):
        return len(self.encodings["input_ids"])

dataset = SimpleTextDataset(tokenized_data)

In [None]:
# Step 6: Setup training args (for demo, keep it small)
training_args = TrainingArguments(
    output_dir="./bert-custom",
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=2,
    save_steps=10,
    save_total_limit=2,
    logging_steps=5,
    report_to="none"  # disables wandb
)

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
# Step 7: Initialize Trainer and Train
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator,
)

trainer.train()

In [None]:
model.save_pretrained("bert-custom")
tokenizer.save_pretrained("bert-custom")

In [None]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-custom")


In [None]:
vocab = tokenizer.get_vocab()  # token -> ID
inv_vocab = {v: k for k, v in vocab.items()}
print(inv_vocab[30522])  # 'angiocardiography'


In [None]:
added = tokenizer.get_added_vocab()
print(added)
# {'angiocardiography': 30522, 'echocardiogram': 30523, 'neurofibromatosis': 30524}


In [None]:
# Convert the tokenizer's vocabulary to a list of tokens
vocab_keys = list(vocab.keys())

# Save the vocab to a file
with open('./bert-custom/flat_vocab.txt', 'w') as f:
    for token in vocab_keys:
        f.write(token + '\n')

In [None]:
# Test tokenization of a sentence with custom tokens
test_text = "The patient's echocardiogram showed no abnormalities after the angiocardiography procedure."
tokens = tokenizer.tokenize(test_text)
print(tokens)
# ['the', 'patient', "'", 's', 'echocardiogram', 'showed', 'no', 'abnormal', '##ities', 'after', 'the', 'angiocardiography', 'procedure', '.']

# Convert to token IDs
token_ids = tokenizer.encode(test_text)
print(token_ids)
# [101, 1996, 5776, 1005, 1055, 30523, 3662, 2053, 28828, 2044, 1996, 30522, 7709, 1012, 102]
