# Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
## Important libs ##
import os
from pathlib import Path
import huggingface_hub
from datasets import load_dataset
from sklearn.metrics import classification_report

os.chdir(Path.cwd().parent)

from src.utils import load_env_file

load_env_file()
api_key = os.getenv("HF_TOKEN")

huggingface_hub.login(api_key)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


# Fine-tuning distilBERT (baseline to compare with LLM later)

## Loading data

In [3]:
raw_datasets = load_dataset("conll2003", trust_remote_code="true")
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})

## Data processing (tokenization and padding)

In [4]:
from transformers import AutoTokenizer

model_id = "distilbert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer(raw_datasets["train"][0]["tokens"], is_split_into_words=True)

{'input_ids': [101, 7327, 19164, 2446, 2655, 2000, 17757, 2329, 12559, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [None]:
from src.model.dataset_configs import tokenize_ner_models

tokenized_datasets = raw_datasets.map(
    tokenize_ner_models,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
)

tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 3453
    })
})

In [6]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [None]:
label_names = raw_datasets["train"].features["ner_tags"].feature.names

words = raw_datasets["train"][4]["tokens"]
labels = raw_datasets["train"][4]["ner_tags"]
line1 = ""
line2 = ""
for word, label in zip(words, labels):
    full_label = label_names[label]
    max_length = max(len(word), len(full_label))
    line1 += word + " " * (max_length - len(word) + 1)
    line2 += full_label + " " * (max_length - len(full_label) + 1)

print(line1)
print(line2)

Germany 's representative to the European Union 's veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer . 
B-LOC   O  O              O  O   B-ORG    I-ORG O  O          O         B-PER  I-PER     O    O  O         O         O      O   O         O    O         O     O    B-LOC   O     O   O          O      O   O       O 


In [8]:
from transformers import AutoModelForTokenClassification

id2label = {i: label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}

model = AutoModelForTokenClassification.from_pretrained(
    model_id,
    id2label=id2label,
    label2id=label2id,
)

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [40]:
label2id

{'O': 0,
 'B-PER': 1,
 'I-PER': 2,
 'B-ORG': 3,
 'I-ORG': 4,
 'B-LOC': 5,
 'I-LOC': 6,
 'B-MISC': 7,
 'I-MISC': 8}

In [11]:
id2label

{0: 'O',
 1: 'B-PER',
 2: 'I-PER',
 3: 'B-ORG',
 4: 'I-ORG',
 5: 'B-LOC',
 6: 'I-LOC',
 7: 'B-MISC',
 8: 'I-MISC'}

## Training (fine-tuning)

In [9]:
import wandb

project_name = "ner_fine_tuning"
group = "ner_fine_tuning"
# This will open a window so you can login to W&B.
# If that doesn't work, set your W&B API key below
# If you do, remove your key before publishing to GitHub.

# %env WANDB_API_KEY=YOUR_WANDB_API_KEY
#wandb.login()
run = wandb.init(project=project_name, group=group, mode="online")

[34m[1mwandb[0m: Currently logged in as: [33mgabrieldiasmp[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
tokenized_datasets.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "labels"]
)

In [11]:
from src.model.dataset_configs import HFTextDataset
from torch.utils.data import DataLoader

In [12]:
train_ds = HFTextDataset(tokenized_datasets["train"])

train_loader = DataLoader(
    dataset=train_ds,
    batch_size=32,
    shuffle=True,
    collate_fn=data_collator
)

In [13]:
val_ds = HFTextDataset(tokenized_datasets["validation"])

val_loader = DataLoader(
    dataset=val_ds,
    batch_size=32,
    shuffle=False,
    collate_fn=data_collator
)

In [14]:
train_last_layers_only = True

if train_last_layers_only:
    for param in model.parameters():
        param.requires_grad = False

    for param in model.classifier.parameters():
        param.requires_grad = True

In [15]:
from src.model.training import FlexibleLightningModel, HFLightningModel, train_model_lightning
from lightning.pytorch.loggers import WandbLogger

lightning_model = FlexibleLightningModel(
    model=model, label_name="labels", learning_rate=0.05, num_classes=9, task_type="token_classification")

wandb_logger = WandbLogger(log_model="best")

trainer = train_model_lightning(
    max_epochs=5,
    project_name=project_name,
    group=group,
    metric_to_monitor="val_loss",
    mode="min"
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [65]:
trainer.fit(
    model=lightning_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                             | Params | Mode 
---------------------------------------------------------------------------
0 | model         | DistilBertForTokenClassification | 66.4 M | eval 
1 | train_acc     | MulticlassAccuracy               | 0      | train
2 | val_acc       | MulticlassAccuracy               | 0      | train
3 | test_acc      | MulticlassAccuracy               | 0      | train
4 | val_precision | MulticlassPrecision              | 0      | train
5 | val_recall    | MulticlassRecall                 | 0      | train
6 | val_f1        | MulticlassF1Score                | 0      | train
---------------------------------------------------------------------------
6.9 K     Trainable params
66.4 M    Non-trainable params
66.4 M    Total params
265.479   Total estimated model params size (MB)
6         Modules in train mode
95        Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [66]:
wandb.finish()

0,1
epoch,▁▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆██████
train_acc_epoch,▁█▁▄▁
train_acc_step,▅█▄▄▃▅▅▁▇▄▇▆▃▂▅▆▅▇▅▅▇▆▆▅▇▅▅▃▄▄▆▄▃▆▆▆▅▄▄▅
train_loss,▄▁▄▄▄▃▃▇▂▅▁▂▄█▃▂▃▂▅▃▂▂▄▄▂▄▃▄▃▅▂▄▄▂▂▄▃▅▄▃
trainer/global_step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
val_acc,▄▁▄█▅
val_f1_class_0,▁▂▇█▆
val_f1_class_1,▆▁█▇▆
val_f1_class_2,▄▇▁█▂
val_f1_class_3,▇▁█▆▇

0,1
epoch,4.0
train_acc_epoch,0.95465
train_acc_step,0.95607
train_loss,0.22556
trainer/global_step,2194.0
val_acc,0.95382
val_f1_class_0,0.98665
val_f1_class_1,0.78969
val_f1_class_2,0.77767
val_f1_class_3,0.64182


## Test inference

In [16]:
test_ds = HFTextDataset(tokenized_datasets["test"])

test_loader = DataLoader(
    dataset=test_ds,
    batch_size=32,
    shuffle=True,
    collate_fn=data_collator
)

In [None]:
f"{wandb.run.entity}/{project_name}/model-{wandb.run.id}:best"

"model-nc4fzygv"

'gabrieldiasmp/ner_fine_tuning/model-8hc1yf5k:best'

In [60]:
# Define checkpoint reference.
checkpoint_reference = f"{wandb.run.entity}/{project_name}/model-{'olprujj3'}:best" # wandb.run.id

# Download checkpoint locally (if not already cached).
artifact = run.use_artifact(checkpoint_reference, type="model")
artifact_dir = artifact.download()

[34m[1mwandb[0m: Downloading large artifact model-olprujj3:best, 253.28MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.4 (178.0MB/s)


In [61]:
# Load LightningModule with checkpoint
model_test = FlexibleLightningModel.load_from_checkpoint(
    checkpoint_path=str(artifact_dir) + "/model.ckpt",
    model=model,
    learning_rate=0.05,
    num_classes=9,
    label_name="labels",
    task_type="token_classification"
)

In [62]:
predicted_labels = batch_outputs = trainer.predict(model=model_test, dataloaders=test_loader)

/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:476: Your `predict_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

In [47]:
batch_outputs[0].keys()

dict_keys(['logits', 'preds', 'labels'])

In [49]:
batch_outputs[0]['logits'][0]

tensor([[ 1.2249e+01, -9.4911e+00, -1.6683e+01, -9.9180e-01, -7.2359e+00,
          1.5641e+00, -9.3560e-03, -2.9993e+00,  2.7563e+00],
        [ 3.8110e+00, -1.7084e+01, -2.9783e+01,  1.8779e+01,  1.5866e-01,
          5.4683e+00, -1.1140e+01,  5.0917e+00, -9.5163e+00],
        [-3.9796e+00, -5.7436e+01, -1.2956e+01, -2.7093e+00,  1.9136e+01,
         -1.9439e+01,  2.1610e+00, -1.9562e+01,  1.1541e+01],
        [ 3.3517e-01, -4.3263e+01, -1.5210e+01, -3.5497e+00,  1.4251e+01,
         -8.7733e+00,  1.9600e+00, -1.6918e+01,  4.3850e+00],
        [ 1.2856e+01, -3.3975e+01, -2.1374e+01,  1.3754e+00,  4.7026e+00,
         -6.5030e+00, -1.0460e+01, -9.6904e+00,  3.8632e-01],
        [ 1.2301e+01, -2.6842e+01, -1.4833e+01,  1.4773e-02,  3.9869e+00,
         -1.3567e+01, -9.3527e+00, -1.1409e+01,  4.0558e+00],
        [ 2.7036e+01, -2.1116e+01, -1.3880e+01, -2.7127e+00, -4.1150e+00,
         -1.1620e+01, -1.8338e+01, -1.1250e+01, -2.0302e-02],
        [ 1.6786e+01, -1.3735e+01, -1.8715e+01, 

In [50]:
import torch
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score

def evaluate_ner(batch_outputs, label_names):
    """
    Compute token-level NER metrics (precision, recall, F1) from a list of batch outputs.

    Args:
        batch_outputs (list of dicts): Each dict must contain keys:
            - 'logits': [batch, seq_len, num_classes] tensor
            - 'labels': [batch, seq_len] tensor (optional, can be None)
        label_names (list of str): Maps class indices to string labels.
    
    Returns:
        dict: Metrics including F1, precision, recall and classification report (if labels exist)
        list: all_preds (list of lists of string labels)
        list: all_labels (list of lists of string labels, empty if labels not provided)
    """
    all_preds, all_labels = [], []

    for out in batch_outputs:
        logits = out["logits"]
        labels = out.get("labels", None)

        # Token-level predicted class indices
        preds = torch.argmax(logits, dim=-1).detach().cpu().numpy()

        if labels is not None:
            labels = labels.detach().cpu().numpy()
            for p, l in zip(preds, labels):
                valid_preds = [label_names[pi] for pi, li in zip(p, l) if li != -100]
                valid_labels = [label_names[li] for pi, li in zip(p, l) if li != -100]
                all_preds.append(valid_preds)
                all_labels.append(valid_labels)
        else:
            for p in preds:
                all_preds.append([label_names[pi] for pi in p])

    metrics = {}
    if all_labels:
        metrics["f1"] = f1_score(all_labels, all_preds)
        metrics["precision"] = precision_score(all_labels, all_preds)
        metrics["recall"] = recall_score(all_labels, all_preds)
        metrics["classification_report"] = classification_report(all_labels, all_preds)
    else:
        metrics["message"] = "No labels available. Only predictions returned."

    return metrics, all_preds, all_labels


- Model with few epochs

In [51]:
metrics, all_preds, all_labels = evaluate_ner(batch_outputs, label_names)

if "classification_report" in metrics:
    print("F1:", metrics["f1"])
    print("Precision:", metrics["precision"])
    print("Recall:", metrics["recall"])
    print(metrics["classification_report"])
else:
    print(metrics["message"])


F1: 0.6712658430932135
Precision: 0.6136696978586096
Recall: 0.7407932011331445
              precision    recall  f1-score   support

         LOC       0.70      0.79      0.74      1668
        MISC       0.43      0.58      0.50       702
         ORG       0.50      0.67      0.57      1661
         PER       0.76      0.84      0.80      1617

   micro avg       0.61      0.74      0.67      5648
   macro avg       0.60      0.72      0.65      5648
weighted avg       0.62      0.74      0.68      5648



- Model fine-tuned with 20 epochs

In [63]:
metrics, all_preds, all_labels = evaluate_ner(batch_outputs, label_names)

if "classification_report" in metrics:
    print("F1:", metrics["f1"])
    print("Precision:", metrics["precision"])
    print("Recall:", metrics["recall"])
    print(metrics["classification_report"])
else:
    print(metrics["message"])


F1: 0.6473708846933999
Precision: 0.5786610878661088
Recall: 0.7345963172804533
              precision    recall  f1-score   support

         LOC       0.66      0.85      0.74      1668
        MISC       0.37      0.63      0.47       702
         ORG       0.50      0.61      0.55      1661
         PER       0.71      0.79      0.75      1617

   micro avg       0.58      0.73      0.65      5648
   macro avg       0.56      0.72      0.63      5648
weighted avg       0.59      0.73      0.65      5648



In [86]:
label2id

{'O': 0,
 'B-PER': 1,
 'I-PER': 2,
 'B-ORG': 3,
 'I-ORG': 4,
 'B-LOC': 5,
 'I-LOC': 6,
 'B-MISC': 7,
 'I-MISC': 8}

In [100]:
raw_datasets['test'][115]

{'id': '115',
 'tokens': ['New', 'Zealand', 'innings'],
 'pos_tags': [22, 22, 21],
 'chunk_tags': [11, 12, 12],
 'ner_tags': [5, 6, 0]}

In [96]:
all_preds[115]

['O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O']

In [None]:
batch_outputs[0]['logits'][0].shape

torch.Size([32, 50, 9])

## Testing LLMs