In [None]:
# エラーの出ないバージョン
!pip install transformers==4.45.2

In [None]:
# 必要なものを用意
## 使用ベースモデル
xlmr_model_name = "xlm-roberta-base"

## データセットロード
from collections import defaultdict
from datasets import load_dataset, DatasetDict
## 話者比率でロード
langs = ["de", "fr", "it", "en"]
fracs = [0.629, 0.229, 0.084, 0.059]
panx_ch = defaultdict(DatasetDict)
for lang, frac in zip(langs,fracs):
  ds = load_dataset("xtreme", name=f"PAN-X.{lang}")
  for split in ds:
    panx_ch[lang][split] = (
        ds[split]
        .shuffle(seed=0)
        .select(range(int(frac*ds[split].num_rows)))
    )

## タグの取得
tags = panx_ch["de"]["train"].features["ner_tags"].feature
index2tag = {idx: tag for idx, tag in enumerate(tags.names)}
tag2index = {tag: idx for idx, tag in enumerate(tags.names)}

## データセットテキストのトークン化(エンコード)
from transformers import AutoTokenizer
xlmr_tokenizer = AutoTokenizer.from_pretrained(xlmr_model_name)

def tokenize_and_align_labels(examples):
  tokenized_inputs = xlmr_tokenizer(examples["tokens"], truncation=True,
                                    is_split_into_words=True)
  labels = []
  for idx, label in enumerate(examples["ner_tags"]):
    word_ids = tokenized_inputs.word_ids(batch_index=idx)
    previous_word_idx = None
    label_ids = []
    for word_idx in word_ids:
      if word_idx is None or word_idx == previous_word_idx:
        label_ids.append(-100)
      else:
        label_ids.append(label[word_idx])
      previous_word_idx = word_idx
    labels.append(label_ids)
  tokenized_inputs["labels"] = labels
  return tokenized_inputs

def encode_panx_dataset(corpus):
  return corpus.map(tokenize_and_align_labels, batched=True,
                    remove_columns=['langs', 'ner_tags', 'tokens'])

panx_de_encoded = encode_panx_dataset(panx_ch["de"]) # エンコード済みドイツ語コーパス

## カスタムモデルクラス
import torch.nn as nn
from transformers import XLMRobertaConfig
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaPreTrainedModel

class XLMRobertaForTokenClassification(RobertaPreTrainedModel):
  config_class = XLMRobertaConfig

  def __init__(self, config):
    super().__init__(config)
    self.num_labels = config.num_labels
    # ボディをロード
    self.roberta = RobertaModel(config, add_pooling_layer=False) # [CLS]トークンによる表現抽出層の無効化
    # トークン分類ヘッドの用意
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)
    # 重みのロードと初期化
    self.init_weights()

  def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
              labels=None, **kwargs):
    # ボディによりエンコーダの表現を取得
    outputs = self.roberta(input_ids, attention_mask=attention_mask,
                           token_type_ids=token_type_ids, **kwargs)
    # 分類器を適用
    sequence_output = self.dropout(outputs[0]) # 最後の隠れ状態
    logits = self.classifier(sequence_output)

    # 損失計算
    loss = None
    if labels is not None:
      loss_fct = nn.CrossEntropyLoss()
      loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
    # モデルの出力オブジェクトとして返す
    return TokenClassifierOutput(loss=loss, logits=logits,
                                 hidden_states=outputs.hidden_states,
                                 attentions=outputs.attentions)

## カスタムモデル設定
from transformers import AutoConfig
xlmr_config = AutoConfig.from_pretrained(xlmr_model_name,
                                         num_labels=tags.num_classes,
                                         id2label=index2tag, label2id=tag2index)

## データコレーター
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(xlmr_tokenizer)

## モデルロード
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = XLMRobertaForTokenClassification.from_pretrained("kirapika2/xlm-roberta-base-finetuned-panx-de",
                                                          config=xlmr_config).to(device)

In [None]:
# トークンごとの損失計算
from torch.nn.functional import cross_entropy

def forward_pass_with_label(batch):
  # リストの辞書をデータコレーターに適した、辞書のリストに変換
  features = [dict(zip(batch, t)) for t in zip(*batch.values())]
  # 入力ラベルをパディングし、全てのテンソルをデバイスにのせる
  batch = data_collator(features)
  input_ids = batch["input_ids"].to(device)
  attention_mask = batch["attention_mask"].to(device)
  labels = batch["labels"].to(device)
  with torch.no_grad():
    # データをモデルに渡す
    output = model(input_ids, attention_mask)
    # logit.size: [batch_size, sequence_length, classes]
    # ロジットが最大のクラスを予測
    predicted_label = torch.argmax(output.logits, axis=-1).cpu().numpy()
  # view を使ってバッチ次元をFlattenし、トークンごとの損失を計算
  loss = cross_entropy(output.logits.view(-1, 7),
                       labels.view(-1), reduction="none")
  # バッチ次元をUnflattenし、Numpy配列に変換
  loss = loss.view(len(input_ids), -1).cpu().numpy()

  return {"loss": loss, "predicted_label": predicted_label}

In [None]:
# 検証データセット全体の損失計算
valid_set = panx_de_encoded["validation"]
valid_set = valid_set.map(forward_pass_with_label, batched=True, batch_size=32)
df = valid_set.to_pandas()

In [None]:
# df の整形
index2tag[-100] = "IGN"
df["input_tokens"] = df["input_ids"].apply(
    lambda x: xlmr_tokenizer.convert_ids_to_tokens(x))
## 各ラベルをIDからタグ名に
df["predicted_label"] = df["predicted_label"].apply(
    lambda x: [index2tag[i] for i in x])
df["labels"] = df["labels"].apply(
    lambda x: [index2tag[i] for i in x])
## 出力内容は入力サイズに切り詰める
df["loss"] = df.apply(
    lambda x: x["loss"][:len(x["input_ids"])], axis=1)
df["predicted_label"] = df.apply(
    lambda x: x["predicted_label"][:len(x["input_ids"])], axis=1)
## 内容確認
df.head(1)

In [None]:
# トークンごとの結果に整形
import pandas as pd
df_tokens = df.apply(pd.Series.explode)
df_tokens = df_tokens.query("labels != 'IGN'")
df_tokens["loss"] = df_tokens["loss"].astype(float).round(2)
df_tokens.head(7)

In [None]:
# 損失の多いトークンを特定
(
    df_tokens.groupby("input_tokens")['loss']
    .agg(["count", "mean", "sum"])
    .sort_values(by="sum", ascending=False)
    .reset_index()
    .round(2)
    .head(10)
    .T
)

In [None]:
# 損失の多いラベルを特定
(
    df_tokens.groupby("labels")['loss']
    .agg(["count", "mean", "sum"])
    .sort_values(by="mean", ascending=False)
    .reset_index()
    .round(2)
    .T
)