<a href="https://colab.research.google.com/github/kokist/modern-bert-sample/blob/main/modern_bert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 環境構築

In [1]:
!pip install git+https://github.com/huggingface/transformers.git
!pip install triton

Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-d4nk1h3_
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-d4nk1h3_
  Resolved https://github.com/huggingface/transformers.git to commit d5aebc64653d09660818109f2fac55b5e1031023
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: transformers
  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
  Created wheel for transformers: filename=transformers-4.48.0.dev0-py3-none-any.whl size=10328720 sha256=242a4fc62d7bdb6ac31b5a4a4dbf90d38400b16fa211ce2fe4aae1b0d390fda0
  Stored in directory: /tmp/pip-ephem-wheel-cache-bgoy630j/wheels/e7/9c/5b/e1a9c8007c343041e61cc484433d512ea9274272e3fcbe7c16
Successfully b

## コード

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForMaskedLM.from_pretrained(model_id, torch_dtype=torch.float16)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print(device)
print(torch.cuda.get_device_name(0))  # GPUのモデル名を表示

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.13M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.19k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/599M [00:00<?, ?B/s]

cuda
Tesla T4


In [3]:
def predict_masked_tokens(text):
    """
    入力文に含まれるすべての[MASK]トークンの予測を行う。
    :param text: [MASK]トークンを含む文字列
    :return: 予測されたトークンのリスト
    """
    # 入力文のトークナイズ
    inputs = tokenizer(text, return_tensors="pt")
    inputs = {key: value.to(device) for key, value in inputs.items()}

    # 推論実行
    with torch.no_grad():
        with torch.autocast(device_type=device.type, dtype=torch.float16):
            outputs = model(**inputs)

    # 入力文中のすべての[MASK]トークンのインデックスを取得
    input_ids = inputs["input_ids"][0].tolist()
    mask_token_indices = [i for i, token_id in enumerate(input_ids) if token_id == tokenizer.mask_token_id]

    # 各[MASK]トークンに対する予測トークンを取得
    predicted_tokens = []
    for masked_index in mask_token_indices:
        logits = outputs.logits[0, masked_index]
        predicted_token_id = logits.argmax(axis=-1).item()
        predicted_token = tokenizer.decode(predicted_token_id)
        predicted_tokens.append(predicted_token)

    return predicted_tokens

In [4]:
text1 = "The capital of France is [MASK]."
text2 = "The capital of France [MASK] [MASK]."
text3 = "No pain, [MASK] gain."
text4 = "No pain, no [MASK]."

print(f"入力: {text1} -> 予測: {predict_masked_tokens(text1)}")
print(f"入力: {text2} -> 予測: {predict_masked_tokens(text2)}")
print(f"入力: {text3} -> 予測: {predict_masked_tokens(text3)}")
print(f"入力: {text4} -> 予測: {predict_masked_tokens(text4)}")

入力: The capital of France is [MASK]. -> 予測: [' Paris']
入力: The capital of France [MASK] [MASK]. -> 予測: [' is', ' Paris']
入力: No pain, [MASK] gain. -> 予測: [' no']
入力: No pain, no [MASK]. -> 予測: [' gain']
