In [2]:
import torch
from wenet.transformer.decoder import BiTransformerDecoder

def instantiate_and_compare_verbose(checkpoint_path):
    ckpt = torch.load(f"{checkpoint_path}/pytorch_model.bin", map_location="cpu")

    # === Infer hyperparameters ===
    vocab_size = ckpt["decoder.left_decoder.embed.0.weight"].size(0)
    d_model = ckpt["decoder.left_decoder.after_norm.weight"].size(0)
    block_ids = {int(k.split('.')[3]) for k in ckpt if k.startswith("decoder.left_decoder.decoders.")}
    num_blocks = max(block_ids) + 1

    decoder = BiTransformerDecoder(
        vocab_size=vocab_size,
        encoder_output_size=d_model,
        attention_heads=8,
        linear_units=2048,
        num_blocks=num_blocks,
        r_num_blocks=0,
        dropout_rate=0.1,
        positional_dropout_rate=0.1,
        self_attention_dropout_rate=0.0,
        src_attention_dropout_rate=0.0,
        input_layer="embed",
        use_output_layer=True,
        normalize_before=True,
        src_attention=True,
        query_bias=True,
        key_bias=True,
        value_bias=True,
        activation_type="relu",
        gradient_checkpointing=False,
        tie_word_embedding=False,
        use_sdpa=False,
        layer_norm_type='layer_norm',
        norm_eps=1e-5,
        n_kv_head=None,
        head_dim=None,
        mlp_type='position_wise_feed_forward',
        mlp_bias=True,
        n_expert=8,
        n_expert_activated=2
    )

    model_sd = decoder.left_decoder.state_dict()
    ckpt_sd = {
        k[len("decoder.left_decoder."):]: v
        for k, v in ckpt.items()
        if k.startswith("decoder.left_decoder.")
    }

    missing, mismatched, extra, matched = [], [], [], []

    for name, param in model_sd.items():
        if name not in ckpt_sd:
            missing.append(name)
        elif tuple(param.shape) != tuple(ckpt_sd[name].shape):
            mismatched.append((name, tuple(param.shape), tuple(ckpt_sd[name].shape)))
        else:
            matched.append(name)

    extra = sorted(set(ckpt_sd.keys()) - set(model_sd.keys()))

    # === Print diagnostics ===
    print("=== 🧪 Decoder Checkpoint Verification ===")
    print(f"🔹 Total model params: {len(model_sd)}")
    print(f"✅ Matched params     : {len(matched)}")
    print(f"❌ Missing params     : {len(missing)}")
    print(f"❌ Mismatched shapes  : {len(mismatched)}")
    print(f"⚠️ Extra params in ckpt: {len(extra)}")
    print("\n✅ Matching Keys:")
    for name in matched:
        print(f"   + {name}")

    if missing:
        print("\n❌ Missing in checkpoint:")
        for name in missing:
            print(f"   - {name}")

    if mismatched:
        print("\n❌ Shape mismatches:")
        for name, shape_model, shape_ckpt in mismatched:
            print(f"   - {name}: model={shape_model}, ckpt={shape_ckpt}")

    if extra:
        print("\n⚠️ Extra keys in checkpoint:")
        for name in extra:
            print(f"   - {name}")

# Usage
path = "../../chunkformer-large-vie"
instantiate_and_compare_verbose(path)


=== 🧪 Decoder Checkpoint Verification ===
🔹 Total model params: 84
✅ Matched params     : 83
❌ Missing params     : 1
❌ Mismatched shapes  : 0
⚠️ Extra params in ckpt: 0

✅ Matching Keys:
   + embed.0.weight
   + after_norm.weight
   + after_norm.bias
   + output_layer.weight
   + output_layer.bias
   + decoders.0.self_attn.linear_q.weight
   + decoders.0.self_attn.linear_q.bias
   + decoders.0.self_attn.linear_k.weight
   + decoders.0.self_attn.linear_k.bias
   + decoders.0.self_attn.linear_v.weight
   + decoders.0.self_attn.linear_v.bias
   + decoders.0.self_attn.linear_out.weight
   + decoders.0.self_attn.linear_out.bias
   + decoders.0.src_attn.linear_q.weight
   + decoders.0.src_attn.linear_q.bias
   + decoders.0.src_attn.linear_k.weight
   + decoders.0.src_attn.linear_k.bias
   + decoders.0.src_attn.linear_v.weight
   + decoders.0.src_attn.linear_v.bias
   + decoders.0.src_attn.linear_out.weight
   + decoders.0.src_attn.linear_out.bias
   + decoders.0.feed_forward.w_1.weight
   +