# BERT 英↔日翻訳 ― TAID 蒸留 + EfQAT 量子化 (W4A8)
Google Colab (T4 GPU) 用 サンプルノートブック


このノートブックでは **BERT2BERT** を Teacher、**TinyBERT‑4L** を Student とし、

1. JParaCrawl / JESC コーパスの取得・前処理  
2. Teacher 追加学習  
3. TAID (Temporally Adaptive Interpolated Distillation) による蒸留  
4. **EfQAT** (CWPN) で 4‑bit 重み / 8‑bit 活性化 (**W4A8**) 量子化  
5. BLEU・BERTScore 評価  
6. Hugging Face Hub へのアップロード  

をワンショットで実行します。


In [1]:
# Colab で一度だけ実行してください
!pip install -U "datasets>=2.19" "huggingface_hub>=0.23" "fsspec>=2024.3" "transformers>=4.40" sacrebleu bert_score accelerate sentencepiece

Collecting datasets>=2.19
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting huggingface_hub>=0.23
  Downloading huggingface_hub-0.32.2-py3-none-any.whl.metadata (14 kB)
Collecting fsspec>=2024.3
  Downloading fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)
Collecting transformers>=4.40
  Downloading transformers-4.52.3-py3-none-any.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.2/40.2 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sacrebleu
  Downloading sacrebleu-2.5.1-py3-none-any.whl.metadata (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting bert_score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Collecting accelerate
  Downloading accelerate-1.7.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets>=2.19)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 

In [4]:
import torch, os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("使用デバイス:", device)

使用デバイス: cuda


## 1. データセット取得

In [3]:
from datasets import load_dataset, concatenate_datasets   # ← ここを追加

# ① JParaCrawl（英語列＝english、日本語列＝japanese）
jpara = (
    load_dataset(
        "Verah/JParaCrawl-Filtered-English-Japanese-Parallel-Corpus",
        split="train"
    )
    .shuffle(seed=42)
    .select(range(50_000))                       # デモなので 5 万行
    .rename_columns({"english": "en", "japanese": "ja"})
)

# ② JESC（translation → en / ja に分離する）
jesc_raw = (
    load_dataset("Hoshikuzu/JESC", split="train")
    .shuffle(seed=42)
    .select(range(50_000))
)

def split_translation(example):
    return {
        "en": example["translation"]["en"],
        "ja": example["translation"]["ja"]
    }

jesc = jesc_raw.map(split_translation, remove_columns=["translation"])

# ③ 形がそろったのでそのまま結合
raw_ds = concatenate_datasets([jpara, jesc])

# ④ train / test 分割
dataset = raw_ds.train_test_split(test_size=0.1, seed=42)
print(dataset)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/5.33k [00:00<?, ?B/s]

1m_filtered.tsv.gz:   0%|          | 0.00/104M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000000 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/2.50k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/175M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2801388 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'en', 'ja', 'model1_accepted', 'model2_accepted'],
        num_rows: 90000
    })
    test: Dataset({
        features: ['id', 'en', 'ja', 'model1_accepted', 'model2_accepted'],
        num_rows: 10000
    })
})


## 2. トークナイズ

In [4]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
max_len = 128

def preprocess(batch):
    src = tokenizer(batch["en"], padding="max_length", truncation=True, max_length=max_len)
    tgt = tokenizer(batch["ja"], padding="max_length", truncation=True, max_length=max_len)
    batch["input_ids"] = src.input_ids
    batch["attention_mask"] = src.attention_mask
    labels = tgt.input_ids
    batch["labels"] = [[tok if tok!=tokenizer.pad_token_id else -100 for tok in seq] for seq in labels]
    return batch

# ここを ds ではなく dataset に変更
train_ds = dataset["train"].map(preprocess, batched=True, remove_columns=["en", "ja"])
val_ds   = dataset["test"].map(preprocess,  batched=True, remove_columns=["en", "ja"])


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

Map:   0%|          | 0/90000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

## 3. Teacher (BERT2BERT) 追加学習

In [5]:
from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

enc_cfg = BertConfig.from_pretrained("bert-base-multilingual-cased")
dec_cfg = BertConfig.from_pretrained("bert-base-multilingual-cased")
dec_cfg.is_decoder=True; dec_cfg.add_cross_attention=True

model_cfg = EncoderDecoderConfig.from_encoder_decoder_configs(enc_cfg, dec_cfg)
teacher = EncoderDecoderModel(config=model_cfg)
teacher.config.decoder_start_token_id = tokenizer.cls_token_id
teacher.config.eos_token_id = tokenizer.sep_token_id
teacher.config.pad_token_id = tokenizer.pad_token_id

collator = DataCollatorForSeq2Seq(tokenizer, model=teacher)
args = Seq2SeqTrainingArguments(
    output_dir="teacher_out",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1,           # デモなので 1 epoch
    fp16=True,
    logging_steps=100,
    save_total_limit=1,
)

trainer = Seq2SeqTrainer(
    model=teacher,
    args=args,
    train_dataset=train_ds,
    tokenizer=tokenizer,
    data_collator=collator
)
trainer.train()

  trainer = Seq2SeqTrainer(
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mfukayatti0[0m ([33mfukayatti0-national-institue-of-technology-ibaraki-college[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Step,Training Loss
100,8.9261
200,6.8006
300,6.4895
400,6.3028
500,6.1481
600,5.9372
700,5.8499
800,5.6535
900,5.5606
1000,5.548




TrainOutput(global_step=11250, training_loss=4.302789946831597, metrics={'train_runtime': 5524.6139, 'train_samples_per_second': 16.291, 'train_steps_per_second': 2.036, 'total_flos': 1.38088706688e+16, 'train_loss': 4.302789946831597, 'epoch': 1.0})

## 4. TAID 蒸留で TinyBERT Student 学習

In [17]:
from torch.utils.data import DataLoader
tiny_enc = BertConfig.from_pretrained("bert-base-multilingual-cased", num_hidden_layers=4)
tiny_dec = BertConfig.from_pretrained("bert-base-multilingual-cased", num_hidden_layers=4, is_decoder=True, add_cross_attention=True)
student_cfg = EncoderDecoderConfig.from_encoder_decoder_configs(tiny_enc, tiny_dec)
student = EncoderDecoderModel(config=student_cfg).to(device)
student.config.decoder_start_token_id = tokenizer.cls_token_id
student.config.eos_token_id = tokenizer.sep_token_id
student.config.pad_token_id = tokenizer.pad_token_id

teacher.to(device); teacher.eval()

loader = DataLoader(train_ds, batch_size=8, shuffle=True)
optim = torch.optim.Adam(student.parameters(), lr=5e-5)

alpha, alpha_end, m, beta = 0.2, 1.0, 0.0, 0.9
steps = 1000   # デモ短縮
for step, batch in enumerate(loader):
    if step>=steps: break
    input_ids = torch.tensor(batch["input_ids"]).to(device)
    attn = torch.tensor(batch["attention_mask"]).to(device)
    labels = torch.tensor(batch["labels"]).to(device)

    # Student forward
    out_s = student(input_ids=input_ids, attention_mask=attn, labels=labels)
    logits_s = out_s.logits
    # Teacher forward
    with torch.no_grad():
        logits_t = teacher(input_ids=input_ids, attention_mask=attn, labels=labels).logits

    p_s = torch.softmax(logits_s, dim=-1)
    p_t = torch.softmax(logits_t, dim=-1)
    p_mid = alpha*p_t + (1-alpha)*p_s

    loss = torch.nn.functional.kl_div(torch.log_softmax(logits_s, dim=-1), p_mid, reduction="batchmean")

    optim.zero_grad()
    loss.backward()
    optim.step()

    # α 更新
    target = alpha_end*(step+1)/steps
    m = beta*m + (1-beta)*(target-alpha)
    alpha += 0.005*m

NameError: name 'jpara' is not defined

## 5. EfQAT (CWPN) で 4bit 量子化

In [None]:
freeze_ratio = 0.9; interval = 4096
loader_q = DataLoader(train_ds, batch_size=8, shuffle=True)
optim_q = torch.optim.Adam(student.parameters(), lr=3e-5)

thresh = None
for step, batch in enumerate(loader_q):
    if step>=500: break
    ids = torch.tensor(batch["input_ids"]).to(device)
    attn = torch.tensor(batch["attention_mask"]).to(device)
    labels = torch.tensor(batch["labels"]).to(device)

    out = student(input_ids=ids, attention_mask=attn, labels=labels)
    loss_q = out.loss
    optim_q.zero_grad(); loss_q.backward()

    # 閾値更新
    if step % interval == 0:
        imps=[]
        for mod in student.modules():
            if isinstance(mod, torch.nn.Linear):
                imps.append(mod.weight.data.abs().mean(dim=1).cpu())
        thresh = torch.cat(imps).kthvalue(int(len(torch.cat(imps))*freeze_ratio)).values.item()

    # 勾配マスク
    for mod in student.modules():
        if isinstance(mod, torch.nn.Linear) and mod.weight.grad is not None:
            imp = mod.weight.data.abs().mean(dim=1)
            mask = (imp >= thresh).float().unsqueeze(1).to(device)
            mod.weight.grad.mul_(mask)
    optim_q.step()


In [None]:
# W4A8 量子化 (簡易実装)
for mod in student.modules():
    if isinstance(mod, torch.nn.Linear):
        w = mod.weight.data
        s = (2**3-1)/w.abs().max()
        mod.weight.data = torch.round(w*s)/s
print("量子化完了!")

## 6. BLEU・BERTScore 評価

In [None]:
from sacrebleu import corpus_bleu
from bert_score import score as bert_score
student.eval(); teacher.eval()
samples = ds["test"].select(range(100))
pred_t, pred_s, refs = [],[],[]
for ex in samples:
    ids = tokenizer(ex["en"], return_tensors="pt", truncation=True, padding=True).input_ids.to(device)
    with torch.no_grad():
        pt = teacher.generate(ids, max_length=128)
        ps = student.generate(ids, max_length=128)
    pred_t.append(tokenizer.decode(pt[0], skip_special_tokens=True))
    pred_s.append(tokenizer.decode(ps[0], skip_special_tokens=True))
    refs.append(ex["ja"])
print("BLEU Teacher:", corpus_bleu(pred_t,[refs]).score)
print("BLEU Student:", corpus_bleu(pred_s,[refs]).score)
P,R,F = bert_score(pred_s, refs, lang='ja', model_type='bert-base-multilingual-cased')
print("BERTScore Student:", F.mean().item())

## 7. Hugging Face Hub へアップロード

In [None]:
from huggingface_hub import notebook_login
# notebook_login() でトークン認証
# student.push_to_hub("your-account/tinybert4L-en-ja-w4a8", private=True)
