In [1]:
from load_fsq_model import load_model

In [1]:
# load_fsq_model def load_model()
import fairseq
from typing import Any, Dict, Optional, Union
from fairseq.data import Dictionary
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
from fairseq.tasks.audio_pretraining import AudioPretrainingTask
from fairseq.tasks.audio_finetuning import AudioFinetuningTask
from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model, Wav2Vec2Config
from fairseq.models.wav2vec.wav2vec2_asr import Wav2VecCtc, Wav2Vec2CtcConfig
# from custom_w2v2 import Wav2Vec2Model
# from w2v2config import Wav2Vec2Config
from fairseq import models, quantization_utils

from fairseq.dataclass.utils import convert_namespace_to_omegaconf, merge_with_parent
from fairseq import tasks

def load_model(filename, arg_overrides: Optional[Dict[str, Any]] = None):

    state = load_checkpoint_to_cpu(filename, arg_overrides)

    if "args" in state and state["args"] is not None:
        cfg = convert_namespace_to_omegaconf(state["args"])
    elif "cfg" in state and state["cfg"] is not None:
        cfg = state["cfg"]
    
    # task = AudioPretrainingTask.setup_task(cfg.task)
    # task = AudioFinetuningTask.setup_task(cfg.task)
    
    # model = task.build_model(cfg.model)
    model_cfg = cfg.model
    model_type = getattr(model_cfg, "_name", None) # or getattr(cfg, "arch", None)
    
    if model_type == 'wav2vec2':
        model_cfg = merge_with_parent(Wav2Vec2Config(), model_cfg)
        model = Wav2Vec2Model.build_model(model_cfg)
    elif model_type == 'wav2vec_ctc':
        cfg.task['data'] = './' # Set path where dict exists
        task = AudioFinetuningTask.setup_task(cfg.task)
        model_cfg = merge_with_parent(Wav2Vec2CtcConfig(), model_cfg)
        model = Wav2VecCtc.build_model(model_cfg, task)
    
    model = quantization_utils.quantize_model_scalar(model, cfg)

    model.load_state_dict(state['model'], strict=True, model_cfg=cfg.model)

    return model

In [3]:
model = load_model('xlsr_53_56k.pt')

In [4]:
model

Wav2Vec2Model(
  (feature_extractor): ConvFeatureExtractionModel(
    (conv_layers): ModuleList(
      (0): Sequential(
        (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeLast()
          (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (2): TransposeLast()
        )
        (3): GELU()
      )
      (1): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeLast()
          (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (2): TransposeLast()
        )
        (3): GELU()
      )
      (2): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeLast()
          (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True

---

In [5]:
model = load_model('wav2vec_small_960h.pt')

FileNotFoundError: [Errno 2] No such file or directory: 'wav2vec_small_960h.pt'

In [6]:
model

Wav2Vec2Model(
  (feature_extractor): ConvFeatureExtractionModel(
    (conv_layers): ModuleList(
      (0): Sequential(
        (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeLast()
          (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (2): TransposeLast()
        )
        (3): GELU()
      )
      (1): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeLast()
          (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (2): TransposeLast()
        )
        (3): GELU()
      )
      (2): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeLast()
          (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True

In [6]:
model = load_model('wav2vec_big_960h.pt')

<class 'fairseq.tasks.audio_pretraining.AudioPretrainingTask'>


In [7]:
model

Wav2VecCtc(
  (w2v_encoder): Wav2VecEncoder(
    (w2v_model): Wav2Vec2Model(
      (feature_extractor): ConvFeatureExtractionModel(
        (conv_layers): ModuleList(
          (0): Sequential(
            (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): Fp32GroupNorm(512, 512, eps=1e-05, affine=True)
            (3): GELU()
          )
          (1): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU()
          )
          (2): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU()
          )
          (3): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU()
          )
          (4): Sequen

In [8]:
import torch.nn as nn
isinstance(model, nn.Module)

True

In [32]:
models.ARCH_MODEL_REGISTRY

{'transformer_tiny': fairseq.models.transformer.transformer_legacy.TransformerModel,
 'transformer': fairseq.models.transformer.transformer_legacy.TransformerModel,
 'transformer_iwslt_de_en': fairseq.models.transformer.transformer_legacy.TransformerModel,
 'transformer_wmt_en_de': fairseq.models.transformer.transformer_legacy.TransformerModel,
 'transformer_vaswani_wmt_en_de_big': fairseq.models.transformer.transformer_legacy.TransformerModel,
 'transformer_vaswani_wmt_en_fr_big': fairseq.models.transformer.transformer_legacy.TransformerModel,
 'transformer_wmt_en_de_big': fairseq.models.transformer.transformer_legacy.TransformerModel,
 'transformer_wmt_en_de_big_t2t': fairseq.models.transformer.transformer_legacy.TransformerModel,
 'fconv': fairseq.models.fconv.FConvModel,
 'fconv_iwslt_de_en': fairseq.models.fconv.FConvModel,
 'fconv_wmt_en_ro': fairseq.models.fconv.FConvModel,
 'fconv_wmt_en_de': fairseq.models.fconv.FConvModel,
 'fconv_wmt_en_fr': fairseq.models.fconv.FConvModel,


In [28]:
models.MODEL_DATACLASS_REGISTRY

{'transformer_lm': fairseq.models.transformer_lm.TransformerLanguageModelConfig,
 'wav2vec': fairseq.models.wav2vec.wav2vec.Wav2VecConfig,
 'wav2vec2': fairseq.models.wav2vec.wav2vec2.Wav2Vec2Config,
 'wav2vec_ctc': fairseq.models.wav2vec.wav2vec2_asr.Wav2Vec2CtcConfig,
 'wav2vec_seq2seq': fairseq.models.wav2vec.wav2vec2_asr.Wav2Vec2Seq2SeqConfig,
 'hubert': fairseq.models.hubert.hubert.HubertConfig,
 'hubert_ctc': fairseq.models.hubert.hubert_asr.HubertCtcConfig}