# Evaluate FT BERT

In this notebook, we fine tune BERT on a document classification task, compare the performance against the pretrained BERT (also with a trained linear projection from the cls token to the 11 classes on the same hyperparameters but frozen BERT parameters)

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizer
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
from lxml import etree

NS = {'uslm': 'http://xml.house.gov/schemas/uslm/1.0',
      'xhtml': 'http://www.w3.org/1999/xhtml'}

  from .autonotebook import tqdm as notebook_tqdm


Read in data

In [2]:
def get_ancestor_heading_text(section, tag, ns):
    ancestor = section.getparent()
    while ancestor is not None:
        if ancestor.tag == f"{{{ns['uslm']}}}{tag}":
            heading = ancestor.find('uslm:heading', namespaces=ns)
            return heading.text.strip() if heading is not None else ""
        ancestor = ancestor.getparent()
    return ""


def parse_sections_with_metadata(file_path):
    with open(file_path, 'rb') as f:
        tree = etree.parse(f)
    
    sections = tree.findall('.//uslm:section', namespaces=NS)
    parsed = []

    for section in sections:
        heading = section.find('uslm:heading', namespaces=NS)
        heading_text = heading.text.strip() if heading is not None else ""

        # Get all paragraphs (and any nested elements)
        content_texts = []
        for p in section.findall('.//uslm:p', namespaces=NS):
            text = ' '.join(p.itertext()).strip()
            if text:
                content_texts.append(text)

        if len(content_texts) == 0:
            continue

        # Get ancestors: subtitle, chapter, part
        subtitle = get_ancestor_heading_text(section, 'subtitle', NS)
        chapter = get_ancestor_heading_text(section, 'chapter', NS)
        part = get_ancestor_heading_text(section, 'part', NS)

        parsed.append({
            "metadata": {
                "subtitle": subtitle,
                "chapter": chapter,
                "part": part
                },
            "content": "\n".join(content_texts)
        })

    return parsed


class SubtitleDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            self.texts[idx],
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        return {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }

Define Model

In [7]:
class BertSubtitleClassifier(nn.Module):
    def __init__(self, num_labels, freeze_bert=False):
            super().__init__()
            self.bert = BertModel.from_pretrained("bert-base-uncased")
            if freeze_bert:
                for param in self.bert.parameters():
                    param.requires_grad = False  # Disable autograd so we can compare the finetuned to pretrained
            self.dropout = nn.Dropout(0.3)
            self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

Define Training and Eval Functions

In [8]:
def evaluate(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    f1_macro = f1_score(all_labels, all_preds, average='macro')
    print(f"  ➤ Accuracy: {acc:.4f} | Macro-F1: {f1_macro:.4f}")
    return f1_macro  # Return F1 for cross-validation comparison


def train_with_early_stopping(model, train_loader, val_loader, optimizer, criterion, device, epochs=5, patience=2):
    best_f1 = 0.0
    best_model_state = None
    epochs_no_improve = 0

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")

        # Training step
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc="Training"):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f" Avg train loss: {avg_loss:.4f}")

        # Validation step
        f1_macro = evaluate(model, val_loader, device)

        # Early stopping logic
        if f1_macro > best_f1:
            best_f1 = f1_macro
            best_model_state = model.state_dict()
            epochs_no_improve = 0
            print(f"  🎉 New best Macro-F1: {best_f1:.4f}")
        else:
            epochs_no_improve += 1
            print(f"  No improvement for {epochs_no_improve} epochs.")

        if epochs_no_improve >= patience:
            print(f"Stopping early after {epoch+1} epochs.")
            break

    # Load best weights before returning
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    return model


def train_final_model(parsed_data, num_epochs, batch_size, lr, patience=2, device=None, freeze_bert=False, save_dir="bert-subtitle-embedder"):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Extract data
    texts = [entry["content"] for entry in parsed_data]
    subtitles = [entry["metadata"]["subtitle"] for entry in parsed_data]

    # Encode labels
    label_encoder = LabelEncoder()
    labels = label_encoder.fit_transform(subtitles)

    # Train-val split (90/10) for early stopping
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        texts, labels, test_size=0.1, stratify=labels, random_state=42
    )

    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    # Dataset and loaders
    train_dataset = SubtitleDataset(train_texts, train_labels, tokenizer)
    val_dataset = SubtitleDataset(val_texts, val_labels, tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Init model
    model = BertSubtitleClassifier(num_labels=len(label_encoder.classes_)).to(device)
    optimizer = AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Train with early stopping
    model = train_with_early_stopping(
        model, train_loader, val_loader, optimizer, criterion, device,
        epochs=num_epochs, patience=patience
    )

    print("\n\nFinal Evaluation:\n")
    evaluate(model, val_loader, device)

    # Save only fine-tuned BERT encoder if not frozen
    if not freeze_bert:
        model.bert.save_pretrained(save_dir)
        tokenizer.save_pretrained(save_dir)
        print(f"Fine-tuned model saved to {save_dir}")
    else:
        print("Frozen BERT model trained and evaluated (not saved).")

    return model

Call train and use best hyperparameters

In [10]:
# Hyperparameters
lr = 3e-05
epochs = 4
batch_size = 8
patience = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_dir = "/DL-data/Tax_Law_RAG/usc26/USC26_Subtitle_Classification_BERT"

# Data
parsed_data = parse_sections_with_metadata("/DL-data/usc26.xml")

Evaluate BERT and FT BERT on F1 and Accuracy - Note stdout

def train_final_model(parsed_data, num_epochs, batch_size, lr, patience=2, device=None, freeze_bert=False, save_dir="bert-subtitle-embedder"):

In [None]:
# Train Fine Tuned BERT
train_final_model(parsed_data=parsed_data, num_epochs=epochs, batch_size=batch_size, lr=lr, patience=patience, device=device, freeze_bert=False, save_dir=save_dir)

Epoch 1/4


Training: 100%|██████████| 239/239 [01:28<00:00,  2.70it/s]


 Avg train loss: 1.1627


Evaluating: 100%|██████████| 27/27 [00:07<00:00,  3.72it/s]


  ➤ Accuracy: 0.7512 | Macro-F1: 0.4776
  🎉 New best Macro-F1: 0.4776
Epoch 2/4


Training: 100%|██████████| 239/239 [01:22<00:00,  2.89it/s]


 Avg train loss: 0.4820


Evaluating: 100%|██████████| 27/27 [00:06<00:00,  3.99it/s]


  ➤ Accuracy: 0.8169 | Macro-F1: 0.5962
  🎉 New best Macro-F1: 0.5962
Epoch 3/4


Training:  46%|████▋     | 111/239 [00:39<00:45,  2.82it/s]

In [None]:
# Train the classifier (linear projection) on frozen BERT
train_final_model(parsed_data=parsed_data, num_epochs=epochs, batch_size=batch_size, lr=lr, patience=patience, device=device, freeze_bert=True, save_dir=save_dir)