In [8]:
from typing import Literal
from dataclasses import dataclass


@dataclass
class ModelArchConfig:
    encoder_block_nums:int
    decoder_block_nums:int


SUPPORTED_MODEL_ARCHS = {
    # 'tiny',
    # 'base',
    # 'small',
    'medium': ModelArchConfig(24, 24),
    'large': ModelArchConfig(32, 32),
}

WHISPER_HF_ENCODER_BLOCK_MAPPING = {
    'encoder.blocks.{i}.attn.query.weight' :'model.encoder.layers.{i}.self_attn.q_proj.weight',
    'encoder.blocks.{i}.attn.query.bias'   :'model.encoder.layers.{i}.self_attn.q_proj.bias',
    'encoder.blocks.{i}.attn.key.weight'   :'model.encoder.layers.{i}.self_attn.k_proj.weight',
    'encoder.blocks.{i}.attn.value.weight' :'model.encoder.layers.{i}.self_attn.v_proj.weight',
    'encoder.blocks.{i}.attn.value.bias'   :'model.encoder.layers.{i}.self_attn.v_proj.bias',
    'encoder.blocks.{i}.attn.out.weight'   :'model.encoder.layers.{i}.self_attn.out_proj.weight',
    'encoder.blocks.{i}.attn.out.bias'     :'model.encoder.layers.{i}.self_attn.out_proj.bias',
    'encoder.blocks.{i}.attn_ln.weight'    :'model.encoder.layers.{i}.self_attn_layer_norm.weight',
    'encoder.blocks.{i}.attn_ln.bias'      :'model.encoder.layers.{i}.self_attn_layer_norm.bias',
    'encoder.blocks.{i}.mlp.0.weight'      :'model.encoder.layers.{i}.fc1.weight',
    'encoder.blocks.{i}.mlp.0.bias'        :'model.encoder.layers.{i}.fc1.bias',
    'encoder.blocks.{i}.mlp.2.weight'      :'model.encoder.layers.{i}.fc2.weight',
    'encoder.blocks.{i}.mlp.2.bias'        :'model.encoder.layers.{i}.fc2.bias',
    'encoder.blocks.{i}.mlp_ln.weight'     :'model.encoder.layers.{i}.final_layer_norm.weight',
    'encoder.blocks.{i}.mlp_ln.bias'       :'model.encoder.layers.{i}.final_layer_norm.bias',
}


WHISPER_HF_DECODER_BLOCK_MAPPING = {
    'decoder.blocks.{i}.attn.query.weight'       :'model.decoder.layers.{i}.self_attn.q_proj.weight',
    'decoder.blocks.{i}.attn.query.bias'         :'model.decoder.layers.{i}.self_attn.q_proj.bias',
    'decoder.blocks.{i}.attn.key.weight'         :'model.decoder.layers.{i}.self_attn.k_proj.weight',
    'decoder.blocks.{i}.attn.value.weight'       :'model.decoder.layers.{i}.self_attn.v_proj.weight',
    'decoder.blocks.{i}.attn.value.bias'         :'model.decoder.layers.{i}.self_attn.v_proj.bias',
    'decoder.blocks.{i}.attn.out.weight'         :'model.decoder.layers.{i}.self_attn.out_proj.weight',
    'decoder.blocks.{i}.attn.out.bias'           :'model.decoder.layers.{i}.self_attn.out_proj.bias',
    'decoder.blocks.{i}.attn_ln.weight'          :'model.decoder.layers.{i}.self_attn_layer_norm.weight',
    'decoder.blocks.{i}.attn_ln.bias'            :'model.decoder.layers.{i}.self_attn_layer_norm.bias',
    'decoder.blocks.{i}.cross_attn.query.weight' :'model.decoder.layers.{i}.encoder_attn.q_proj.weight',
    'decoder.blocks.{i}.cross_attn.query.bias'   :'model.decoder.layers.{i}.encoder_attn.q_proj.bias',
    'decoder.blocks.{i}.cross_attn.key.weight'   :'model.decoder.layers.{i}.encoder_attn.k_proj.weight',
    'decoder.blocks.{i}.cross_attn.value.weight' :'model.decoder.layers.{i}.encoder_attn.v_proj.weight',
    'decoder.blocks.{i}.cross_attn.value.bias'   :'model.decoder.layers.{i}.encoder_attn.v_proj.bias',
    'decoder.blocks.{i}.cross_attn.out.weight'   :'model.decoder.layers.{i}.encoder_attn.out_proj.weight',
    'decoder.blocks.{i}.cross_attn.out.bias'     :'model.decoder.layers.{i}.encoder_attn.out_proj.bias',
    'decoder.blocks.{i}.cross_attn_ln.weight'    :'model.decoder.layers.{i}.encoder_attn_layer_norm.weight',
    'decoder.blocks.{i}.cross_attn_ln.bias'      :'model.decoder.layers.{i}.encoder_attn_layer_norm.bias',
    'decoder.blocks.{i}.mlp.0.weight'            :'model.decoder.layers.{i}.fc1.weight',
    'decoder.blocks.{i}.mlp.0.bias'              :'model.decoder.layers.{i}.fc1.bias',
    'decoder.blocks.{i}.mlp.2.weight'            :'model.decoder.layers.{i}.fc2.weight',
    'decoder.blocks.{i}.mlp.2.bias'              :'model.decoder.layers.{i}.fc2.bias',
    'decoder.blocks.{i}.mlp_ln.weight'           :'model.decoder.layers.{i}.final_layer_norm.weight',
    'decoder.blocks.{i}.mlp_ln.bias'             :'model.decoder.layers.{i}.final_layer_norm.bias',
}

WHISPER_HF_MAPPING = {

    # encoder
    'encoder.conv1.weight' :'model.encoder.conv1.weight',
    'encoder.conv1.bias'   :'model.encoder.conv1.bias',
    'encoder.conv2.weight' :'model.encoder.conv2.weight',
    'encoder.conv2.bias'   :'model.encoder.conv2.bias',
    # ???:'model.encoder.embed_positions.weight',
    **{block_name.format(i=i):hf_block_name.format(i=i)
     for i in range(24)
     for block_name, hf_block_name in WHISPER_HF_ENCODER_BLOCK_MAPPING.items()},
    'encoder.ln_post.weight':'model.encoder.layer_norm.weight',
    'encoder.ln_post.bias'  :'model.encoder.layer_norm.bias',

    # decoder
    'decoder.positional_embedding':'model.decoder.embed_tokens.weight',
    'decoder.token_embedding.weight':'model.decoder.embed_positions.weight',
    **{block_name.format(i=i):hf_block_name.format(i=i)
     for i in range(24)
     for block_name, hf_block_name in WHISPER_HF_DECODER_BLOCK_MAPPING.items()},
    'decoder.ln.weight'     : 'model.decoder.layer_norm.weight',
    'decoder.ln.bias'       : 'model.decoder.layer_norm.bias',
}


class WhisperHFArchMapping:
    def __init__(self, model_arch):
        self.model_arch = model_arch
        self.encoder_block_nums = SUPPORTED_MODEL_ARCHS[model_arch].encoder_block_nums
        self.decoder_block_nums = SUPPORTED_MODEL_ARCHS[model_arch].decoder_block_nums



In [9]:
WHISPER_HF_MAPPING

{'encoder.conv1.weight': 'model.encoder.conv1.weight',
 'encoder.conv1.bias': 'model.encoder.conv1.bias',
 'encoder.conv2.weight': 'model.encoder.conv2.weight',
 'encoder.conv2.bias': 'model.encoder.conv2.bias',
 'encoder.blocks.0.attn.query.weight': 'model.encoder.layers.0.self_attn.q_proj.weight',
 'encoder.blocks.0.attn.query.bias': 'model.encoder.layers.0.self_attn.q_proj.bias',
 'encoder.blocks.0.attn.key.weight': 'model.encoder.layers.0.self_attn.k_proj.weight',
 'encoder.blocks.0.attn.value.weight': 'model.encoder.layers.0.self_attn.v_proj.weight',
 'encoder.blocks.0.attn.value.bias': 'model.encoder.layers.0.self_attn.v_proj.bias',
 'encoder.blocks.0.attn.out.weight': 'model.encoder.layers.0.self_attn.out_proj.weight',
 'encoder.blocks.0.attn.out.bias': 'model.encoder.layers.0.self_attn.out_proj.bias',
 'encoder.blocks.0.attn_ln.weight': 'model.encoder.layers.0.self_attn_layer_norm.weight',
 'encoder.blocks.0.attn_ln.bias': 'model.encoder.layers.0.self_attn_layer_norm.bias',
 'e