# HW4: Fine-tuning BERT for entity labeling
This notebook contains starter code for finetuning a BERT-style model for the task of entity recognition. It has minimal text so you can easily copy it to **handin.py** when you submit.  Please read all the comments in the code as they contain important information.

In [1]:
# This code block just contains standard setup code for running in Python
import time

# PyTorch imports
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset #random_split
import numpy as np

# Fix the random seed(s) for reproducability
torch.random.manual_seed(8942764)
torch.cuda.manual_seed(8942764)
np.random.seed(8942764)

# Please set your device by uncommenting the right version below

# On Colab or on a machine with access to an Nvidia GPU use the following setting
#device = 'cuda:0'

# if you have an Apple Silicon machine with a GPU, use the following setting
# this should about 3-4 times faster that running it on just CPU
# device = 'mps'

# If you will use a cpu, this is the setting
device = 'cpu'

# Note that in handin.py these next two lines will need to be removed
# if you are going run this on your personal machine you will need to install
# these locally in the shell/terminal.

# !pip install protobuf==3.20.2
# !pip install transformers
# !pip install datasets
# !pip install evaluate
# !pip install seqeval

from transformers import AutoTokenizer, BertModel, DataCollatorForTokenClassification

import evaluate

In [5]:
# Load the dataset
from datasets import ClassLabel, Sequence, load_dataset

data_splits = load_dataset('json', data_files={'train': 'dinos_and_deities_train_bio.jsonl', 'dev': 'dinos_and_deities_dev_bio_sm.jsonl',
                                                'test': 'dinos_and_deities_test_bio_nolabels.jsonl'})

label_names_fname = "dinos_and_deities_train_bio.jsonl.labels"
labels_int2str = []
with open(label_names_fname) as f:
    labels_int2str = f.read().split()
print(f"Labels: {labels_int2str}")
labels_str2int = {l: i for i, l in enumerate(labels_int2str)}

data_splits.cast_column("ner_tags", Sequence(ClassLabel(names=labels_int2str)))
print(data_splits)

Labels: ['I-Aquatic_animal', 'B-Deity', 'B-Mythological_king', 'I-Mythological_king', 'I-Cretaceous_dinosaur', 'B-Aquatic_animal', 'B-Aquatic_mammal', 'I-Goddess', 'I-Deity', 'B-Cretaceous_dinosaur', 'I-Aquatic_mammal', 'B-Goddess', 'O']
DatasetDict({
    train: Dataset({
        features: ['para_index', 'title', 'doc_id', 'content', 'page_id', 'id', 'tokens', 'ner_strings', 'ner_tags'],
        num_rows: 1749
    })
    dev: Dataset({
        features: ['para_index', 'title', 'doc_id', 'content', 'page_id', 'id', 'tokens', 'ner_strings', 'ner_tags'],
        num_rows: 150
    })
    test: Dataset({
        features: ['para_index', 'title', 'doc_id', 'content', 'page_id', 'id', 'tokens', 'ner_strings', 'ner_tags'],
        num_rows: 303
    })
})


In [6]:
# initialize pretrained BERT tokenizer. This might take a while the first time it's run because the model needs to be downloaded.
# Note: if you change the BERT model later, don't forget to also change this!!
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

In [7]:
# If you want you can look at some sample data items
print(data_splits["train"][8])
print(data_splits["dev"][5])

{'para_index': 0, 'title': 'Myersiohyla liliae', 'doc_id': 'Myersiohyla liliae-0', 'content': 'Myersiohyla liliae is a species of frogs in the family Hylidae. It is endemic to the Pacaraima Mountains in Guyana and known from the region of its type locality in the Kaieteur National Park and from Imbaimadai. The species is dedicated to the daughter of its describer, Lili Kok.', 'page_id': '28259031', 'id': 'Ud-DXIcB1INCf0UyAseC', 'tokens': ['Myersiohyla', 'liliae', 'is', 'a', 'species', 'of', 'frogs', 'in', 'the', 'family', 'Hylidae.', 'It', 'is', 'endemic', 'to', 'the', 'Pacaraima', 'Mountains', 'in', 'Guyana', 'and', 'known', 'from', 'the', 'region', 'of', 'its', 'type', 'locality', 'in', 'the', 'Kaieteur', 'National', 'Park', 'and', 'from', 'Imbaimadai.', 'The', 'species', 'is', 'dedicated', 'to', 'the', 'daughter', 'of', 'its', 'describer,', 'Lili', 'Kok.'], 'ner_strings': ['B-Aquatic_animal', 'I-Aquatic_animal', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '

In [8]:
# This dataset is split into a train, validation and test set, and each token has a label.
# Data from the dataset can generally be accessed like a Python dict.
print(data_splits['train'].features)

# Print the original sentence (which is whitespace tokenized).
example_input_tokens = data_splits['train'][8]['tokens']
print(f"Original tokens: {example_input_tokens}")

# Print the labels of the sentence.
example_ner_labels = data_splits['train'][8]['ner_tags']
print(f"NER labels: {example_ner_labels}")

# Map integer to string labels for the sentence
example_mapped_labels = [labels_int2str[l] for l in example_ner_labels]
print(f'Labels: {example_mapped_labels}')

# Print the sentence split into tokens.
example_tokenized = tokenizer(example_input_tokens, is_split_into_words=True)
print('BERT Tokenized: ', example_tokenized.tokens())

# Print the number of tokens in the vocabulary
print(f'Vocab size: {tokenizer.vocab_size}')

# # Print the sentence mapped to token ids.
print('Token IDs: ', tokenizer.convert_tokens_to_ids(example_tokenized.tokens()))

# Of course, there are now way more tokens than labels! Fortunately the HF tokenizer
# provides a function that will give us the mapping:
print(example_tokenized.word_ids())

{'para_index': Value(dtype='int64', id=None), 'title': Value(dtype='string', id=None), 'doc_id': Value(dtype='string', id=None), 'content': Value(dtype='string', id=None), 'page_id': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None), 'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'ner_strings': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'ner_tags': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}
Original tokens: ['Myersiohyla', 'liliae', 'is', 'a', 'species', 'of', 'frogs', 'in', 'the', 'family', 'Hylidae.', 'It', 'is', 'endemic', 'to', 'the', 'Pacaraima', 'Mountains', 'in', 'Guyana', 'and', 'known', 'from', 'the', 'region', 'of', 'its', 'type', 'locality', 'in', 'the', 'Kaieteur', 'National', 'Park', 'and', 'from', 'Imbaimadai.', 'The', 'species', 'is', 'dedicated', 'to', 'the', 'daughter', 'of', 'its', 'describer,', 'Lili', 'Kok.']
NER labels: [5, 0, 12, 12, 12, 12, 12, 12, 12, 12, 12,

In [9]:
# We can write a function that uses that along with the original labels to get the new set of labels
# for each BERT-tokenized token.
def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            str_label = labels_int2str[label]
            if str_label[0] == 'B':
                new_str_label = 'I' + str_label[1:]
                label = labels_str2int[new_str_label]
            new_labels.append(label)

    return new_labels

In [10]:
tokenizer_aligned_labels = align_labels_with_tokens(example_ner_labels, example_tokenized.word_ids())
print(f'Aligned labels: {tokenizer_aligned_labels}')
print(f'Mapped aligned labels: {[labels_int2str[l] if l >= 0 else "_" for l in tokenizer_aligned_labels]}')

Aligned labels: [-100, 5, 0, 0, 0, 0, 0, 0, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, -100]
Mapped aligned labels: ['_', 'B-Aquatic_animal', 'I-Aquatic_animal', 'I-Aquatic_animal', 'I-Aquatic_animal', 'I-Aquatic_animal', 'I-Aquatic_animal', 'I-Aquatic_animal', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '_']


In [11]:
# Let's check the function on the example from before. The special tokens don't have labels,
# so we'll just replace those with _
aligned_labels = align_labels_with_tokens(example_ner_labels, example_tokenized.word_ids())
print(f"Tokens: {example_tokenized.tokens()}")
print(f"Aligned labels: {[labels_int2str[l] if l >= 0 else '_' for l in aligned_labels]}")

Tokens: ['[CLS]', 'Myers', '##io', '##hyl', '##a', 'l', '##ilia', '##e', 'is', 'a', 'species', 'of', 'frogs', 'in', 'the', 'family', 'H', '##yl', '##idae', '.', 'It', 'is', 'endemic', 'to', 'the', 'Pac', '##ara', '##ima', 'Mountains', 'in', 'Guyana', 'and', 'known', 'from', 'the', 'region', 'of', 'its', 'type', 'locality', 'in', 'the', 'Kai', '##ete', '##ur', 'National', 'Park', 'and', 'from', 'I', '##mba', '##ima', '##dai', '.', 'The', 'species', 'is', 'dedicated', 'to', 'the', 'daughter', 'of', 'its', 'describe', '##r', ',', 'Lil', '##i', 'Ko', '##k', '.', '[SEP]']
Aligned labels: ['_', 'B-Aquatic_animal', 'I-Aquatic_animal', 'I-Aquatic_animal', 'I-Aquatic_animal', 'I-Aquatic_animal', 'I-Aquatic_animal', 'I-Aquatic_animal', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O',

In [12]:
# Need to get the whole dataset into this format, so need to write a fn
# we can apply efficiently across all examples using Dataset.map.
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"], truncation=True, is_split_into_words=True
    )
    all_labels = examples["ner_tags"]
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))

    tokenized_inputs["labels"] = new_labels
    return tokenized_inputs

In [13]:
# Now we can apply that fn to tokenize all the data
tokenized_data_splits = data_splits.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=data_splits["train"].column_names,
)

In [14]:
# Testing batcher
print("Examples:")
for i in range(2):
    print(tokenized_data_splits["train"][i]["labels"])

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
batch = data_collator([tokenized_data_splits["train"][i] for i in range(2)])

Examples:
[-100, 9, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 9, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, -100]
[-100, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,

In [15]:
# Evaluation: we can use the seqeval library to handle calculating span-level precision, recall and F1
metric = evaluate.load("seqeval")

labels = data_splits["train"][0]["ner_tags"]
labels = [labels_int2str[i] for i in labels]
print(labels)

# Make a small change and see how it impacts the score
predictions = labels.copy()
predictions[0] = "O"
metric.compute(predictions=[predictions], references=[labels])

['B-Cretaceous_dinosaur', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Cretaceous_dinosaur', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


{'Cretaceous_dinosaur': {'precision': 1.0,
  'recall': 0.5,
  'f1': 0.6666666666666666,
  'number': 2},
 'overall_precision': 1.0,
 'overall_recall': 0.5,
 'overall_f1': 0.6666666666666666,
 'overall_accuracy': 0.9904761904761905}

In [16]:
from seqeval.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score

In [17]:
def run_eval(model, dataset, batch_size, device, collate_fn=None, loss_fn=None, print_out=False, pad_token_id=-100):
    """
    Evaluate the model on a given dataset using the seqeval library for sequence labeling.

    Parameters:
    - model: The trained model to evaluate.
    - dataset: The dataset for evaluation.
    - batch_size: Batch size for DataLoader.
    - device: Device to perform evaluation (CPU/GPU).
    - collate_fn: Optional collate function for DataLoader.
    - loss_fn: Loss function for computing validation loss (e.g., CrossEntropyLoss).
    - print_out: If True, prints the classification report.
    - pad_token_id: Token ID used for padding (default -100).

    Returns:
    - avg_loss: Average validation loss.
    - metrics: A dictionary containing evaluation metrics.
    """
    model.eval()
    all_labels = []
    all_predictions = []
    total_loss = 0
    num_batches = 0

    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)

    with torch.no_grad():
        for batch in dataloader:
            # Move batch to the specified device
            batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            labels = batch.pop("labels")  # Extract labels from the batch

            # Forward pass to get logits
            logits = model(**batch)

            # Compute loss if loss_fn is provided
            if loss_fn is not None:
                loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
                total_loss += loss.item()

            num_batches += 1

            # Get predictions
            predictions = torch.argmax(logits, dim=-1)

            # Filter out padding tokens (e.g., -100) from labels and predictions
            mask = labels != pad_token_id  # Mask for non-padding tokens
            filtered_labels = labels[mask].cpu().numpy()
            filtered_predictions = predictions[mask].cpu().numpy()

            all_labels.append(filtered_labels)
            all_predictions.append(filtered_predictions)

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0

    # Convert the lists of labels and predictions into a list of lists (each sequence is a list)
    all_labels_str = [[labels_int2str[i] for i in seq] for seq in all_labels]
    all_predictions_str = [[labels_int2str[i] for i in seq] for seq in all_predictions]

    # Compute metrics using seqeval
    from seqeval.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score

    report = classification_report(all_predictions_str, all_labels_str)
    accuracy = accuracy_score(all_predictions_str, all_labels_str)
    precision = precision_score(all_predictions_str, all_labels_str)
    recall = recall_score(all_predictions_str, all_labels_str)
    f1 = f1_score(all_predictions_str,all_labels_str )

    metrics = {
        "loss": avg_loss,
        "overall_accuracy": accuracy,
        "overall_precision": precision,
        "overall_recall": recall,
        "overall_f1": f1,
        "classification_report": report,
    }

    if print_out:
        print(report)

    return avg_loss, metrics, all_labels_str, all_predictions_str, labels

In [18]:
def train(model,
          train_dataset,
          num_epochs,
          batch_size,
          optimizer_cls,
          lr,
          weight_decay,
          device,
          collate_fn,
          log_every=100):
    """
    Train a token classification model with a DataLoader.

    Returns:
        model: Trained model.
        (train_loss_history, train_acc_history): Lists of loss and accuracy per epoch.
    """
    dataloader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=collate_fn)

    # Select optimizer
    if optimizer_cls == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_cls == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_cls == 'AdamW':
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    else:
        raise ValueError(f"Unsupported optimizer_cls: {optimizer_cls}")

    # Loss function for token classification
    lossfn = nn.CrossEntropyLoss()

    # Training history
    train_loss_history = []
    train_acc_history = []

    # Training loop
    for e in range(num_epochs):
        model.train()
        epoch_loss_history = []
        epoch_acc_history = []
        start_time = time.time()

        for i, batch in enumerate(dataloader):
            # Move batch to device (GPU/CPU)
            batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            y = batch.pop('labels')  # Extract labels

            # Forward pass
            outputs = model(**batch)  # Outputs logits directly
            logits = outputs  # logits are the direct output in your custom model

            # Compute the loss
            loss = lossfn(logits.view(-1, logits.size(-1)), y.view(-1))  # Reshape for token-level classification

            # Compute predictions and accuracy
            pred = torch.argmax(logits, dim=-1)  # Predicted labels
            mask = y != -100  # Ignore padding indices
            acc = (pred[mask] == y[mask]).float().mean()  # Calculate accuracy on non-padded tokens

            # Record metrics
            epoch_loss_history.append(loss.item())
            epoch_acc_history.append(acc.item())

            # Log progress at intervals
            if i % log_every == 0:
                speed = 0 if i == 0 else log_every / (time.time() - start_time)
                print(f'epoch: {e} | iter: {i} | train_loss: {np.mean(epoch_loss_history):.3e} | '
                      f'train_acc: {np.mean(epoch_acc_history):.3f} | speed: {speed:.3f} b/s')
                start_time = time.time()

            # Backpropagation
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        # End of epoch metrics
        train_loss_history.append(np.mean(epoch_loss_history))
        train_acc_history.append(np.mean(epoch_acc_history))
        print(f'Epoch {e}: train_loss={train_loss_history[-1]:.3e}, train_acc={train_acc_history[-1]:.3f}')

    return model, (train_loss_history, train_acc_history)


In [35]:
class BertForTokenClassification(nn.Module):
    def __init__(self, bert_pretrained_config_name, num_classes, freeze_bert=False, dropout_prob=0.1):
        '''
        BERT with a classification MLP
        args:
        - bert_pretrained_config_name (str): model name from huggingface hub
        - num_classes (int): number of classes in the classification task
        - freeze_bert (bool): [default False] If true gradients are not computed for
                              BERT's parameters.
        - dropout_prob (float): [default 0.1] probability of dropping each activation.
        '''
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_pretrained_config_name)
        self.bert.requires_grad_(not freeze_bert)
        self.classifier = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 128),
            nn.ReLU(),
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128, num_classes),
            nn.LogSoftmax(dim=-1)
        )
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        logits = self.classifier(outputs.last_hidden_state)  # [batch_size, seq_len, num_labels]
        return logits  # Only returning logits for simplicity



In [None]:
# This is where fine-tuning of the classifier happens.
# Here we are training with batch size 32 for 5 epochs.

# At the end of each epoch, you also see validation loss and validation accuracy.
# Change the device as described above if you will not be using a GPU

# Set the random seed(s) for reproducability
torch.random.manual_seed(8942764)
#torch.cuda.manual_seed(8942764)
np.random.seed(8942764)

# Make sure this is the same as you use for tokenization!
bert_model = 'bert-base-cased'

num_labels = len(labels_int2str)
print(f"Num labels: {num_labels}")

# conll hyperparams
# multiply your learning rate by k when using batch size of kN
lr = 4*2e-5 # 1e-3
weight_decay = 0.01
epochs = 5
batch_size = 32
dropout_prob = 0.2
freeze_bert = False

bert_cls = BertForTokenClassification(bert_model, num_labels, dropout_prob=dropout_prob, freeze_bert=freeze_bert)
bert_cls.to(device)
print(f'Trainable parameters: {sum([p.numel() for p in bert_cls.parameters() if p.requires_grad])}\n')

# Flag for setting "debug" mode. Set debug to False for full training.
debug = False

# Sample a subset of the training data for faster iteration in debug mode
subset_size = 1000
subset_indices = torch.randperm(len(tokenized_data_splits['train']))[:subset_size]
train_subset = Subset(tokenized_data_splits['train'], subset_indices)

bert_cls, bert_cls_logs = train(bert_cls, tokenized_data_splits['train'] if not debug else train_subset,
                                num_epochs=epochs, batch_size=batch_size, optimizer_cls='AdamW',
                                lr=lr, weight_decay=weight_decay, device=device,
                                collate_fn=data_collator, log_every=10 if debug else 100)



Num labels: 13
Trainable parameters: 108426893



In [21]:
# Save the model's weights
#torch.save(bert_cls.state_dict(), "./bert_cls_model.pth")


In [38]:
# Load custom model
bert_cls.load_state_dict(torch.load("./bert_cls_model.pth", map_location=torch.device('cpu')))


<All keys matched successfully>

In [24]:
# Now evaluating on the validation dataset
final_loss, final_metrics, true_labels, predicted_labels, labels = run_eval(bert_cls, tokenized_data_splits['dev'], batch_size=32, device=device, collate_fn=data_collator)
final_acc = final_metrics['overall_accuracy']
final_p = final_metrics['overall_precision']
final_r = final_metrics['overall_recall']
final_f1 = final_metrics['overall_f1']
print(f'\nFinal Loss: {final_loss:.3e}\t Final Accuracy: {final_acc:.3f}\t dev_p:{final_p:.3f}\t dev_r:{final_r:.3f}\t dev_f1:{final_f1:.3f}')

  _warn_prf(average, modifier, msg_start, len(result))



Final Loss: 0.000e+00	 Final Accuracy: 0.936	 dev_p:0.317	 dev_r:0.213	 dev_f1:0.255


In [25]:
for true, pred in zip(true_labels[:5], predicted_labels[:5]):
    print("True:", true)
    print("Pred:", pred)
    print()


True: ['O', 'O', 'O', 'O', 'O', 'B-Goddess', 'I-Goddess', 'I-Goddess', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Aquatic_mammal', 'I-Aquatic_mammal', 'I-Aquatic_mammal', 'I-Aquatic_mammal', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Aquatic_mammal', 'I-Aquatic_mammal', 'I-Aquatic_mammal', 'I-Aquatic_mammal', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Goddess', 'I-Goddess', 'I-Goddess', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 

In [None]:
# Now evaluating on the test dataset
final_loss, final_metrics, true_labels, predicted_labels, labels = run_eval(
    bert_cls,
    tokenized_data_splits['test'],
    batch_size=32,
    device=device,
    collate_fn=data_collator
)
final_acc = final_metrics['overall_accuracy']
final_p = final_metrics['overall_precision']
final_r = final_metrics['overall_recall']
final_f1 = final_metrics['overall_f1']
print(f'\nFinal Loss: {final_loss:.3e}\t Final Accuracy: {final_acc:.3f}\t dev_p:{final_p:.3f}\t dev_r:{final_r:.3f}\t dev_f1:{final_f1:.3f}')

In [39]:
def generate_test_output(model, dataset, batch_size, device, collate_fn, labels_int2str):
    """
    Generate the output tags for the test dataset in the required format.

    Parameters:
    - model: The trained model.
    - dataset: Test dataset.
    - batch_size: Batch size for DataLoader.
    - device: Device (CPU or GPU).
    - collate_fn: Collate function for the DataLoader.
    - labels_int2str: Mapping from label integers to string tags.

    Returns:
    - List of lists of tags for each sentence in the test dataset.
    """
    model.eval()
    test_output = []

    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)

    with torch.no_grad():
        for batch in dataloader:
            # Move tensors to the specified device
            batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            labels = batch.pop("labels")  # Remove labels from the batch
            
            # Get model predictions
            logits = model(**batch)
            predictions = torch.argmax(logits, dim=-1)

            # Convert predictions to tags
            for pred, label, input_ids in zip(predictions, labels, batch["input_ids"]):
                sentence_tags = []
                for p, l, token_id in zip(pred.cpu().numpy(), label.cpu().numpy(), input_ids.cpu().numpy()):
                    if l == -100:  # Ignore special tokens and padding
                        continue
                    sentence_tags.append(labels_int2str[p])
                test_output.append(sentence_tags)

    return test_output


In [40]:
# Generate predictions for the test dataset
test_tags = generate_test_output(
    model=bert_cls,
    dataset=tokenized_data_splits['test'],
    batch_size=32,
    device=device,
    collate_fn=data_collator,
    labels_int2str=labels_int2str
)

# Save to JSON
import json

with open("test_predictions_bert.json", "w") as f:
    json.dump(test_tags, f, indent=4)

print("Test predictions saved to 'test_predictions_bert.json'")


Test predictions saved to 'test_predictions_bert.json'
