In [36]:
import whisper
import torch.nn as nn
import loralib as lora

In [37]:
whisper_model = whisper.load_model('tiny', device="cpu")

In [38]:
whisper_model

Whisper(
  (encoder): AudioEncoder(
    (conv1): Conv1d(80, 384, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(384, 384, kernel_size=(3,), stride=(2,), padding=(1,))
    (blocks): ModuleList(
      (0-3): 4 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=384, out_features=384, bias=True)
          (key): Linear(in_features=384, out_features=384, bias=False)
          (value): Linear(in_features=384, out_features=384, bias=True)
          (out): Linear(in_features=384, out_features=384, bias=True)
        )
        (attn_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=1536, out_features=384, bias=True)
        )
        (mlp_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln_post): LayerNorm((384,), eps=1e-05,

In [39]:
from whisper.model import MultiHeadAttention
import loralib as lora
import torch

def replace_attention_layers_with_lora(module, config, parent=None, parent_name=None):
    for name, child in module.named_children():
        # Recursive call for child modules
        replace_attention_layers_with_lora(child, config, parent=module, parent_name=name)

    if isinstance(module, MultiHeadAttention):
        # Replace the specific layers if they match the target names
        if hasattr(module, 'query') and 'query' in config['target_modules']:
            setattr(module, 'query', lora.Linear(module.query.in_features, module.query.out_features, r=config['r']))
        if hasattr(module, 'key') and 'key' in config['target_modules']:
            setattr(module, 'key', lora.Linear(module.key.in_features, module.key.out_features, r=config['r'], bias=False))
        if hasattr(module, 'value') and 'value' in config['target_modules']:
            setattr(module, 'value', lora.Linear(module.value.in_features, module.value.out_features, r=config['r']))

def mark_only_lora_as_trainable(model, bias='none'):
    lora.mark_only_lora_as_trainable(model, bias=bias)

def save_lora_model(model, path, bias='none'):
    torch.save(lora.lora_state_dict(model, bias=bias), path)

def load_lora_model(model, pretrained_path, lora_path, strict=False):
    model.load_state_dict(torch.load(pretrained_path), strict=strict)
    model.load_state_dict(torch.load(lora_path), strict=strict)

def print_model_layers(model, indent=0):
    """
    Recursively prints out the model's layers and their types.
    """
    for name, module in model.named_children():
        print(' ' * indent + f'{name}: {type(module).__name__}')
        print_model_layers(module, indent + 2)

def has_lora_layers(model):
    """
    Check if the model has any LoRA layers.1
    """
    for n, _ in model.named_parameters():
        if 'lora_' in n:
            return True
    return False

In [40]:
if has_lora_layers(whisper_model):
    print("LoRA layers have been successfully integrated into the model.")
else:
    print("No LoRA layers found in the model.")


No LoRA layers found in the model.


In [43]:
config = {'r': 32, 'target_modules': ['query', 'key', 'value']}

In [44]:
replace_attention_layers_with_lora(whisper_model, config=config)
mark_only_lora_as_trainable(whisper_model)

In [45]:
if has_lora_layers(whisper_model):
    print("LoRA layers have been successfully integrated into the model.")
else:
    print("No LoRA layers found in the model.")

LoRA layers have been successfully integrated into the model.


In [52]:
def print_trainable_params(model: nn.Module) -> None:
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Out of {total_params:,} parameters, {trainable_params:,} are trainable, a reduction of {round((1-trainable_params/total_params)*100, 1)} % ")
print_trainable_params(whisper_model)

Out of 38,069,376 parameters, 884,736 are trainable, a reduction of 97.7 % 
