In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.environ['CUDA_LAUNCH_BLOCKING']="1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

from BERT import BERTForNER, ModelArgs, compute_loss, evaluate
from BERTdataloader import BERTDataLoader
import torch
from transformers import AutoTokenizer
from datasets import load_dataset

from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(f"runs/bert_ner_experiment_1")


In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
ds = load_dataset("conll2003")


In [None]:
loader = BERTDataLoader(tokenizer, ds)

train_loader = loader.get_dataloader('train')

validation_loader = loader.get_dataloader('validation')

num_ner_labels = loader.num_ner_labels

label_list = loader.label_list


In [None]:
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import (
    CosineAnnealingLR,
    LinearLR,
    SequentialLR
)

batch_size = 8                  # Physical batch size  
grad_accum_steps = 2          # Accumulate gradients over 2 steps
effective_batch_size = batch_size * grad_accum_steps  # Simulated batch size
device = "cuda"
args = ModelArgs(dim=768, n_heads=8, n_layers = 8, device = device, max_batch_size = 8, vocab_size = tokenizer.vocab_size)
# args = ModelArgs(dim=1024, n_heads=16, n_layers = 24, device = device, max_batch_size = 8, vocab_size = tokenizer.vocab_size)
model = BERTForNER(args, num_ner_labels=num_ner_labels).to(device)

epochs = 6

optimizer = torch.optim.AdamW(model.parameters(),  lr=5e-5, weight_decay=0.01)
criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)  # -100 = ignore padding tokens

total_steps = len(train_loader) * epochs
warmup_steps = int(0.01 * total_steps)  # 1% warmup (shorter than fine-tuning)

# Short linear warmup
warmup = LinearLR(
    optimizer,
    start_factor=0.01,  # Start at 1% of max LR  
    end_factor=1.0,
    total_iters=warmup_steps
)

# Long cosine decay
cosine = CosineAnnealingLR(
    optimizer,
    T_max=total_steps - warmup_steps,  # Decay over remaining steps
    eta_min=1e-5                       # Min LR = 1e-5 (higher than fine-tuning)
)

# Combine them
scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup, cosine],
    milestones=[warmup_steps]  # Switch to cosine after warmup
)


In [None]:
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()   
    epoch_loss = 0
    total_steps = 0
    
    progress_bar = tqdm(
        enumerate(train_loader), 
        total=len(train_loader),
        desc=f"Epoch {epoch + 1}/{epochs}",
        leave=True
    )

    for step, batch_load in progress_bar:
        batch_load = batch_load.to(device)
        batch = {k: v.to(device) for k, v in batch_load.items()}
        outputs = model(
            batch["input_ids"].to(device),
            attn_mask=batch["attention_mask"].to(device)
        )
        labels = batch["labels"].to(device)
        loss = criterion(outputs.view(-1, outputs.shape[-1]), labels.view(-1))
        loss = loss / grad_accum_steps 
        loss.backward()
        
        epoch_loss += loss.item() * grad_accum_steps  
        total_steps += 1
        
        if (step + 1) % grad_accum_steps == 0 or (step + 1) == len(train_loader):
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            progress_bar.set_postfix({
                "loss": f"{epoch_loss / total_steps:.4f}",
                "lr": optimizer.param_groups[0]["lr"]
            })
            writer.add_scalar(
                "LR", 
                optimizer.param_groups[0]["lr"], 
                global_step=step // grad_accum_steps
            )
        writer.add_scalar("Loss/train", loss.item() * grad_accum_steps, global_step=epoch * len(train_loader) + step)
    
    print(f"Epoch {epoch + 1} Training Loss: {epoch_loss / len(train_loader):.4f}")
    
    # Validation
    validation_metrics = evaluate(
                    model=model,
                    validation_loader=validation_loader,
                    criterion=criterion,
                    label_list=label_list,  
                    device=device,
                    writer = writer,
                    global_step = epoch,
                )
    print("*" * 30)
    

In [None]:
torch.save(model.state_dict(), 'final.pt')