# 03: Train Encoder-Only Model (BERT-style)
This notebook demonstrates how to train a masked language model (MLM) using an encoder-only architecture like BERT.

In [None]:
!pip install torch transformers

In [1]:
import sys
import os
sys.path.append(os.path.abspath(".."))

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from models.encoder_only import BertStyleEncoder
import random

## Load and tokenize dataset

In [None]:
if not os.path.exists("../data/tiny_shakespeare.txt"):
    from urllib.request import urlretrieve
    os.makedirs("../data", exist_ok=True)
    urlretrieve("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", "../data/tiny_shakespeare.txt")

In [3]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
with open("../data/tiny_shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()
tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True)["input_ids"].squeeze(0)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

## Masked Language Modeling Dataset

In [4]:
class MLMDataset(Dataset):
    def __init__(self, tokens, mask_prob=0.15, block_size=64):
        self.samples = [tokens[i:i+block_size] for i in range(0, len(tokens)-block_size, block_size)]
        self.mask_prob = mask_prob
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        input_ids = self.samples[idx].clone()
        labels = input_ids.clone()
        mask = torch.rand(input_ids.shape) < self.mask_prob
        input_ids[mask] = tokenizer.mask_token_id
        return input_ids, labels

dataset = MLMDataset(tokens)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

## Initialize BERT-style encoder model

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = BertStyleEncoder(
    vocab_size=tokenizer.vocab_size,
    embed_dim=768,
    depth=6,
    heads=12,
    ff_dim=2048,
    max_len=64
).to(device)

## Train with MLM loss

In [6]:
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss()
for epoch in range(3):
    model.train()
    total_loss = 0
    for input_ids, labels in dataloader:
        input_ids, labels = input_ids.to(device), labels.to(device)
        logits = model(input_ids)
        loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} loss: {total_loss / len(dataloader):.4f}")

torch.save(model.state_dict(), "bert_style_encoder.pt")
print("Model saved.")

Epoch 1 loss: 10.4540
Epoch 2 loss: 9.8412
Epoch 3 loss: 9.3071
Model saved.
