In [11]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from torch.nn.functional import softmax

In [8]:
model_name = "dbmdz/bert-large-cased-finetuned-conll03-english"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 8  # Adjust this based on your GPU memory

cuda


In [9]:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name).to(device)
model.eval()

Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), 

In [12]:
# Custom Dataset Class
class NERDataset(Dataset):
    def __init__(self, tokens, labels, tokenizer, label_to_index):
        self.tokens = tokens
        self.labels = labels
        self.tokenizer = tokenizer
        self.label_to_index = label_to_index  # Label to index mapping

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        return {"tokens": self.tokens[idx], "labels": self.labels[idx]}

    def collate_fn(self, batch):
        tokens = [item['tokens'] for item in batch]
        labels = [item['labels'] for item in batch]
        encoded_inputs = self.tokenizer(tokens, is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)

        label_ids = []
        for i, label in enumerate(labels):
            word_ids = encoded_inputs.word_ids(batch_index=i)
            label_indices = [self.label_to_index.get(label_word, -100) for label_word in label]  # Use -100 for unknown or padding
            label_ids.append([-100 if word_id is None else label_indices[word_id] for word_id in word_ids])

        encoded_inputs['labels'] = torch.tensor(label_ids, dtype=torch.long).to(encoded_inputs.input_ids.device)
        return encoded_inputs

In [19]:
# Function to read and preprocess the dataset
def load_dataset(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()

    tokens, labels = [], []
    temp_tokens, temp_labels = [], []
    for line in lines:
        if line.strip() == '':
            if temp_tokens:
                tokens.append(temp_tokens)
                labels.append(temp_labels)
                temp_tokens, temp_labels = [], []
            continue
        parts = line.strip().split()
        temp_tokens.append(parts[0])
        temp_labels.append(parts[-1])

    if temp_tokens:
        tokens.append(temp_tokens)
        labels.append(temp_labels)

    return tokens, labels


In [15]:
# labels
label_to_index = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}

# Load data
tokens, labels = load_dataset('data/FIN3.txt')
dataset = NERDataset(tokens, labels, tokenizer, label_to_index)
loader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)

# Initialize lists
true_labels = []
pred_labels = []



In [16]:
# Model Prediction with Batching
for batch in loader:
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    label_ids = batch['labels'].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
    logits = outputs.logits
    predictions = torch.argmax(softmax(logits, dim=2), dim=2)

    for i in range(input_ids.size(0)):
        mask = attention_mask[i].bool()
        prediction = predictions[i][mask].cpu().numpy()
        true_labels_batch = label_ids[i][mask].cpu().numpy()
        pred_labels_batch = [model.config.id2label[pred] for pred in prediction]

        true_labels.extend(true_labels_batch)
        pred_labels.extend(pred_labels_batch)



In [21]:
pred_labels_indices = [label_to_index.get(label, -100) for label in pred_labels]  # Ensure all predicted labels are integers
true_labels_indices = [label_to_index.get(label, -100) if isinstance(label, str) else label for label in true_labels]

In [22]:
# calculate metrics
precision = precision_score(true_labels_indices, pred_labels_indices, average='macro', zero_division=0)
recall = recall_score(true_labels_indices, pred_labels_indices, average='macro', zero_division=0)
f1 = f1_score(true_labels_indices, pred_labels_indices, average='macro', zero_division=0)
accuracy = accuracy_score(true_labels_indices, pred_labels_indices)

In [23]:
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1-Score: {f1}")
print(f"Accuracy: {accuracy}")


Precision: 0.8689224729273751
Recall: 0.8851114769094417
F1-Score: 0.86862560674615774
Accuracy: 0.9266313309776207
