<a href="https://colab.research.google.com/github/ftnext/ml-playground/blob/main/crf/ousia-llm-book/ousia_llm_chapter6_bert_crf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install \
datasets \
pytorch-crf \
transformers[ja,torch] \
spacy-alignments \
seqeval

Collecting datasets
  Downloading datasets-2.14.1-py3-none-any.whl (492 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m492.4/492.4 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pytorch-crf
  Downloading pytorch_crf-0.7.2-py3-none-any.whl (9.5 kB)
Collecting transformers[ja,torch]
  Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m74.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting spacy-alignments
  Downloading spacy_alignments-0.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m55.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdon

In [2]:
import torch
from datasets import load_dataset
from spacy_alignments import get_alignments
from torchcrf import CRF
from transformers import (
    AutoTokenizer,
    BatchEncoding,
    BertForTokenClassification,
    DataCollatorForTokenClassification,
    PretrainedConfig,
    PreTrainedTokenizer,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.modeling_outputs import TokenClassifierOutput

In [3]:
dataset = load_dataset("llm-book/ner-wikipedia-dataset")

Downloading builder script:   0%|          | 0.00/4.14k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/641k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [4]:
model_name = "cl-tohoku/bert-base-japanese-v3"
tokenizer = AutoTokenizer.from_pretrained(model_name)

Downloading (…)okenizer_config.json:   0%|          | 0.00/251 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/231k [00:00<?, ?B/s]

In [5]:
def create_label2id(entities_collection) -> dict[str, int]:
    label2id = {"O": 0}
    entity_types = {
        entity["type"]
        for entities in entities_collection
        for entity in entities
    }
    for i, entity_type in enumerate(sorted(entity_types)):
        label2id[f"B-{entity_type}"] = i * 2 + 1
        label2id[f"I-{entity_type}"] = i * 2 + 2
    return label2id

In [6]:
label2id = create_label2id(dataset["train"]["entities"])
id2label = {v: k for k, v in label2id.items()}

In [7]:
def tokenize(text: str, tokenizer: PreTrainedTokenizer) -> list[str]:
    return tokenizer.convert_ids_to_tokens(tokenizer.encode(text))

In [8]:
def get_char_to_token_alignments(
    text: str, tokens: list[str]
) -> list[list[int]]:
    # [UNK]が入ってくると検証できない
    if "[UNK]" not in set(tokens):
        text_without_space = text.replace(" ", "")
        joined_tokens = "".join(t.removeprefix("##") for t in tokens[1:-1])
        assert text_without_space == joined_tokens

    characters = list(text)
    # [[1], [1], [1], [2], [2]] のように、何文字目が何番目のトークンかを表す
    char_to_token_indices, _ = get_alignments(characters, tokens)
    return char_to_token_indices

In [9]:
def output_labels(
    text: str, tokens: list[str], entities
) -> list[str]:
    char_to_token_indices = get_char_to_token_alignments(text, tokens)

    labels = ["O"] * len(tokens)
    for entity in entities:
        entity_span, entity_type = entity["span"], entity["type"]
        start_token_indices = char_to_token_indices[entity_span[0]]
        end_token_indices = char_to_token_indices[entity_span[1] - 1]
        # "[UNK]"があるとき、リストが空（start_token_indices[0]がIndexError）
        if len(start_token_indices) == 0 or len(end_token_indices) == 0:
            continue
        start_token_id: int = char_to_token_indices[entity_span[0]][0]
        end_token_id: int = char_to_token_indices[entity_span[1] - 1][0]

        labels[start_token_id] = f"B-{entity_type}"
        for idx in range(start_token_id + 1, end_token_id + 1):
            labels[idx] = f"I-{entity_type}"

    # 特殊トークンにはラベルを設定しない
    labels[0] = "-"
    labels[-1] = "-"

    return labels

In [10]:
def preprocess_data(
    data, tokenizer: PreTrainedTokenizer, label2id: dict[str, int]
) -> BatchEncoding:
    inputs = tokenizer(
        data["text"], return_tensors="pt", return_special_tokens_mask=True
    )
    # torch.Tensorのsizeが[1, トークン長]となっているので、squeezeしてsize [トークン長] とする
    flatten_inputs = {k: v.squeeze(0) for k, v in inputs.items()}

    tokens = tokenize(data["text"], tokenizer)
    string_labels = output_labels(data["text"], tokens, data["entities"])
    assert len(string_labels) == flatten_inputs["input_ids"].size(0)

    # string_labelsには[CLS]と[SEP]に対応する-があり、これはlabel2idに含まれない
    tensor_labels = torch.tensor(
        [label2id.get(label, 0) for label in string_labels]
    )
    tensor_labels[torch.where(flatten_inputs["special_tokens_mask"])] = -100
    flatten_inputs["labels"] = tensor_labels
    return flatten_inputs

In [11]:
train_dataset = dataset["train"].map(
    preprocess_data,
    fn_kwargs={"tokenizer": tokenizer, "label2id": label2id},
    remove_columns=dataset["train"].column_names,
)
validation_dataset = dataset["validation"].map(
    preprocess_data,
    fn_kwargs={"tokenizer": tokenizer, "label2id": label2id},
    remove_columns=dataset["validation"].column_names,
)

Map:   0%|          | 0/4274 [00:00<?, ? examples/s]

Map:   0%|          | 0/534 [00:00<?, ? examples/s]

In [12]:
def create_transitions(
    label2id: dict[str, int]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    b_ids = [v for k, v in label2id.items() if k.startswith("B")]
    i_ids = [v for k, v in label2id.items() if k.startswith("I")]
    o_id = label2id["O"]

    # 開始からはBとOに遷移可能。Iには遷移不可能
    start_transitions = torch.full([len(label2id)], -100.0)
    start_transitions[b_ids] = 0
    start_transitions[o_id] = 0

    between_labels_transitions = torch.full(
        [len(label2id), len(label2id)], -100.0
    )
    # すべてのラベルからBやOに遷移可能
    between_labels_transitions[:, b_ids] = 0
    between_labels_transitions[:, o_id] = 0
    # Bから同じタイプのIへ、Iから同じタイプのIへ遷移可能
    between_labels_transitions[b_ids, i_ids] = 0
    between_labels_transitions[i_ids, i_ids] = 0

    # すべてのラベルから終了に遷移可能
    end_transitions = torch.zeros(len(label2id))
    return start_transitions, between_labels_transitions, end_transitions

In [13]:
class BertWithCrfForTokenClassification(BertForTokenClassification):
    def __init__(self, config: PretrainedConfig):
        super().__init__(config)
        self.crf = CRF(len(config.label2id), batch_first=True)

    def _init_weights(self, module: torch.nn.Module) -> None:
        super()._init_weights(module)
        if isinstance(module, CRF):
            st, t, et = create_transitions(self.config.label2id)
            module.start_transitions.data = st
            module.transitions.data = t
            module.end_transitions.data = et

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
        labels: torch.Tensor | None = None,
    ) -> TokenClassifierOutput:
        output = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        if labels is not None:
            logits = output.logits
            mask = labels != -100
            labels *= mask
            output["loss"] = -self.crf(
                logits[:, 1:, :],
                labels[:, 1:],
                mask=mask[:, 1:],
                reduction="mean",
            )
        return output

In [14]:
model_crf = BertWithCrfForTokenClassification.from_pretrained(
    model_name, label2id=label2id, id2label=id2label
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/447M [00:00<?, ?B/s]

Some weights of BertWithCrfForTokenClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.weight', 'crf.transitions', 'crf.end_transitions', 'classifier.bias', 'crf.start_transitions']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
set_seed(42)

In [16]:
training_args = TrainingArguments(
    output_dir="output_bert_crf_ner",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=1e-4,
    lr_scheduler_type="linear",
    warmup_ratio=0.1,
    num_train_epochs=5,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    fp16=True,
)
data_collator = DataCollatorForTokenClassification(tokenizer)

In [17]:
trainer = Trainer(
    model=model_crf,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    data_collator=data_collator,
    args=training_args,
)

In [18]:
trainer.train()



Epoch,Training Loss,Validation Loss
1,18.6555,1.717354
2,1.2243,1.399486
3,0.5614,1.705526
4,0.3113,1.882042
5,0.1895,2.053059


TrainOutput(global_step=670, training_loss=4.188399502768445, metrics={'train_runtime': 244.6214, 'train_samples_per_second': 87.359, 'train_steps_per_second': 2.739, 'total_flos': 1073167784696784.0, 'train_loss': 4.188399502768445, 'epoch': 5.0})

In [19]:
from google.colab import drive

In [20]:
drive.mount("drive")

Mounted at drive


In [21]:
!mkdir -p drive/MyDrive/llm-book
!cp -r output_bert_crf_ner/ drive/MyDrive/llm-book