In [None]:
import torch
from torch import nn
from typing import Optional, Union, Tuple
from transformers import BertModel
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertEmbeddings, BertEncoder, BertPooler
from transformers.models.hubert.modeling_hubert import HubertPreTrainedModel, HubertFeatureEncoder, HubertFeatureProjection, HubertEncoder, HubertModel
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions

In [None]:
model = BertModel.from_pretrained('google/bert_uncased_L-4_H-128_A-2')

In [None]:
def freeze_module(model):
    for param in model.parameters():
        param.requires_grad = False

In [None]:
for name, param in model.named_parameters():
    # if param.requires_grad:
        print(name)
        print(param.shape)
        print(param.requires_grad)
        print()

In [None]:
count_parameters = lambda model : {'requires_grad':sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6,
                                   'does_not_require_grad':sum(p.numel() for p in model.parameters() if not p.requires_grad)/1e6}

In [None]:
count_parameters(model)

In [None]:
train_layers = 1

for child in model.children():
    print(child._get_name())
    if isinstance(child, BertEmbeddings):
        freeze_whole_model(child)
    elif isinstance(child, )
        

In [None]:
count_parameters(model)

In [None]:
for child in model.named_children():
    print(child.embeddings)
    if isinstance(child, BertEmbeddings):
        print(child)

In [None]:
#TODO to test this just inserst all the code from encoder_components.py in an above cell
# anything with src... cannot be imported for some reason

bi_enc_no_conv = BiEncoderSpeechTextModelWithoutFeatureEncoder()
bi_enc = BiEncoderSpeechTextModel()
mm_enc = MultiModalSpeechTextEncoder()

In [None]:
bi_enc_no_conv.children

In [None]:
mm_enc.children

In [None]:
def freeze_layers_except_last(model, n_layers_to_train=1):
    for name, child in model.named_children():
        if name == 'transformer':
            continue
        for param in child.parameters():
            param.requires_grad = False

In [None]:
from transformers.models.hubert.modeling_hubert import HubertPreTrainedModel, HubertFeatureEncoder, HubertFeatureProjection, HubertEncoder, HubertPositionalConvEmbedding
from torch.nn import LayerNorm, Dropout

In [None]:
def freeze_model(model, trainable_layers=0):
    """Trainable layers refers to the number of trainable attention layers
        in the network. If trainable layers > 0, then the corresponding projection
        head will also be trainable. In case of a Bi-Encoder only components of
        speech model will be trainable, the text model will always be frozen.

    Args:
        model (
            BiEncoderSpeechTextModelWithoutFeatureEncoder,
            BiEncoderSpeechTextModel,
            MultiModalSpeechTextEncoder
            ): The model to be frozen.
        trainablelayers (int, optional): How many attention layers in the speech or
            multimodal encoder to train. Defaults to 0.
    """
    print(f"Parameters before freezing: {count_parameters(model)}")
    
    for _, child in model.named_children():
        
        # standard BERT as text model
        if isinstance(child, BertModel):
            freeze_module(child)
        
        # modules for the multimodal encoder
        elif isinstance(child, BertEmbeddingsWrapper):
            freeze_module(child)
        elif isinstance(child, HubertConvFeatureExtractorWrapper):
            freeze_module(child)
        elif isinstance(child, HubertFeatureProjectionWrapper):
            freeze_module(child)
        elif isinstance(child, BertEncoderWrapper):          
            for na, ch in child.named_children():
                for n, c in ch.named_children():
                    if isinstance(c, torch.nn.ModuleList):
                        for i, _ in enumerate(c._modules):
                                if i < (len(c._modules) - trainable_layers):
                                    freeze_module(c[i])
        elif isinstance(child, HubertPooler) or isinstance(child, BertPoolerWrapper):
            pass
        
        # modules for the speech encoder without convolution
        elif isinstance(child, HubertModelWithoutFeatureEncoder): # done
            for na, ch in child.named_children():
                if isinstance(ch, HubertFeatureProjectionWrapper):
                    freeze_module(ch)
                elif isinstance(ch, HubertEncoderWrapper):
                    for n, c in ch.named_children():
                        for n_enc, c_enc in c.named_children():
                            if isinstance(c_enc, LayerNorm):
                                freeze_module(c_enc)
                            elif isinstance(c_enc, Dropout):
                                freeze_module(c_enc)
                            elif isinstance(c_enc, torch.nn.ModuleList):
                                for i, _ in enumerate(c_enc._modules):
                                    if i < (len(c_enc._modules) - trainable_layers):
                                        freeze_module(c_enc[i])
                elif isinstance(ch, HubertPooler):
                    pass
        
        # modules for the HuBERT speech encoder with convolution and pooler             
        elif isinstance(child, HubertModelWithPooler): # done
            for na, ch in child.named_children():
                if isinstance(ch, HubertModel):
                    freeze_module(ch.feature_extractor)
                    freeze_module(ch.feature_projection)
                    for n, c in ch.encoder.named_children():
                        if isinstance(c, HubertPositionalConvEmbedding):
                            freeze_module(c)
                        elif isinstance(c, LayerNorm):
                            freeze_module(c)
                        elif isinstance(c, Dropout):
                            freeze_module(c)
                        elif isinstance(c, torch.nn.ModuleList):
                            for i, _ in enumerate(c._modules):
                                if i < (len(c._modules) - trainable_layers):
                                    freeze_module(c[i])
                if isinstance(ch, HubertPooler):
                    pass
                
    print(f"Parameters after freezing: {count_parameters(model)}")
    
freeze_model(mm_enc, trainable_layers=1)

In [None]:
bi_enc