In [1]:
import torch
import pickle
import json

path = "../checkpoints/pytorch_model.bin"

In [2]:
ckpt = torch.load(path, map_location="cpu")

In [3]:
list(ckpt.keys())

['wav2vec2.masked_spec_embed',
 'wav2vec2.feature_extractor.conv_layers.0.conv.weight',
 'wav2vec2.feature_extractor.conv_layers.0.layer_norm.weight',
 'wav2vec2.feature_extractor.conv_layers.0.layer_norm.bias',
 'wav2vec2.feature_extractor.conv_layers.1.conv.weight',
 'wav2vec2.feature_extractor.conv_layers.2.conv.weight',
 'wav2vec2.feature_extractor.conv_layers.3.conv.weight',
 'wav2vec2.feature_extractor.conv_layers.4.conv.weight',
 'wav2vec2.feature_extractor.conv_layers.5.conv.weight',
 'wav2vec2.feature_extractor.conv_layers.6.conv.weight',
 'wav2vec2.feature_projection.layer_norm.weight',
 'wav2vec2.feature_projection.layer_norm.bias',
 'wav2vec2.feature_projection.projection.weight',
 'wav2vec2.feature_projection.projection.bias',
 'wav2vec2.encoder.pos_conv_embed.conv.bias',
 'wav2vec2.encoder.pos_conv_embed.conv.weight_g',
 'wav2vec2.encoder.pos_conv_embed.conv.weight_v',
 'wav2vec2.encoder.layer_norm.weight',
 'wav2vec2.encoder.layer_norm.bias',
 'wav2vec2.encoder.layers.0.

In [4]:
import sys
sys.path.append('..')
from src.model.wav2vec2 import Wav2Vec2PretrainingModule

In [5]:
from omegaconf import OmegaConf
conf = OmegaConf.load("../configs/wav2vec2-base-pretraining.yaml")

In [6]:
conf

{'model': {'feature_extractor': {'in_channels': 1, 'hidden_channels': [512, 512, 512, 512, 512, 512, 512], 'kernel_sizes': [10, 3, 3, 3, 3, 2, 2], 'strides': [5, 2, 2, 2, 2, 2, 2]}, 'context_encoder': {'d_model': 768, 'feature_projection': {'in_features': 512, 'dropout': 0.1}, 'transformer_encoder': {'pos_embedding': {'kernel_size': 3, 'groups': 2}, 'enc_layer': {'num_heads': 8, 'layer_norm_first': False, 'feed_forward_dim': 2048, 'dropout': 0.1}, 'num_enc_layers': 12, 'layer_drop_prob': 0.05, 'dropout': 0.1}}, 'quantizer': {'in_features': 512, 'num_codebooks': 2, 'num_codewords': 320, 'd_model': 768}, 'train_cfg': {'mask_prob': 0.65, 'mask_length': 10, 'min_masks': 2, 'num_negatives': 100, 'contrastive_logits_temperature': 0.1, 'diversity_loss_weight': 0.2}, 'optimizer': {'_target_': 'torch.optim.Adam', 'lr': 0.1, 'weight_decay': 0.0001}, 'lr_scheduler': {'scheduler': {'_target_': 'torch.optim.lr_scheduler.CosineAnnealingLR', 'T_max': 20}}}}

In [7]:
model = Wav2Vec2PretrainingModule(**conf.model)

In [8]:
list(model.state_dict().keys())

['context_encoder.masked_spec_embed',
 'context_encoder.feature_projection.layer_norm.weight',
 'context_encoder.feature_projection.layer_norm.bias',
 'context_encoder.feature_projection.projection.weight',
 'context_encoder.feature_projection.projection.bias',
 'context_encoder.encoder.pos_conv_embed.conv.bias',
 'context_encoder.encoder.pos_conv_embed.conv.weight_g',
 'context_encoder.encoder.pos_conv_embed.conv.weight_v',
 'context_encoder.encoder.layer_norm.weight',
 'context_encoder.encoder.layer_norm.bias',
 'context_encoder.encoder.layers.0.attention.k_proj.weight',
 'context_encoder.encoder.layers.0.attention.k_proj.bias',
 'context_encoder.encoder.layers.0.attention.v_proj.weight',
 'context_encoder.encoder.layers.0.attention.v_proj.bias',
 'context_encoder.encoder.layers.0.attention.q_proj.weight',
 'context_encoder.encoder.layers.0.attention.q_proj.bias',
 'context_encoder.encoder.layers.0.attention.out_proj.weight',
 'context_encoder.encoder.layers.0.attention.out_proj.bias

In [9]:
def mapping_weights(model, ckpt):
    model_state_dict = dict()

    for src_param_name, weight in ckpt.items():
        if "feature_extractor" in src_param_name:
            dst_param_name = src_param_name.replace("wav2vec2.", "")
        else:
            dst_param_name = src_param_name.replace("wav2vec2", "context_encoder")
        model_state_dict[dst_param_name] = weight
    
    # Find missing parameters
    dst_params = list(model.state_dict().keys())
    missing_params = set(dst_params) - set(model_state_dict.keys())
    print(f"Missing parameters: {missing_params}")

    return model_state_dict

In [10]:
list(mapping_weights(model, ckpt).keys())

Missing parameters: {'quantizer.projection.weight', 'quantizer.projection.bias', 'quantizer.codebooks'}


['context_encoder.masked_spec_embed',
 'feature_extractor.conv_layers.0.conv.weight',
 'feature_extractor.conv_layers.0.layer_norm.weight',
 'feature_extractor.conv_layers.0.layer_norm.bias',
 'feature_extractor.conv_layers.1.conv.weight',
 'feature_extractor.conv_layers.2.conv.weight',
 'feature_extractor.conv_layers.3.conv.weight',
 'feature_extractor.conv_layers.4.conv.weight',
 'feature_extractor.conv_layers.5.conv.weight',
 'feature_extractor.conv_layers.6.conv.weight',
 'context_encoder.feature_projection.layer_norm.weight',
 'context_encoder.feature_projection.layer_norm.bias',
 'context_encoder.feature_projection.projection.weight',
 'context_encoder.feature_projection.projection.bias',
 'context_encoder.encoder.pos_conv_embed.conv.bias',
 'context_encoder.encoder.pos_conv_embed.conv.weight_g',
 'context_encoder.encoder.pos_conv_embed.conv.weight_v',
 'context_encoder.encoder.layer_norm.weight',
 'context_encoder.encoder.layer_norm.bias',
 'context_encoder.encoder.layers.0.atte