In [7]:
%pip install transformers[torch] datasets torch accelerate>=0.26.0

Note: you may need to restart the kernel to use updated packages.


In [15]:
import random

# Generate train data

In [16]:
# import labels and relationships from the json files
import json
with open('./../useful_dataset/lbl2ent.json') as f:
    file = json.load(f)
    entities = list(file.keys())

with open('./../useful_dataset/lbl2rel.json') as f:
    file = json.load(f)
    relations = list(file.keys())

In [17]:
templates = [
    "Who is the {} of {}?",
    "Who {} the movie {}?",
    "Who worked as {} on the movie {}?",
    "Which person is the {} of {}?"
]

In [30]:
def generate_examples(movies, relationships, templates, num_examples=100):
    sentences = []
    labels = []

    for _ in range(num_examples):
        # Randomly pick a relationship and a movie title
        relationship = random.choice(relationships)
        movie = random.choice(movies)

        # Randomly pick a template and fill it in
        template = random.choice(templates)
        question = template.format(relationship, movie)

        # Tokenize question manually
        tokens = question.split()

        # Create labels for each token
        label_seq = []
        for token in tokens:
            if token.lower() == relationship:
                # label_seq.append("B-" + relationship.replace(" ", "_").upper())
                label_seq.append("B-RELATION")
            elif token.lower() in relationship.split(" "):
                # label_seq.append("I-" + relationship.replace(" ", "_").upper())
                label_seq.append("I-RELATION")
            elif token == movie.split()[0]:
                label_seq.append("B-MOVIE")
            elif token in movie.split()[1:]:
                label_seq.append("I-MOVIE")
            else:
                label_seq.append("O")

        # Append to the list
        sentences.append(tokens)
        labels.append(label_seq)

    return sentences, labels

# Generate 200 examples
tokens, ner_tags = generate_examples(entities, relations, templates, num_examples=500)

In [33]:
# create a checkpoint of the created dataset with pickle?
import pickle

with open('data.pkl', 'wb') as f:
    pickle.dump((tokens, ner_tags), f)


In [34]:
import torch
from transformers import BertTokenizer, BertForTokenClassification, Trainer, TrainingArguments
from datasets import load_dataset, Dataset, DatasetDict
from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification
import random

# Assuming `tokens` and `ner_tags` are available

# Load your dataset in the correct format
train_data = {
    'tokens': tokens,
    'ner_tags': ner_tags
}

# Map NER tags to labels (B-RELATION, B-MOVIE, etc.)
label_list = ["O", "B-RELATION", "B-MOVIE", "I-MOVIE", "I-RELATION"]

# Create the mapping from label to index
label_map = {label: i for i, label in enumerate(label_list)}

# Convert ner_tags to numerical values based on label_map
def convert_labels_to_ids(ner_tags):
    converted_tags = []
    for labels in ner_tags:
        converted_tags.append([label_map[label] for label in labels])
    return converted_tags

# Apply label conversion
train_data['ner_tags'] = convert_labels_to_ids(train_data['ner_tags'])

# Convert the dataset into Hugging Face Dataset format
dataset = Dataset.from_dict(train_data)

# Split the dataset into training and validation sets (80/20 split)
train_test_split = dataset.train_test_split(test_size=0.2)
datasets = DatasetDict({"train": train_test_split["train"], "test": train_test_split["test"]})

# Load the fast version of the DistilBERT tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

# Load the pre-trained DistilBERT model for token classification
model = DistilBertForTokenClassification.from_pretrained("distilbert-base-uncased", num_labels=len(label_list))

# Tokenize the dataset
# Function to tokenize and align labels
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples['tokens'], 
        truncation=True,  # Activate truncation
        padding=True,     # Activate padding
        is_split_into_words=True  # Ensure input is treated as pre-tokenized (word level)
    )

    labels = []
    for i, label in enumerate(examples['ner_tags']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Get word IDs for each token
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)  # Ignored label for special tokens and padding
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])  # Assign the label to the first token of a word
            else:
                label_ids.append(-100)  # Assign ignored label to sub-tokens
            previous_word_idx = word_idx
        labels.append(label_ids)
    
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

# Apply the tokenization and label alignment to the dataset
tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|██████████| 400/400 [00:00<00:00, 1649.78 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 1928.72 examples/s]


# Train Dataset

In [35]:
# Define training arguments
# Check if GPU is available and move model accordingly
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using device:", device)
model.to(device)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=50,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    save_strategy="epoch",
)

# Create Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
)
# Train the model
trainer.train()

# Save the model
model.save_pretrained("./custom_ner_model")
tokenizer.save_pretrained("./custom_ner_model")

  0%|          | 10/2500 [00:31<1:36:34,  2.33s/it]

{'loss': 1.236, 'grad_norm': 2.05682373046875, 'learning_rate': 1.9920000000000002e-05, 'epoch': 0.2}


  1%|          | 20/2500 [00:50<1:18:05,  1.89s/it]

{'loss': 0.9362, 'grad_norm': 2.098996162414551, 'learning_rate': 1.9840000000000003e-05, 'epoch': 0.4}


  1%|          | 30/2500 [01:08<1:12:12,  1.75s/it]

{'loss': 0.8679, 'grad_norm': 2.9255290031433105, 'learning_rate': 1.976e-05, 'epoch': 0.6}


  2%|▏         | 40/2500 [01:27<1:13:41,  1.80s/it]

{'loss': 0.6143, 'grad_norm': 1.9422575235366821, 'learning_rate': 1.968e-05, 'epoch': 0.8}


  2%|▏         | 50/2500 [01:45<1:15:19,  1.84s/it]

{'loss': 0.4903, 'grad_norm': 2.7894644737243652, 'learning_rate': 1.9600000000000002e-05, 'epoch': 1.0}



  2%|▏         | 50/2500 [01:50<1:15:19,  1.84s/it]

{'eval_loss': 0.4394254684448242, 'eval_runtime': 5.1281, 'eval_samples_per_second': 19.5, 'eval_steps_per_second': 2.535, 'epoch': 1.0}


  2%|▏         | 60/2500 [02:12<1:20:15,  1.97s/it]

{'loss': 0.4153, 'grad_norm': 2.667789936065674, 'learning_rate': 1.9520000000000003e-05, 'epoch': 1.2}


  3%|▎         | 70/2500 [02:30<1:10:57,  1.75s/it]

{'loss': 0.2934, 'grad_norm': 2.1570897102355957, 'learning_rate': 1.944e-05, 'epoch': 1.4}


  3%|▎         | 80/2500 [02:51<1:18:53,  1.96s/it]

{'loss': 0.251, 'grad_norm': 2.7886674404144287, 'learning_rate': 1.936e-05, 'epoch': 1.6}


  4%|▎         | 90/2500 [03:08<1:08:02,  1.69s/it]

{'loss': 0.2071, 'grad_norm': 1.2972254753112793, 'learning_rate': 1.9280000000000002e-05, 'epoch': 1.8}


  4%|▍         | 100/2500 [03:25<1:06:21,  1.66s/it]

{'loss': 0.1443, 'grad_norm': 1.9633232355117798, 'learning_rate': 1.9200000000000003e-05, 'epoch': 2.0}



  4%|▍         | 100/2500 [03:30<1:06:21,  1.66s/it]

{'eval_loss': 0.18129847943782806, 'eval_runtime': 4.415, 'eval_samples_per_second': 22.65, 'eval_steps_per_second': 2.945, 'epoch': 2.0}


  4%|▍         | 110/2500 [03:54<1:28:08,  2.21s/it]

{'loss': 0.0976, 'grad_norm': 2.0399608612060547, 'learning_rate': 1.912e-05, 'epoch': 2.2}


  5%|▍         | 120/2500 [04:12<1:10:26,  1.78s/it]

{'loss': 0.1081, 'grad_norm': 0.8024411201477051, 'learning_rate': 1.904e-05, 'epoch': 2.4}


  5%|▌         | 130/2500 [04:30<1:12:02,  1.82s/it]

{'loss': 0.0932, 'grad_norm': 1.4590203762054443, 'learning_rate': 1.896e-05, 'epoch': 2.6}


  6%|▌         | 140/2500 [04:48<1:09:15,  1.76s/it]

{'loss': 0.077, 'grad_norm': 0.5680411458015442, 'learning_rate': 1.8880000000000002e-05, 'epoch': 2.8}


  6%|▌         | 150/2500 [05:05<1:03:40,  1.63s/it]

{'loss': 0.097, 'grad_norm': 3.2772316932678223, 'learning_rate': 1.88e-05, 'epoch': 3.0}



  6%|▌         | 150/2500 [05:10<1:03:40,  1.63s/it]

{'eval_loss': 0.11745478957891464, 'eval_runtime': 4.7154, 'eval_samples_per_second': 21.207, 'eval_steps_per_second': 2.757, 'epoch': 3.0}


  6%|▋         | 160/2500 [05:32<1:10:54,  1.82s/it]

{'loss': 0.0873, 'grad_norm': 0.8015313148498535, 'learning_rate': 1.8720000000000004e-05, 'epoch': 3.2}


  7%|▋         | 170/2500 [05:49<1:03:50,  1.64s/it]

{'loss': 0.0523, 'grad_norm': 0.9232264161109924, 'learning_rate': 1.864e-05, 'epoch': 3.4}


  7%|▋         | 180/2500 [06:05<1:05:41,  1.70s/it]

{'loss': 0.0554, 'grad_norm': 1.0180779695510864, 'learning_rate': 1.8560000000000002e-05, 'epoch': 3.6}


  8%|▊         | 190/2500 [06:22<1:04:27,  1.67s/it]

{'loss': 0.0309, 'grad_norm': 0.6432324051856995, 'learning_rate': 1.8480000000000003e-05, 'epoch': 3.8}


  8%|▊         | 200/2500 [06:40<1:12:33,  1.89s/it]

{'loss': 0.0199, 'grad_norm': 4.413028717041016, 'learning_rate': 1.8400000000000003e-05, 'epoch': 4.0}



  8%|▊         | 200/2500 [06:45<1:12:33,  1.89s/it]

{'eval_loss': 0.10696744918823242, 'eval_runtime': 5.0123, 'eval_samples_per_second': 19.951, 'eval_steps_per_second': 2.594, 'epoch': 4.0}


  8%|▊         | 210/2500 [07:12<1:23:35,  2.19s/it]

{'loss': 0.0435, 'grad_norm': 1.2067598104476929, 'learning_rate': 1.832e-05, 'epoch': 4.2}


  9%|▉         | 220/2500 [07:30<1:03:19,  1.67s/it]

{'loss': 0.0347, 'grad_norm': 2.71944260597229, 'learning_rate': 1.824e-05, 'epoch': 4.4}


  9%|▉         | 230/2500 [07:47<1:03:42,  1.68s/it]

{'loss': 0.0356, 'grad_norm': 1.7716636657714844, 'learning_rate': 1.8160000000000002e-05, 'epoch': 4.6}


 10%|▉         | 240/2500 [08:04<1:02:13,  1.65s/it]

{'loss': 0.0245, 'grad_norm': 1.2135217189788818, 'learning_rate': 1.8080000000000003e-05, 'epoch': 4.8}


 10%|█         | 250/2500 [08:21<1:04:37,  1.72s/it]

{'loss': 0.0124, 'grad_norm': 1.8004651069641113, 'learning_rate': 1.8e-05, 'epoch': 5.0}



 10%|█         | 250/2500 [08:26<1:04:37,  1.72s/it]

{'eval_loss': 0.09780597686767578, 'eval_runtime': 4.385, 'eval_samples_per_second': 22.805, 'eval_steps_per_second': 2.965, 'epoch': 5.0}


 10%|█         | 260/2500 [08:49<1:16:10,  2.04s/it]

{'loss': 0.0181, 'grad_norm': 2.1001110076904297, 'learning_rate': 1.792e-05, 'epoch': 5.2}


 11%|█         | 270/2500 [09:09<1:12:43,  1.96s/it]

{'loss': 0.0176, 'grad_norm': 0.4288174510002136, 'learning_rate': 1.7840000000000002e-05, 'epoch': 5.4}


 11%|█         | 280/2500 [09:27<1:02:15,  1.68s/it]

{'loss': 0.0107, 'grad_norm': 1.7580386400222778, 'learning_rate': 1.7760000000000003e-05, 'epoch': 5.6}


 12%|█▏        | 290/2500 [09:47<1:13:50,  2.00s/it]

{'loss': 0.0272, 'grad_norm': 0.19294720888137817, 'learning_rate': 1.768e-05, 'epoch': 5.8}


 12%|█▏        | 300/2500 [10:04<1:01:37,  1.68s/it]

{'loss': 0.0313, 'grad_norm': 0.7597123384475708, 'learning_rate': 1.76e-05, 'epoch': 6.0}



 12%|█▏        | 300/2500 [10:09<1:01:37,  1.68s/it]

{'eval_loss': 0.09392878413200378, 'eval_runtime': 4.956, 'eval_samples_per_second': 20.178, 'eval_steps_per_second': 2.623, 'epoch': 6.0}


 12%|█▏        | 310/2500 [10:31<1:13:44,  2.02s/it]

{'loss': 0.0168, 'grad_norm': 0.0949208214879036, 'learning_rate': 1.752e-05, 'epoch': 6.2}


 13%|█▎        | 320/2500 [10:50<1:10:16,  1.93s/it]

{'loss': 0.0133, 'grad_norm': 1.1375336647033691, 'learning_rate': 1.7440000000000002e-05, 'epoch': 6.4}


 13%|█▎        | 330/2500 [11:06<57:24,  1.59s/it]  

{'loss': 0.0097, 'grad_norm': 0.08247578889131546, 'learning_rate': 1.736e-05, 'epoch': 6.6}


 14%|█▎        | 340/2500 [11:25<1:06:59,  1.86s/it]

{'loss': 0.0031, 'grad_norm': 0.06675557047128677, 'learning_rate': 1.728e-05, 'epoch': 6.8}


 14%|█▍        | 350/2500 [11:42<59:12,  1.65s/it]  

{'loss': 0.0264, 'grad_norm': 0.2759002447128296, 'learning_rate': 1.72e-05, 'epoch': 7.0}



 14%|█▍        | 350/2500 [11:47<59:12,  1.65s/it]

{'eval_loss': 0.0930391326546669, 'eval_runtime': 4.958, 'eval_samples_per_second': 20.17, 'eval_steps_per_second': 2.622, 'epoch': 7.0}


 14%|█▍        | 360/2500 [12:06<1:00:27,  1.70s/it]

{'loss': 0.0076, 'grad_norm': 5.107741355895996, 'learning_rate': 1.7120000000000002e-05, 'epoch': 7.2}


 15%|█▍        | 370/2500 [12:24<58:50,  1.66s/it]  

{'loss': 0.0072, 'grad_norm': 0.47735264897346497, 'learning_rate': 1.704e-05, 'epoch': 7.4}


 15%|█▌        | 380/2500 [12:40<54:50,  1.55s/it]  

{'loss': 0.0141, 'grad_norm': 0.07981155067682266, 'learning_rate': 1.696e-05, 'epoch': 7.6}


 16%|█▌        | 390/2500 [12:56<55:13,  1.57s/it]

{'loss': 0.0023, 'grad_norm': 0.2992154657840729, 'learning_rate': 1.688e-05, 'epoch': 7.8}


 16%|█▌        | 400/2500 [13:12<53:04,  1.52s/it]

{'loss': 0.0106, 'grad_norm': 1.248434066772461, 'learning_rate': 1.6800000000000002e-05, 'epoch': 8.0}



 16%|█▌        | 400/2500 [13:17<53:04,  1.52s/it]

{'eval_loss': 0.0867852121591568, 'eval_runtime': 4.4396, 'eval_samples_per_second': 22.524, 'eval_steps_per_second': 2.928, 'epoch': 8.0}


 16%|█▋        | 410/2500 [13:35<56:11,  1.61s/it]  

{'loss': 0.0205, 'grad_norm': 0.05835070461034775, 'learning_rate': 1.672e-05, 'epoch': 8.2}


 17%|█▋        | 420/2500 [13:50<55:28,  1.60s/it]

{'loss': 0.0027, 'grad_norm': 0.032453134655952454, 'learning_rate': 1.664e-05, 'epoch': 8.4}


 17%|█▋        | 430/2500 [14:06<54:48,  1.59s/it]

{'loss': 0.003, 'grad_norm': 0.10624770820140839, 'learning_rate': 1.656e-05, 'epoch': 8.6}


 18%|█▊        | 440/2500 [14:21<51:49,  1.51s/it]

{'loss': 0.0103, 'grad_norm': 0.09980782121419907, 'learning_rate': 1.648e-05, 'epoch': 8.8}


 18%|█▊        | 450/2500 [14:38<57:32,  1.68s/it]  

{'loss': 0.0027, 'grad_norm': 0.1972033828496933, 'learning_rate': 1.64e-05, 'epoch': 9.0}



 18%|█▊        | 450/2500 [14:42<57:32,  1.68s/it]

{'eval_loss': 0.08775633573532104, 'eval_runtime': 4.236, 'eval_samples_per_second': 23.607, 'eval_steps_per_second': 3.069, 'epoch': 9.0}


 18%|█▊        | 460/2500 [15:02<56:40,  1.67s/it]  

{'loss': 0.0019, 'grad_norm': 0.015519374050199986, 'learning_rate': 1.632e-05, 'epoch': 9.2}


 19%|█▉        | 470/2500 [15:19<54:56,  1.62s/it]  

{'loss': 0.0025, 'grad_norm': 0.09384717047214508, 'learning_rate': 1.6240000000000004e-05, 'epoch': 9.4}


 19%|█▉        | 480/2500 [15:35<51:53,  1.54s/it]

{'loss': 0.0025, 'grad_norm': 0.0310612004250288, 'learning_rate': 1.616e-05, 'epoch': 9.6}


 20%|█▉        | 490/2500 [15:52<54:50,  1.64s/it]

{'loss': 0.0032, 'grad_norm': 0.1335475742816925, 'learning_rate': 1.6080000000000002e-05, 'epoch': 9.8}


 20%|██        | 500/2500 [16:07<50:44,  1.52s/it]

{'loss': 0.0041, 'grad_norm': 0.7003500461578369, 'learning_rate': 1.6000000000000003e-05, 'epoch': 10.0}



 20%|██        | 500/2500 [16:11<50:44,  1.52s/it]

{'eval_loss': 0.11824596673250198, 'eval_runtime': 4.1297, 'eval_samples_per_second': 24.215, 'eval_steps_per_second': 3.148, 'epoch': 10.0}


 20%|██        | 510/2500 [16:31<57:43,  1.74s/it]  

{'loss': 0.0048, 'grad_norm': 0.09376294165849686, 'learning_rate': 1.5920000000000003e-05, 'epoch': 10.2}


 21%|██        | 520/2500 [16:46<49:24,  1.50s/it]  

{'loss': 0.0159, 'grad_norm': 4.876733303070068, 'learning_rate': 1.584e-05, 'epoch': 10.4}


 21%|██        | 530/2500 [17:02<50:59,  1.55s/it]

{'loss': 0.0021, 'grad_norm': 0.026007119566202164, 'learning_rate': 1.576e-05, 'epoch': 10.6}


 22%|██▏       | 540/2500 [17:18<53:15,  1.63s/it]

{'loss': 0.0088, 'grad_norm': 0.5785508751869202, 'learning_rate': 1.5680000000000002e-05, 'epoch': 10.8}


 22%|██▏       | 550/2500 [17:37<54:02,  1.66s/it]  

{'loss': 0.0014, 'grad_norm': 0.03034861758351326, 'learning_rate': 1.5600000000000003e-05, 'epoch': 11.0}



 22%|██▏       | 550/2500 [17:45<54:02,  1.66s/it]

{'eval_loss': 0.1020650565624237, 'eval_runtime': 8.2653, 'eval_samples_per_second': 12.099, 'eval_steps_per_second': 1.573, 'epoch': 11.0}


 22%|██▏       | 560/2500 [18:11<1:10:44,  2.19s/it]

{'loss': 0.0013, 'grad_norm': 0.011421275325119495, 'learning_rate': 1.552e-05, 'epoch': 11.2}


 23%|██▎       | 570/2500 [18:30<57:39,  1.79s/it]  

{'loss': 0.0012, 'grad_norm': 0.014551719650626183, 'learning_rate': 1.544e-05, 'epoch': 11.4}


 23%|██▎       | 580/2500 [18:47<54:01,  1.69s/it]

{'loss': 0.0082, 'grad_norm': 0.03415127843618393, 'learning_rate': 1.5360000000000002e-05, 'epoch': 11.6}


 24%|██▎       | 590/2500 [19:07<55:48,  1.75s/it]  

{'loss': 0.0105, 'grad_norm': 0.26651209592819214, 'learning_rate': 1.5280000000000003e-05, 'epoch': 11.8}


 24%|██▍       | 600/2500 [19:24<57:30,  1.82s/it]

{'loss': 0.002, 'grad_norm': 0.25476396083831787, 'learning_rate': 1.5200000000000002e-05, 'epoch': 12.0}



 24%|██▍       | 600/2500 [19:29<57:30,  1.82s/it]

{'eval_loss': 0.08727285265922546, 'eval_runtime': 5.1006, 'eval_samples_per_second': 19.606, 'eval_steps_per_second': 2.549, 'epoch': 12.0}


 24%|██▍       | 610/2500 [19:52<55:54,  1.77s/it]  

{'loss': 0.002, 'grad_norm': 1.926017165184021, 'learning_rate': 1.5120000000000001e-05, 'epoch': 12.2}


 25%|██▍       | 620/2500 [20:08<50:20,  1.61s/it]

{'loss': 0.0016, 'grad_norm': 0.012159683741629124, 'learning_rate': 1.5040000000000002e-05, 'epoch': 12.4}


 25%|██▌       | 630/2500 [20:25<47:15,  1.52s/it]

{'loss': 0.0015, 'grad_norm': 0.04647526890039444, 'learning_rate': 1.496e-05, 'epoch': 12.6}


 26%|██▌       | 640/2500 [20:41<48:44,  1.57s/it]

{'loss': 0.001, 'grad_norm': 0.018428072333335876, 'learning_rate': 1.4880000000000002e-05, 'epoch': 12.8}


 26%|██▌       | 650/2500 [20:57<47:31,  1.54s/it]

{'loss': 0.0011, 'grad_norm': 0.017468370497226715, 'learning_rate': 1.48e-05, 'epoch': 13.0}



 26%|██▌       | 650/2500 [21:02<47:31,  1.54s/it]

{'eval_loss': 0.1001766100525856, 'eval_runtime': 5.2158, 'eval_samples_per_second': 19.173, 'eval_steps_per_second': 2.492, 'epoch': 13.0}


 26%|██▌       | 653/2500 [21:10<1:29:28,  2.91s/it]

KeyboardInterrupt: 