<a href="https://colab.research.google.com/github/futugyou/pyproject/blob/master/google_colab/generation_representation_model_04.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required dependencies
%pip install datasets
%pip install sentence_transformers
%pip install transformers
%pip install torch
%pip install tqdm
%pip install evaluate
%pip install scikit-learn

In [None]:
from datasets import load_dataset

dataset = load_dataset("conll2003", trust_remote_code=True)
train_dataset, test_dataset = dataset["train"], dataset["test"]

In [None]:
label2id = {
    "0": 0,
    "B-PER": 1,
    "I-PER": 2,
    "B-ORG": 3,
    "I-ORG": 4,
    "B-LOC": 5,
    "I-LOC": 6,
    "B-MISC": 7,
    "I-MISC": 8,
}

id2label = {index: label for label, index in label2id.items()}

In [None]:
from transformers import AutoTokenizer, AutoModelForTokenClassification

model_id = "bert-base-cased"
model = AutoModelForTokenClassification.from_pretrained(model_id, num_labels=len(id2label), id2label=id2label, label2id=label2id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
example = train_dataset[848]
token_ids = tokenizer(example["tokens"], is_split_into_words=True)["input_ids"]
sub_tokens = tokenizer.convert_ids_to_tokens(token_ids)

In [None]:
def align_labels(examples):
    token_ids = tokenizer(
        examples["text"], truncation=True, is_split_into_words=True
    )
    labels = examples["ner_tags"]
    updated_labels = []
    for index, label in enumerate(labels):
        word_ids = token_ids.word_ids(batch_index=index)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx != previous_word_idx:
                previous_word_idx = word_idx
                updated_label = 100 if word_idx is None else label[word_idx]
                label_ids.append(updated_label)
            elif word_idx is None:
                label_ids.append(-100)
            else:
                updated_label = label[word_idx]
                if updated_label % 2 == 1:
                    updated_label = updated_label + 1
                label_ids.append(updated_label)
        updated_labels.append(label_ids)
    token_ids["labels"] = updated_labels
    return token_ids

In [None]:
tokenized = dataset.map(align_labels, batched=True)
print(f"Original: {examples["ner_tags"]}")
print(f"Updated: {tokenized["train"][848]["lanels"]}")

In [None]:
import evaluate

seqeval = evaluate.load('seqeval')

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logots, axis=-1)
    true_predictions = []
    true_labels = []

    for prediction, label in zip(predictions, labels):
        for token_prediction, token_label in zip(prediction, label):
            if token_label != -100:
                true_predictions.append([id2label[token_prediction]])
                true_labels.append([id2label[token_label]])
    
    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {"f1": results["overall_f1"]}

In [None]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    "model",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
    save_strategy="epoch",
    report_to="none",
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
trainer.evaluate()

In [None]:
from transformers import pipeline

trainer.save_model("ner_model")

token_classifier = pipeline("token-classification", model="ner_model")
token_classifier("my name is Maarten")