In [84]:
from transformers import DistilBertTokenizer, DistilBertForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, Trainer, TrainingArguments

from transformers import pipeline
from torch.utils.data import Dataset, DataLoader



In [2]:
# Load pre-trained DistilBERT model and tokenizer
model_name = "distilbert-base-uncased"
model = DistilBertForTokenClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForTokenClassification: ['vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream

In [80]:
# Example raw text and entity offsets
data = [
    {
        "text": "Apple is looking to buy a U.K. startup for $1 billion.",
        "annotations": [(0, 5, "ORG"), (26, 30, "LOC"), (43, 53, "MONEY")]
    },
    {
        "text": "Elon Musk founded SpaceX and Tesla.",
        "annotations": [(0, 9, "PERSON"), (18, 24, "ORG"), (29, 34, "ORG")]
    },
        {
        "text": "Apple is looking to buy a U.K. startup for $1 billion.",
        "annotations": [(0, 5, "ORG"), (26, 30, "LOC"), (43, 53, "MONEY")]
    },
    {
        "text": "Elon Musk founded SpaceX and Tesla.",
        "annotations": [(0, 9, "PERSON"), (18, 24, "ORG"), (29, 34, "ORG")]
    }
]




In [95]:



class NERDataset(Dataset):
    def __init__(self, data_json, tokenizer, label_map) -> None:
        self.data = data_json
        self.tokenizer = tokenizer
        self.label_map = label_map
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        item = self.data[index]
        text = item['text']
        annots = item['annotations']
        tokenized_text = self.tokenizer(text)
        
        iob_tags = ['O'] * len(tokenized_text.tokens())
        
        for start_offset, end_offset, label in annots:
            # converting offsets based on char indices to token indices
            start_token_index = tokenized_text.char_to_token(start_offset)
            end_token_index = tokenized_text.char_to_token(end_offset - 1) # exclusive
            
            iob_tags[start_token_index] = f"B-{label}"
            
            for token_index in range(start_token_index + 1, end_token_index + 1):
                iob_tags[token_index] = f"I-{label}"
        
        label_indices = [label_map[label] for label in iob_tags]
            
        return {
            # 'text' : text,
            'attention_mask' : tokenized_text['attention_mask'],
            'input_ids' : tokenized_text['input_ids'],
            'labels' : label_indices
            
            # 'tokens' : tokenized_text.tokens() # for debugging only, can be commented out
        }


def generate_label_map():
    return {
    "O": 0,  # Outside
    "B-ORG": 1,
    "I-ORG": 2,
    "B-LOC": 3,
    "I-LOC": 4,
    "B-PERSON": 5,
    "I-PERSON": 6,
    "B-MONEY": 7,
    "I-MONEY": 8,
}


label_map = generate_label_map()            

ner_dataset = NERDataset(data, tokenizer, label_map)
eval_dataset = NERDataset(data, tokenizer, label_map)



In [96]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True)
dataloader = DataLoader(ner_dataset, batch_size=2, collate_fn=data_collator)


In [97]:
next(iter(dataloader))

{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]]), 'input_ids': tensor([[  101,  6207,  2003,  2559,  2000,  4965,  1037,  1057,  1012,  1047,
          1012, 22752,  2005,  1002,  1015,  4551,  1012,   102],
        [  101,  3449,  2239, 14163,  6711,  2631,  2686,  2595,  1998, 26060,
          1012,   102,     0,     0,     0,     0,     0,     0]]), 'labels': tensor([[   0,    1,    0,    0,    0,    0,    0,    3,    4,    4,    4,    0,
            0,    7,    8,    8,    0,    0],
        [   0,    5,    6,    6,    6,    0,    1,    2,    0,    1,    0,    0,
         -100, -100, -100, -100, -100, -100]])}

In [100]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./ner_output",
    overwrite_output_dir=True,
    num_train_epochs=30,
    per_device_train_batch_size=2,
    save_steps=500,
    save_total_limit=2,
    evaluation_strategy="epoch",
    logging_dir="./logs",
    use_mps_device=True
   
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ner_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator)

# Start training
trainer.train()

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 0.00023971761402208358, 'eval_runtime': 0.0423, 'eval_samples_per_second': 94.664, 'eval_steps_per_second': 23.666, 'epoch': 1.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 0.00011524301953613758, 'eval_runtime': 0.0429, 'eval_samples_per_second': 93.273, 'eval_steps_per_second': 23.318, 'epoch': 2.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 5.473521741805598e-05, 'eval_runtime': 0.0521, 'eval_samples_per_second': 76.754, 'eval_steps_per_second': 19.189, 'epoch': 3.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 3.251015004934743e-05, 'eval_runtime': 0.0546, 'eval_samples_per_second': 73.29, 'eval_steps_per_second': 18.322, 'epoch': 4.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 2.2197520593181252e-05, 'eval_runtime': 0.0501, 'eval_samples_per_second': 79.88, 'eval_steps_per_second': 19.97, 'epoch': 5.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 1.688627162366174e-05, 'eval_runtime': 0.0461, 'eval_samples_per_second': 86.838, 'eval_steps_per_second': 21.709, 'epoch': 6.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 1.3604042578663211e-05, 'eval_runtime': 0.0491, 'eval_samples_per_second': 81.508, 'eval_steps_per_second': 20.377, 'epoch': 7.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 1.1301890481263399e-05, 'eval_runtime': 0.0499, 'eval_samples_per_second': 80.207, 'eval_steps_per_second': 20.052, 'epoch': 8.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 9.630936801841017e-06, 'eval_runtime': 0.0487, 'eval_samples_per_second': 82.07, 'eval_steps_per_second': 20.517, 'epoch': 9.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 8.467237421427853e-06, 'eval_runtime': 0.0732, 'eval_samples_per_second': 54.642, 'eval_steps_per_second': 13.66, 'epoch': 10.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 7.63635034672916e-06, 'eval_runtime': 0.0493, 'eval_samples_per_second': 81.18, 'eval_steps_per_second': 20.295, 'epoch': 11.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 7.018923952273326e-06, 'eval_runtime': 0.049, 'eval_samples_per_second': 81.613, 'eval_steps_per_second': 20.403, 'epoch': 12.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 6.550689704454271e-06, 'eval_runtime': 0.0447, 'eval_samples_per_second': 89.56, 'eval_steps_per_second': 22.39, 'epoch': 13.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 6.18115154793486e-06, 'eval_runtime': 0.0624, 'eval_samples_per_second': 64.128, 'eval_steps_per_second': 16.032, 'epoch': 14.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 5.878176580154104e-06, 'eval_runtime': 0.0495, 'eval_samples_per_second': 80.836, 'eval_steps_per_second': 20.209, 'epoch': 15.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 5.611926098936237e-06, 'eval_runtime': 0.0492, 'eval_samples_per_second': 81.25, 'eval_steps_per_second': 20.312, 'epoch': 16.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 5.4007614380680025e-06, 'eval_runtime': 0.0413, 'eval_samples_per_second': 96.951, 'eval_steps_per_second': 24.238, 'epoch': 17.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 5.21254924024106e-06, 'eval_runtime': 0.0508, 'eval_samples_per_second': 78.689, 'eval_steps_per_second': 19.672, 'epoch': 18.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 5.044995305070188e-06, 'eval_runtime': 0.0492, 'eval_samples_per_second': 81.306, 'eval_steps_per_second': 20.326, 'epoch': 19.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 4.918755621474702e-06, 'eval_runtime': 0.0496, 'eval_samples_per_second': 80.714, 'eval_steps_per_second': 20.178, 'epoch': 20.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 4.8108781811606605e-06, 'eval_runtime': 0.0625, 'eval_samples_per_second': 64.031, 'eval_steps_per_second': 16.008, 'epoch': 21.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 4.719067419500789e-06, 'eval_runtime': 0.0535, 'eval_samples_per_second': 74.796, 'eval_steps_per_second': 18.699, 'epoch': 22.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 4.636438006855315e-06, 'eval_runtime': 0.049, 'eval_samples_per_second': 81.686, 'eval_steps_per_second': 20.422, 'epoch': 23.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 4.574465947371209e-06, 'eval_runtime': 0.0509, 'eval_samples_per_second': 78.64, 'eval_steps_per_second': 19.66, 'epoch': 24.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 4.53315033155377e-06, 'eval_runtime': 0.05, 'eval_samples_per_second': 79.974, 'eval_steps_per_second': 19.994, 'epoch': 25.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 4.491836079978384e-06, 'eval_runtime': 0.0499, 'eval_samples_per_second': 80.2, 'eval_steps_per_second': 20.05, 'epoch': 26.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 4.452816483535571e-06, 'eval_runtime': 0.0506, 'eval_samples_per_second': 79.105, 'eval_steps_per_second': 19.776, 'epoch': 27.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 4.425272891239729e-06, 'eval_runtime': 0.0629, 'eval_samples_per_second': 63.582, 'eval_steps_per_second': 15.895, 'epoch': 28.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 4.413797341840109e-06, 'eval_runtime': 0.0531, 'eval_samples_per_second': 75.31, 'eval_steps_per_second': 18.827, 'epoch': 29.0}


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 4.406911557452986e-06, 'eval_runtime': 0.0504, 'eval_samples_per_second': 79.401, 'eval_steps_per_second': 19.85, 'epoch': 30.0}
{'train_runtime': 7.9281, 'train_samples_per_second': 15.136, 'train_steps_per_second': 7.568, 'train_loss': 5.093542858958244e-05, 'epoch': 30.0}


TrainOutput(global_step=60, training_loss=5.093542858958244e-05, metrics={'train_runtime': 7.9281, 'train_samples_per_second': 15.136, 'train_steps_per_second': 7.568, 'train_loss': 5.093542858958244e-05, 'epoch': 30.0})