In [13]:
import numpy as np
import torch
from transformers import SpeechT5ForSpeechToText#, SpeechT5Model, SpeechT5Config, SpeechT5PreTrainedModel

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [15]:
st5_asr = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr").to(device)

In [16]:
ckpt = torch.load("../checkpoints/speecht5_base.pt", map_location=device)

In [17]:
#ckpt['cfg']['task'].t5_task = 'pretrain'

In [18]:
len(ckpt['model'].keys())

464

## Mapping Speech Pre-net

In [7]:
# Initialize the mappings dictionary
speech_prenet_mapping = {}

# Iterate through st5_asr named_parameters
for name, _ in st5_asr.named_parameters():
    if name.startswith("speecht5.encoder.prenet"):
        for ckpt_name, ckpt_param in ckpt['model'].items():
            if ckpt_name.startswith("speech_encoder_prenet"):
                # Split the layer names based on '.' and '_'
                st5_asr_parts = name.split('.')
                ckpt_parts = ckpt_name.split('.')
                st5_asr_parts[-1] = st5_asr_parts[-1].split('_')[-1]
                ckpt_parts[-1] = ckpt_parts[-1].split('_')[-1]

                # Check if the layer names match
                if st5_asr_parts[-1] == ckpt_parts[-1]:
                    speech_prenet_mapping[name] = ckpt_name
                    break

# Print the mapping
for st5_asr_name, ckpt_name in speech_prenet_mapping.items():
    print(f"{st5_asr_name} -> {ckpt_name}")

speecht5.encoder.prenet.feature_encoder.conv_layers.0.conv.weight -> speech_encoder_prenet.feature_extractor.conv_layers.0.0.weight
speecht5.encoder.prenet.feature_encoder.conv_layers.0.layer_norm.weight -> speech_encoder_prenet.feature_extractor.conv_layers.0.0.weight
speecht5.encoder.prenet.feature_encoder.conv_layers.0.layer_norm.bias -> speech_encoder_prenet.feature_extractor.conv_layers.0.2.bias
speecht5.encoder.prenet.feature_encoder.conv_layers.1.conv.weight -> speech_encoder_prenet.feature_extractor.conv_layers.0.0.weight
speecht5.encoder.prenet.feature_encoder.conv_layers.2.conv.weight -> speech_encoder_prenet.feature_extractor.conv_layers.0.0.weight
speecht5.encoder.prenet.feature_encoder.conv_layers.3.conv.weight -> speech_encoder_prenet.feature_extractor.conv_layers.0.0.weight
speecht5.encoder.prenet.feature_encoder.conv_layers.4.conv.weight -> speech_encoder_prenet.feature_extractor.conv_layers.0.0.weight
speecht5.encoder.prenet.feature_encoder.conv_layers.5.conv.weight ->

## Mapping Encoder

In [28]:
encoder_mapping = {}
listt = []
hf_roi = 3
encoder_name = "speecht5.encoder.wrapped_encoder"

for name, _ in st5_asr.named_parameters():
    if name.startswith(encoder_name):
        listt.append(name.split(".")[hf_roi])

layers = set(listt)
#print(layers)

st5_hf_dict = {l:{} for l in layers}
#print(st5_hf_dict)

for layer_name in layers:
    temp_dict = {}
    for name, _ in st5_asr.named_parameters():    
        if name.startswith(encoder_name+"."+layer_name):
            temp_dict[name.split('.')[hf_roi+1]] = {}
    st5_hf_dict[layer_name] = temp_dict

print(st5_hf_dict)

for main_layer in st5_hf_dict.keys():
    for sublayer in st5_hf_dict[main_layer]:
        temp_dict = {}
        for name, _ in st5_asr.named_parameters():
            if name.startswith(encoder_name+"."+main_layer+"."+sublayer):
                if len(name.split('.')) <= hf_roi+2:
                    continue
                temp_dict[name.split('.')[hf_roi+2]] = {}
        st5_hf_dict[main_layer][sublayer] = temp_dict
        
print(st5_hf_dict)

{'embed_positions': {'pe_k': {}}, 'layers': {'0': {}, '1': {}, '2': {}, '3': {}, '4': {}, '5': {}, '6': {}, '7': {}, '8': {}, '9': {}, '10': {}, '11': {}}, 'layer_norm': {'weight': {}, 'bias': {}}}
{'embed_positions': {'pe_k': {'weight': {}}}, 'layers': {'0': {'attention': {}, 'layer_norm': {}, 'feed_forward': {}, 'final_layer_norm': {}}, '1': {'attention': {}, 'layer_norm': {}, 'feed_forward': {}, 'final_layer_norm': {}}, '2': {'attention': {}, 'layer_norm': {}, 'feed_forward': {}, 'final_layer_norm': {}}, '3': {'attention': {}, 'layer_norm': {}, 'feed_forward': {}, 'final_layer_norm': {}}, '4': {'attention': {}, 'layer_norm': {}, 'feed_forward': {}, 'final_layer_norm': {}}, '5': {'attention': {}, 'layer_norm': {}, 'feed_forward': {}, 'final_layer_norm': {}}, '6': {'attention': {}, 'layer_norm': {}, 'feed_forward': {}, 'final_layer_norm': {}}, '7': {'attention': {}, 'layer_norm': {}, 'feed_forward': {}, 'final_layer_norm': {}}, '8': {'attention': {}, 'layer_norm': {}, 'feed_forward': 

In [8]:
encoder_mapping = {}

#HF after wrapped_encoder: {'embed_positions', 'layer_norm', 'layers'}

# Iterate through st5_asr named_parameters
layer_num = 0
hf_roi = 3
bs_roi = 1
for name, _ in st5_asr.named_parameters():
    if name.startswith("speecht5.encoder.wrapped_encoder"):
        st5_asr_parts = name.split('.')
        st5_hf_layer = st5_asr_parts[hf_roi]
        #listt.append(st5_asr_parts[3])
        for ckpt_name, _ in ckpt['model'].items():
            if ckpt_name.startswith("encoder"):
                ckpt_parts = ckpt_name.split('.')
                print("HF:", st5_asr_parts)
                print("BS:", ckpt_parts)
                st5_bs_layer = ckpt_parts[bs_roi]
                
                if st5_hf_layer == st5_bs_layer:
                    st5_hf_layer_num = st5_asr_parts[hf_roi+1]
                    st5_bs_layer_num = ckpt_parts[bs_roi+1]
                    print("Here")
                    
                    if st5_bs_layer_num == st5_bs_layer_num:
                        st5_hf_layer_num_layer = st5_asr_parts[hf_roi+2]
                        st5_bs_layer_num_layer = ckpt_parts[bs_roi+2]
                        if st5_bs_layer_num_layer == "self_attn":
                            st5_bs_layer_num_layer = "attention"
                        print("Here")
                        
                        if st5_hf_layer_num_layer == st5_bs_layer_num_layer:
                            st5_hf_layer_num_layer_proj = st5_asr_parts[hf_roi+3]
                            st5_bs_layer_num_layer_proj = ckpt_parts[bs_roi+3]
                            print("Here")
                            
                            if st5_hf_layer_num_layer_proj == st5_bs_layer_num_layer_proj:
                                st5_hf_layer_num_layer_proj_param = st5_asr_parts[hf_roi+4]
                                st5_bs_layer_num_layer_proj_param = ckpt_parts[bs_roi+4]
                                print("Here")
                                
                                if st5_hf_layer_num_layer_proj_param == st5_bs_layer_num_layer_proj_param:
                                    encoder_mapping[name] = ckpt_name
                                    break
                        
                        

# Print the mapping
for st5_asr_name, ckpt_name in encoder_mapping.items():
    print(f"{st5_asr_name} -> {ckpt_name}")

HF: ['speecht5', 'encoder', 'wrapped_encoder', 'layer_norm', 'weight']
BS: ['encoder', 'version']
HF: ['speecht5', 'encoder', 'wrapped_encoder', 'layer_norm', 'weight']
BS: ['encoder', 'layers', '0', 'self_attn', 'k_proj', 'weight']
HF: ['speecht5', 'encoder', 'wrapped_encoder', 'layer_norm', 'weight']
BS: ['encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias']
HF: ['speecht5', 'encoder', 'wrapped_encoder', 'layer_norm', 'weight']
BS: ['encoder', 'layers', '0', 'self_attn', 'v_proj', 'weight']
HF: ['speecht5', 'encoder', 'wrapped_encoder', 'layer_norm', 'weight']
BS: ['encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias']
HF: ['speecht5', 'encoder', 'wrapped_encoder', 'layer_norm', 'weight']
BS: ['encoder', 'layers', '0', 'self_attn', 'q_proj', 'weight']
HF: ['speecht5', 'encoder', 'wrapped_encoder', 'layer_norm', 'weight']
BS: ['encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias']
HF: ['speecht5', 'encoder', 'wrapped_encoder', 'layer_norm', 'weight']
BS: ['encoder', 'layers', '

IndexError: list index out of range

In [15]:
encoder_mapping

{}

In [27]:
ctr1 = 1
for name, p in st5_asr.named_parameters():
    if name.startswith("speecht5.encoder.wrapped_encoder"):
        print(ctr1, name, p.size())
        ctr1 += 1

1 speecht5.encoder.wrapped_encoder.layer_norm.weight torch.Size([768])
2 speecht5.encoder.wrapped_encoder.layer_norm.bias torch.Size([768])
3 speecht5.encoder.wrapped_encoder.layers.0.attention.k_proj.weight torch.Size([768, 768])
4 speecht5.encoder.wrapped_encoder.layers.0.attention.k_proj.bias torch.Size([768])
5 speecht5.encoder.wrapped_encoder.layers.0.attention.v_proj.weight torch.Size([768, 768])
6 speecht5.encoder.wrapped_encoder.layers.0.attention.v_proj.bias torch.Size([768])
7 speecht5.encoder.wrapped_encoder.layers.0.attention.q_proj.weight torch.Size([768, 768])
8 speecht5.encoder.wrapped_encoder.layers.0.attention.q_proj.bias torch.Size([768])
9 speecht5.encoder.wrapped_encoder.layers.0.attention.out_proj.weight torch.Size([768, 768])
10 speecht5.encoder.wrapped_encoder.layers.0.attention.out_proj.bias torch.Size([768])
11 speecht5.encoder.wrapped_encoder.layers.0.layer_norm.weight torch.Size([768])
12 speecht5.encoder.wrapped_encoder.layers.0.layer_norm.bias torch.Size([7

In [44]:
compat = {}
ctr2 = 1
for name, p in ckpt['model'].items():
    if name.startswith("encoder"):
        print(ctr2, name, p.size())
        ctr2 += 1


1 encoder.version torch.Size([1])
2 encoder.layers.0.self_attn.k_proj.weight torch.Size([768, 768])
3 encoder.layers.0.self_attn.k_proj.bias torch.Size([768])
4 encoder.layers.0.self_attn.v_proj.weight torch.Size([768, 768])
5 encoder.layers.0.self_attn.v_proj.bias torch.Size([768])
6 encoder.layers.0.self_attn.q_proj.weight torch.Size([768, 768])
7 encoder.layers.0.self_attn.q_proj.bias torch.Size([768])
8 encoder.layers.0.self_attn.out_proj.weight torch.Size([768, 768])
9 encoder.layers.0.self_attn.out_proj.bias torch.Size([768])
10 encoder.layers.0.self_attn_layer_norm.weight torch.Size([768])
11 encoder.layers.0.self_attn_layer_norm.bias torch.Size([768])
12 encoder.layers.0.fc1.weight torch.Size([3072, 768])
13 encoder.layers.0.fc1.bias torch.Size([3072])
14 encoder.layers.0.fc2.weight torch.Size([768, 3072])
15 encoder.layers.0.fc2.bias torch.Size([768])
16 encoder.layers.0.final_layer_norm.weight torch.Size([768])
17 encoder.layers.0.final_layer_norm.bias torch.Size([768])
18 en

In [7]:
compat = {}
ctr2 = 1
for name, p in ckpt['model'].items():
    if name.startswith("speech_encoder_prenet"):
        print(ctr2, name, end=" -> ")
        parts = name.split(".")
        tmp = name.replace("speech", "speecht5")
        tmp = tmp.replace("_", ".", 2)
        tmp = tmp.replace("feature_extractor", "feature_encoder")
        tmp = tmp.replace(".0.weight", ".conv.weight")
        
        if tmp.endswith("mask_emb"):
            tmp = tmp.replace("mask_emb", "masked_spec_emb")
        
        print(tmp)
        ctr2 += 1

1 speech_encoder_prenet.mask_emb -> speecht5.encoder.prenet.masked_spec_emb
2 speech_encoder_prenet.feature_extractor.conv_layers.0.0.weight -> speecht5.encoder.prenet.feature_encoder.conv_layers.0.conv.weight
3 speech_encoder_prenet.feature_extractor.conv_layers.0.2.weight -> speecht5.encoder.prenet.feature_encoder.conv_layers.0.2.weight
4 speech_encoder_prenet.feature_extractor.conv_layers.0.2.bias -> speecht5.encoder.prenet.feature_encoder.conv_layers.0.2.bias
5 speech_encoder_prenet.feature_extractor.conv_layers.1.0.weight -> speecht5.encoder.prenet.feature_encoder.conv_layers.1.conv.weight
6 speech_encoder_prenet.feature_extractor.conv_layers.2.0.weight -> speecht5.encoder.prenet.feature_encoder.conv_layers.2.conv.weight
7 speech_encoder_prenet.feature_extractor.conv_layers.3.0.weight -> speecht5.encoder.prenet.feature_encoder.conv_layers.3.conv.weight
8 speech_encoder_prenet.feature_extractor.conv_layers.4.0.weight -> speecht5.encoder.prenet.feature_encoder.conv_layers.4.conv.wei

In [None]:
mapping = {}

for name, p in ckpt['model'].items():
    if name.startswith("speech_encoder_prenet"):
        

In [None]:
# Create a new state dictionary using the mapped layer names
new_state_dict = {}
for name, param in ckpt['model'].items():
    if name in compat:
        new_name = compat[name]
        new_state_dict[new_name] = param

# Load the new state dictionary into the Hugging Face model
st5_asr.model.load_state_dict(new_state_dict)

# Now you can use the Hugging Face model with some input text
input_text = "Your input text here"
output = st5_asr(input_text)
print(output)

## Mapping Speech Pre-net

In [37]:
# Initialize the mappings dictionary
speech_prenet_mapping = {}

# Iterate through st5_asr named_parameters
for name, _ in st5_asr.named_parameters():
    if name.startswith("speecht5.encoder.prenet"):
        for ckpt_name, ckpt_param in ckpt['model'].items():
            if ckpt_name.startswith("speech_encoder_prenet"):
                # Split the layer names based on '.' and '_'
                st5_asr_parts = name.split('.')
                ckpt_parts = ckpt_name.split('.')
                st5_asr_parts[-1] = st5_asr_parts[-1].split('_')[-1]
                ckpt_parts[-1] = ckpt_parts[-1].split('_')[-1]

                # Check if the layer names match
                if st5_asr_parts[-1] == ckpt_parts[-1]:
                    speech_prenet_mapping[name] = ckpt_name
                    break

# Print the mapping
for st5_asr_name, ckpt_name in speech_prenet_mapping.items():
    print(f"{st5_asr_name} -> {ckpt_name}")

speecht5.encoder.prenet.feature_encoder.conv_layers.0.conv.weight -> speech_encoder_prenet.feature_extractor.conv_layers.0.0.weight
speecht5.encoder.prenet.feature_encoder.conv_layers.0.layer_norm.weight -> speech_encoder_prenet.feature_extractor.conv_layers.0.0.weight
speecht5.encoder.prenet.feature_encoder.conv_layers.0.layer_norm.bias -> speech_encoder_prenet.feature_extractor.conv_layers.0.2.bias
speecht5.encoder.prenet.feature_encoder.conv_layers.1.conv.weight -> speech_encoder_prenet.feature_extractor.conv_layers.0.0.weight
speecht5.encoder.prenet.feature_encoder.conv_layers.2.conv.weight -> speech_encoder_prenet.feature_extractor.conv_layers.0.0.weight
speecht5.encoder.prenet.feature_encoder.conv_layers.3.conv.weight -> speech_encoder_prenet.feature_extractor.conv_layers.0.0.weight
speecht5.encoder.prenet.feature_encoder.conv_layers.4.conv.weight -> speech_encoder_prenet.feature_extractor.conv_layers.0.0.weight
speecht5.encoder.prenet.feature_encoder.conv_layers.5.conv.weight ->

## Mapping Encoder

In [46]:
encoder_mapping = {}

listt = []
#HF after wrapped_encoder: {'embed_positions', 'layer_norm', 'layers'}



# Iterate through st5_asr named_parameters
layer_num = 0
for name, _ in st5_asr.named_parameters():
    if name.startswith("speecht5.encoder.wrapped_encoder"):
        st5_asr_parts = name.split('.')
        listt.append(st5_asr_parts[3])
        #for ckpt_name, _ in ckpt['model'].items():
        #    if ckpt_name.startswith("encoder"):
        #        st5_asr_parts = name.split('.')
        #        ckpt_parts = ckpt_name.split('.')
        #        print(st5_asr_parts)

sett = set(listt)
sett

{'embed_positions', 'layer_norm', 'layers'}

In [36]:
# Initialize the mappings dictionary
mapping = {}

# Iterate through ckpt['model'] items
for ckpt_name, ckpt_param in ckpt['model'].items():
    if ckpt_name.startswith("encoder.layers."):
        # Get the layer number from the ckpt_name
        ckpt_layer_num = int(ckpt_name.split(".")[2])

        # Iterate through st5_asr named_parameters
        for name, _ in st5_asr.named_parameters():
            if name.startswith("speecht5.encoder.wrapped_encoder.layers."):
                # Get the layer number from the st5_asr name
                st5_asr_layer_num = int(name.split(".")[-2])

                # Check if the layer numbers match
                if ckpt_layer_num == st5_asr_layer_num:
                    mapping[name] = ckpt_name
                    break

# Print the mapping
for st5_asr_name, ckpt_name in mapping.items():
    print(f"{st5_asr_name} -> {ckpt_name}")

speecht5.encoder.wrapped_encoder.layers.0.attention.k_proj.weight -> encoder.layers.0.norm_k.bias
speecht5.encoder.wrapped_encoder.layers.1.attention.k_proj.weight -> encoder.layers.11.norm_k.bias
speecht5.encoder.wrapped_encoder.layers.2.attention.k_proj.weight -> encoder.layers.2.norm_k.bias
speecht5.encoder.wrapped_encoder.layers.3.attention.k_proj.weight -> encoder.layers.3.norm_k.bias
speecht5.encoder.wrapped_encoder.layers.4.attention.k_proj.weight -> encoder.layers.4.norm_k.bias
speecht5.encoder.wrapped_encoder.layers.5.attention.k_proj.weight -> encoder.layers.5.norm_k.bias
speecht5.encoder.wrapped_encoder.layers.6.attention.k_proj.weight -> encoder.layers.6.norm_k.bias
speecht5.encoder.wrapped_encoder.layers.7.attention.k_proj.weight -> encoder.layers.7.norm_k.bias
speecht5.encoder.wrapped_encoder.layers.8.attention.k_proj.weight -> encoder.layers.8.norm_k.bias
speecht5.encoder.wrapped_encoder.layers.9.attention.k_proj.weight -> encoder.layers.9.norm_k.bias
speecht5.encoder.wr

In [35]:
mapping

{}

In [None]:
# Iterate through all layer numbers from 0 to 11 for both models
for layer_num in range(12):
    ckpt_layer_prefix = f"encoder.layers.{layer_num}"
    st5_asr_layer_prefix = f"speecht5.encoder.wrapped_encoder.layers.{layer_num}"

    # Iterate through ckpt['model'] items
    for ckpt_name, ckpt_param in ckpt['model'].items():
        if ckpt_name.startswith(ckpt_layer_prefix):
            # Iterate through st5_asr named_parameters
            for name, _ in st5_asr.named_parameters():
                if name.startswith(st5_asr_layer_prefix):
                    mapping[name] = ckpt_name
                    break

# Print the mapping
for st5_asr_name, ckpt_name in mapping.items():
    print(f"{st5_asr_name} -> {ckpt_name}")