## Prepare Dataset for BERT 

In [7]:
import torch
from datasets import load_dataset
from transformers import BertTokenizer, BertForTokenClassification, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, TensorDataset, RandomSampler, SequentialSampler
from sklearn.model_selection import train_test_split
import numpy as np
from tqdm import tqdm

# Load the dataset
dataset = load_dataset("midas/duc2001", "raw")["test"]

# Ensure that the tokenizer and encode_plus handle pre-tokenized text correctly
documents = [' '.join(doc) for doc in dataset['document']]  # Convert list of tokens to a single string per document

# Tokenization and encoding for BERT
input_ids = []
attention_masks = []
labels = []

# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Map labels into integers
tag2idx = {'B': 0, 'I': 1, 'O': 2}
tags_vals = ['B', 'I', 'O']

for i, doc in enumerate(documents):
    encoded_dict = tokenizer.encode_plus(
                    doc,                      # Document to encode as a single string.
                    add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                    max_length = 64,           # Pad & truncate all sentences.
                    padding='max_length',      # Update deprecated argument
                    truncation=True,
                    return_attention_mask = True,   # Construct attention masks.
                    return_tensors = 'pt',     # Return pytorch tensors.
                )
    
    # Pad the labels to the max_length
    label = [tag2idx[tag] for tag in dataset['doc_bio_tags'][i]] + [tag2idx['O']] * (64 - len(dataset['doc_bio_tags'][i]))
    labels.append(torch.tensor(label))

# Convert lists into tensors
input_ids = torch.cat([item['input_ids'] for item in encoded_dict], dim=0)
attention_masks = torch.cat([item['attention_mask'] for item in encoded_dict], dim=0)
labels = torch.cat(labels, dim=0)

# Split into training and validation while maintaining alignment
train_inputs, val_inputs, train_labels, val_labels, train_masks, val_masks = train_test_split(
    input_ids, labels, attention_masks, test_size=0.1, random_state=2018
)

# Create the DataLoader for our training set
train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=32)

# Create the DataLoader for our validation set
valid_data = TensorDataset(val_inputs, val_masks, val_labels)
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=32)

# Load BERT for token classification
model = BertForTokenClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=len(tag2idx),
    output_attentions = False,
    output_hidden_states = False,
)

# Set up the optimizer
optimizer = AdamW(model.parameters(), lr=3e-5, eps=1e-8)

epochs = 4
max_grad_norm = 1.0

# Total number of training steps
total_steps = len(train_dataloader) * epochs

# Create the learning rate scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=0, 
                                            num_training_steps=total_steps)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Function to calculate the accuracy of predictions vs labels
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=2).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

# Training loop
for epoch in tqdm(range(epochs), desc="Epoch"):
    model.train()
    total_loss = 0
    
    for step, batch in enumerate(train_dataloader):
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        model.zero_grad()
        
        outputs = model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
        loss = outputs.loss
        
        total_loss += loss.item()
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        
        optimizer.step()
        scheduler.step()
        
    avg_train_loss = total_loss / len(train_dataloader)            
    print(f'\nAverage Training Loss: {avg_train_loss:.2f}')
    
    # ========================================
    #               Validation
    # ========================================
    # After each training epoch, measure the model's performance on the validation set.
    
    # Put the model in evaluation mode
    model.eval()
    eval_loss, eval_accuracy, nb_eval_steps = 0, 0, 0
    
    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        with torch.no_grad():
            outputs = model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
        
        logits = outputs.logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').num
