In [None]:
# 利用可能な xtreme サブセット
from datasets import get_dataset_config_names

xtreme_subsets = get_dataset_config_names("xtreme")
print(f"XTREME has {len(xtreme_subsets)} configurations") # 183

In [None]:
# PANから始まるものを3つ確認
panx_subsets = [s for s in xtreme_subsets if s.startswith("PAN")]
panx_subsets[:3]

In [None]:
# ドイツ語コーパス読み込み
from datasets import load_dataset

load_dataset("xtreme", name="PAN-X.de")

In [None]:
# スイスの話者比率でデータセット用意
from collections import defaultdict
from datasets import DatasetDict

langs = ["de", "fr", "it", "en"]
fracs = [0.629, 0.229, 0.084, 0.059]
panx_ch = defaultdict(DatasetDict) # キーがなければ DatasetDict が返る

for lang, frac in zip(langs, fracs):
  # 各言語コーパスをロード
  dataset = load_dataset("xtreme", name=f"PAN-X.{lang}")
  # 各分割ごとにシャッフルし、ダウンサンプリング
  for split in dataset:
    panx_ch[lang][split] = (
        dataset[split]
        .shuffle(seed=0)
        .select(range(int(frac*dataset[split].num_rows)))
    )

In [None]:
import pandas as pd

# 訓練データ数確認
pd.DataFrame({lang:[panx_ch[lang]["train"].num_rows] for lang in langs},index=["Number of training examples"])

In [None]:
# ドイツ語の事例
element = panx_ch["de"]["train"][0]
for key, value in element.items():
  print(f"{key}: {value}")

In [None]:
# 固有認識表現のラベル確認
for key, value in panx_ch["de"]["train"].features.items():
  print(f"{key}: {value}") # ner_tags がクラス名の List

In [None]:
# ラベル抽出
tags = panx_ch["de"]["train"].features["ner_tags"].feature
print(tags)

In [None]:
# ラベル設定
def create_tag_names(batch):
  return {"ner_tags_str": [tags.int2str(idx) for idx in batch["ner_tags"]]}

panx_de = panx_ch["de"].map(create_tag_names)

In [None]:
# トークンごとのタグを確認
de_example = panx_de["train"][0]
pd.DataFrame([de_example["tokens"], de_example["ner_tags_str"]],
             ['Tokens', 'Tags'])

In [None]:
# 各固有表現の頻度
from collections import Counter

split2freqs = defaultdict(Counter)
for split, dataset in panx_de.items():
  for row in dataset["ner_tags_str"]:
    for tag in row:
      if tag.startswith("B"): # 各固有表現に対し先頭を抜き出す
        tag_type = tag.split("-")[1]
        split2freqs[split][tag_type] += 1

pd.DataFrame.from_dict(split2freqs, orient="index")