In [1]:
import whisper
import torch.nn as nn
from whisper_finetune.model.lora import (
    replace_attention_layers_with_lora,
    has_lora_layers,
    mark_only_lora_as_trainable,
    print_trainable_params,
)

In [2]:
whisper_model = whisper.load_model('tiny', device="cpu")

In [3]:
print_trainable_params(whisper_model)

Out of 37,184,640 parameters, 37,184,640 are trainable, a reduction of 0.0 % 


In [5]:
replace_attention_layers_with_lora(whisper_model, {'r': 32, 'alpha': 64, 'dropout': 0.05, 'target_modules': ['query', 'value'], 'bias': 'none'})
mark_only_lora_as_trainable(whisper_model)
assert has_lora_layers(whisper_model), "Lora layers were somehow not correctly set."
print_trainable_params(whisper_model)

Out of 37,765,248 parameters, 589,824 are trainable, a reduction of 98.4 % 


In [9]:
whisper_model

Whisper(
  (encoder): AudioEncoder(
    (conv1): Conv1d(80, 384, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(384, 384, kernel_size=(3,), stride=(2,), padding=(1,))
    (blocks): ModuleList(
      (0-3): 4 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(
            in_features=384, out_features=384, bias=False
            (lora_dropout): Dropout(p=0.05, inplace=False)
          )
          (key): Linear(in_features=384, out_features=384, bias=False)
          (value): Linear(
            in_features=384, out_features=384, bias=False
            (lora_dropout): Dropout(p=0.05, inplace=False)
          )
          (out): Linear(in_features=384, out_features=384, bias=True)
        )
        (attn_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=1536, 

In [15]:
whisper_model = whisper.load_model('tiny', device="cpu")

In [16]:
whisper_model

Whisper(
  (encoder): AudioEncoder(
    (conv1): Conv1d(80, 384, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(384, 384, kernel_size=(3,), stride=(2,), padding=(1,))
    (blocks): ModuleList(
      (0-3): 4 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=384, out_features=384, bias=True)
          (key): Linear(in_features=384, out_features=384, bias=False)
          (value): Linear(in_features=384, out_features=384, bias=True)
          (out): Linear(in_features=384, out_features=384, bias=True)
        )
        (attn_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=1536, out_features=384, bias=True)
        )
        (mlp_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln_post): LayerNorm((384,), eps=1e-05,

In [17]:
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model

config = LoraConfig(r=32, lora_alpha=64, target_modules=["query", "value"], lora_dropout=0.05, bias="none")

model = get_peft_model(whisper_model, config)
model.print_trainable_parameters()

trainable params: 589,824 || all params: 37,774,464 || trainable%: 1.5614357890028565


In [18]:
model

PeftModel(
  (base_model): LoraModel(
    (model): Whisper(
      (encoder): AudioEncoder(
        (conv1): Conv1d(80, 384, kernel_size=(3,), stride=(1,), padding=(1,))
        (conv2): Conv1d(384, 384, kernel_size=(3,), stride=(2,), padding=(1,))
        (blocks): ModuleList(
          (0-3): 4 x ResidualAttentionBlock(
            (attn): MultiHeadAttention(
              (query): lora.Linear(
                (base_layer): Linear(in_features=384, out_features=384, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=384, out_features=32, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=32, out_features=384, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
